Skip to content

Commit

Permalink
Fix gradient scaling to account for world_size normalization (#2172)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco authored Jan 9, 2025
1 parent cce8ef6 commit e420bc0
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 9 deletions.
7 changes: 5 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,9 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)

current_loss.backward()

Expand All @@ -778,7 +780,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
4 changes: 2 additions & 2 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,6 @@ def save_checkpoint(self, epoch: int) -> None:
def _loss_step(
self, batch: Dict[str, torch.Tensor]
) -> (torch.Tensor, torch.Tensor):

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]

Expand Down Expand Up @@ -875,7 +874,8 @@ def train(self) -> None:
torch.distributed.all_reduce(running_class_loss)
torch.distributed.all_reduce(running_kd_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
self._optimizer.step()
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
7 changes: 5 additions & 2 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,9 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)

current_loss.backward()

Expand All @@ -849,7 +851,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
3 changes: 2 additions & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down

0 comments on commit e420bc0

Please sign in to comment.