Skip to content

Commit

Permalink
Integrating SlowMo into fairseq and adding SlowMo to fbcode
Browse files Browse the repository at this point in the history
Summary:
This diff contains the following changes -
* Adding SlowMo algorithm to fbcode (this contains the latest implementation - complete with reduced memory usage for slow momentum, faster forward, linting among other things)
* Integration of SlowMo algorithm into fairseq (includes changes to the code to integrate as well as arguments for SlowMo)
* Scripts for calling SlowMo
* Addition of log-dir in addition to save-dir to allow different directories to be used for logging and saving

Reviewed By: myleott, mikerabbat

Differential Revision: D19184997

fbshipit-source-id: b42b298ac5297fb83a3335fa7ce262c8f48fb2bc
  • Loading branch information
vtantia authored and facebook-github-bot committed Apr 21, 2020
1 parent 91f7cf6 commit 0dac0ff
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 22 deletions.
41 changes: 39 additions & 2 deletions fairseq/models/distributed_fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from fairseq.models import BaseFairseqModel


_GOSSIP_DISABLED = False
try:
import gossip
except ImportError:
_GOSSIP_DISABLED = True


def DistributedFairseqModel(args, model, process_group=None):
"""
Wrap a *model* to support distributed data parallel training.
Expand All @@ -26,7 +33,7 @@ def DistributedFairseqModel(args, model, process_group=None):
"""
# determine which DDP class to extend
assert isinstance(model, nn.Module)
if args.ddp_backend == 'c10d':
if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d':
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
Expand All @@ -41,14 +48,44 @@ def DistributedFairseqModel(args, model, process_group=None):
init_kwargs['check_reduction'] = True
if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]:
init_kwargs['find_unused_parameters'] = args.find_unused_parameters
elif args.ddp_backend == 'no_c10d':
elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d':
ddp_class = LegacyDistributedDataParallel
init_kwargs = dict(
module=model,
world_size=args.distributed_world_size,
buffer_size=2**28,
process_group=process_group,
)
elif args.distributed_wrapper == 'SlowMo':
if _GOSSIP_DISABLED:
raise ImportError(
'Cannot find gossip library. Please install from: '
'github.com/facebookresearch/stochastic_gradient_push'
)
ddp_class = gossip.GossipDataParallel

# The values of slowmo_momentum below were obtained by tuning on the
# En-De 16 dataset by training the transformer_wmt_en_de_large model
if args.slowmo_momentum is None:
if args.distributed_world_size <= 16:
args.slowmo_momentum = 0.0
elif args.distributed_world_size <= 32:
args.slowmo_momentum = 0.2
elif args.distributed_world_size <= 64:
args.slowmo_momentum = 0.5
else:
args.slowmo_momentum = 0.6

init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
nprocs_per_node=args.nprocs_per_node,
slowmo_momentum=args.slowmo_momentum,
localsgd=(args.slowmo_algorithm == 'LocalSGD'),
localsgd_frequency=args.localsgd_frequency
)
else:
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)

Expand Down
17 changes: 17 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,23 @@ def add_distributed_training_args(parser):
group.add_argument('--broadcast-buffers', default=False, action='store_true',
help='Copy non-trainable parameters between GPUs, such as '
'batchnorm population statistics')

group.add_argument('--distributed-wrapper', default='DDP', type=str,
choices=['DDP', 'SlowMo'],
help='DistributedDataParallel backend')
# Add arguments for SlowMo - these will be used when SlowMo is enabled via above
group.add_argument('--slowmo-momentum', default=None, type=float,
help='SlowMo momentum term; by default use 0.0 for 16 GPUs, '
'0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs')
group.add_argument('--slowmo-algorithm', default='LocalSGD', choices=['LocalSGD', 'SGP'],
help='whether to use LocalSGD or SGP')
group.add_argument('--localsgd-frequency', default=3, type=int,
help='Local SGD allreduce frequency')
group.add_argument('--nprocs-per-node', type=int, metavar='N',
default=max(1, torch.cuda.device_count()),
help='number of GPUs in each node. An allreduce operation across GPUs in '
'a node is very fast. Hence, we do allreduce across GPUs in a node, '
'and gossip across different nodes')
# fmt: on
return group

Expand Down
51 changes: 31 additions & 20 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def maybe_no_sync():
logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
)

overflow = False
try:
# multiply gradients by (# GPUs / sample_size) since DDP
# already normalizes by the number of GPUs. Thus we get
Expand All @@ -429,29 +430,11 @@ def maybe_no_sync():
grad_norm = self.clip_grad_norm(self.args.clip_norm)

# check that grad norms are consistent across workers
if not self.args.use_bmuf:
if not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo':
self._check_grad_norms(grad_norm)

# take an optimization step
self.optimizer.step()
self.set_num_updates(self.get_num_updates() + 1)

# log stats
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
)

# clear CUDA cache to reduce memory fragmentation
if (
self.args.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
) == 0
and torch.cuda.is_available()
and not self.args.cpu
):
torch.cuda.empty_cache()
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print out where it fails
with NanDetector(self.model):
Expand All @@ -461,15 +444,43 @@ def maybe_no_sync():
)
raise
except OverflowError as e:
overflow = True
logger.info("NOTE: overflow detected, " + str(e))
grad_norm = torch.tensor(0.).cuda()
self.zero_grad()
logging_output = None
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
logger.error("OOM during optimization, irrecoverable")
raise e

# Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
if hasattr(self.model, 'perform_additional_optimizer_actions'):
if hasattr(self.optimizer, 'fp32_params'):
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
else:
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)

if not overflow or self.args.distributed_wrapper == 'SlowMo':
self.set_num_updates(self.get_num_updates() + 1)

# log stats
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
)

# clear CUDA cache to reduce memory fragmentation
if (
self.args.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
) == 0
and torch.cuda.is_available()
and not self.args.cpu
):
torch.cuda.empty_cache()

if self.args.fp16:
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)

Expand Down

0 comments on commit 0dac0ff

Please sign in to comment.