diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index a81596c3de..dd74bf1f13 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -11,6 +11,13 @@ from fairseq.models import BaseFairseqModel +_GOSSIP_DISABLED = False +try: + import gossip +except ImportError: + _GOSSIP_DISABLED = True + + def DistributedFairseqModel(args, model, process_group=None): """ Wrap a *model* to support distributed data parallel training. @@ -26,7 +33,7 @@ def DistributedFairseqModel(args, model, process_group=None): """ # determine which DDP class to extend assert isinstance(model, nn.Module) - if args.ddp_backend == 'c10d': + if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d': ddp_class = nn.parallel.DistributedDataParallel init_kwargs = dict( module=model, @@ -41,7 +48,7 @@ def DistributedFairseqModel(args, model, process_group=None): init_kwargs['check_reduction'] = True if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]: init_kwargs['find_unused_parameters'] = args.find_unused_parameters - elif args.ddp_backend == 'no_c10d': + elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d': ddp_class = LegacyDistributedDataParallel init_kwargs = dict( module=model, @@ -49,6 +56,36 @@ def DistributedFairseqModel(args, model, process_group=None): buffer_size=2**28, process_group=process_group, ) + elif args.distributed_wrapper == 'SlowMo': + if _GOSSIP_DISABLED: + raise ImportError( + 'Cannot find gossip library. Please install from: ' + 'github.com/facebookresearch/stochastic_gradient_push' + ) + ddp_class = gossip.GossipDataParallel + + # The values of slowmo_momentum below were obtained by tuning on the + # En-De 16 dataset by training the transformer_wmt_en_de_large model + if args.slowmo_momentum is None: + if args.distributed_world_size <= 16: + args.slowmo_momentum = 0.0 + elif args.distributed_world_size <= 32: + args.slowmo_momentum = 0.2 + elif args.distributed_world_size <= 64: + args.slowmo_momentum = 0.5 + else: + args.slowmo_momentum = 0.6 + + init_kwargs = dict( + module=model, + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=args.broadcast_buffers, + nprocs_per_node=args.nprocs_per_node, + slowmo_momentum=args.slowmo_momentum, + localsgd=(args.slowmo_algorithm == 'LocalSGD'), + localsgd_frequency=args.localsgd_frequency + ) else: raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) diff --git a/fairseq/options.py b/fairseq/options.py index f53df9779e..b54e6bd761 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -389,6 +389,23 @@ def add_distributed_training_args(parser): group.add_argument('--broadcast-buffers', default=False, action='store_true', help='Copy non-trainable parameters between GPUs, such as ' 'batchnorm population statistics') + + group.add_argument('--distributed-wrapper', default='DDP', type=str, + choices=['DDP', 'SlowMo'], + help='DistributedDataParallel backend') + # Add arguments for SlowMo - these will be used when SlowMo is enabled via above + group.add_argument('--slowmo-momentum', default=None, type=float, + help='SlowMo momentum term; by default use 0.0 for 16 GPUs, ' + '0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs') + group.add_argument('--slowmo-algorithm', default='LocalSGD', choices=['LocalSGD', 'SGP'], + help='whether to use LocalSGD or SGP') + group.add_argument('--localsgd-frequency', default=3, type=int, + help='Local SGD allreduce frequency') + group.add_argument('--nprocs-per-node', type=int, metavar='N', + default=max(1, torch.cuda.device_count()), + help='number of GPUs in each node. An allreduce operation across GPUs in ' + 'a node is very fast. Hence, we do allreduce across GPUs in a node, ' + 'and gossip across different nodes') # fmt: on return group diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 8213679ce2..a2f9d4c7db 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -412,6 +412,7 @@ def maybe_no_sync(): logging_outputs, sample_size, ooms, ignore=is_dummy_batch, ) + overflow = False try: # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get @@ -429,29 +430,11 @@ def maybe_no_sync(): grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers - if not self.args.use_bmuf: + if not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo': self._check_grad_norms(grad_norm) # take an optimization step self.optimizer.step() - self.set_num_updates(self.get_num_updates() + 1) - - # log stats - logging_output = self._reduce_and_log_stats( - logging_outputs, sample_size, grad_norm, - ) - - # clear CUDA cache to reduce memory fragmentation - if ( - self.args.empty_cache_freq > 0 - and ( - (self.get_num_updates() + self.args.empty_cache_freq - 1) - % self.args.empty_cache_freq - ) == 0 - and torch.cuda.is_available() - and not self.args.cpu - ): - torch.cuda.empty_cache() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print out where it fails with NanDetector(self.model): @@ -461,15 +444,43 @@ def maybe_no_sync(): ) raise except OverflowError as e: + overflow = True logger.info("NOTE: overflow detected, " + str(e)) + grad_norm = torch.tensor(0.).cuda() self.zero_grad() - logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e + # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step + if hasattr(self.model, 'perform_additional_optimizer_actions'): + if hasattr(self.optimizer, 'fp32_params'): + self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) + else: + self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) + + if not overflow or self.args.distributed_wrapper == 'SlowMo': + self.set_num_updates(self.get_num_updates() + 1) + + # log stats + logging_output = self._reduce_and_log_stats( + logging_outputs, sample_size, grad_norm, + ) + + # clear CUDA cache to reduce memory fragmentation + if ( + self.args.empty_cache_freq > 0 + and ( + (self.get_num_updates() + self.args.empty_cache_freq - 1) + % self.args.empty_cache_freq + ) == 0 + and torch.cuda.is_available() + and not self.args.cpu + ): + torch.cuda.empty_cache() + if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)