From 98e43eec9a60c9cbcf7ae31b4070b51597e18bd4 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley Date: Tue, 17 Oct 2023 13:01:17 -0700 Subject: [PATCH] enable running without bias --- megablocks/layers/arguments.py | 2 ++ megablocks/layers/moe.py | 24 ++++++++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 7fdd30ef..22391128 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -17,6 +17,8 @@ class Arguments: hidden_size : int = 1024 ffn_hidden_size : int = 4096 num_layers : int = 1 + bias : bool = True + return_bias : bool = True # MoE arguments. moe_num_experts : int = 1 diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 87e6a5b9..adf52c57 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -116,13 +116,16 @@ def __init__(self, args : Arguments): # Expert MLP. self.mlp = mlp.MLP(args) - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter(torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) - torch.nn.init.zeros_(self.bias) + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter(torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args))) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) # Select the forward function for the operating mode. self.forward_fn = ( @@ -420,4 +423,9 @@ def forward(self, x): x, tokens_per_expert = self.forward_fn( x, expert_weights, top_experts) save_load_balancing_loss((tokens_per_expert, scores)) - return x.view(sl, bs, hs), self.bias + x = x.view(sl, bs, hs) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x