mirror of https://github.com/hpcaitech/ColossalAI
[doc] Fix gradient accumulation doc. (#4349)
* [doc] fix gradient accumulation doc * [doc] fix gradient accumulation docpull/4382/head
parent
38b792aab2
commit
f40b718959
|
@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
|
|||
with sync_context:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
else:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
|
|||
with sync_context:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
else:
|
||||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in New Issue