[doc] Fix gradient accumulation doc. (#4349)

* [doc] fix gradient accumulation doc

* [doc] fix gradient accumulation doc
pull/4382/head
flybird1111 2023-08-04 17:24:35 +08:00 committed by GitHub
parent 38b792aab2
commit f40b718959
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()

View File

@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()