Skip to content

Commit

Permalink
Merge pull request #31 from vchiley/no_bias
Browse files Browse the repository at this point in the history
Enable running MegaBlocks MoE without bias
  • Loading branch information
tgale96 authored Oct 17, 2023
2 parents 6640ebd + 98e43ee commit 52aa1b2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
2 changes: 2 additions & 0 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Arguments:
hidden_size : int = 1024
ffn_hidden_size : int = 4096
num_layers : int = 1
bias : bool = True
return_bias : bool = True

# MoE arguments.
moe_num_experts : int = 1
Expand Down
24 changes: 16 additions & 8 deletions megablocks/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,16 @@ def __init__(self, args : Arguments):
# Expert MLP.
self.mlp = mlp.MLP(args)

# Note that the output bias is not parallelized with expert
# model parallelism.
self.bias = torch.nn.Parameter(torch.empty(
args.hidden_size,
device=args.device,
dtype=common.dtype(args)))
torch.nn.init.zeros_(self.bias)
if self.args.bias:
# Note that the output bias is not parallelized with expert
# model parallelism.
self.bias = torch.nn.Parameter(torch.empty(
args.hidden_size,
device=args.device,
dtype=common.dtype(args)))
torch.nn.init.zeros_(self.bias)
else:
self.register_parameter('bias', None)

# Select the forward function for the operating mode.
self.forward_fn = (
Expand Down Expand Up @@ -420,4 +423,9 @@ def forward(self, x):
x, tokens_per_expert = self.forward_fn(
x, expert_weights, top_experts)
save_load_balancing_loss((tokens_per_expert, scores))
return x.view(sl, bs, hs), self.bias
x = x.view(sl, bs, hs)
if self.bias is not None:
if self.args.return_bias:
return x, self.bias
return x + self.bias
return x

0 comments on commit 52aa1b2

Please sign in to comment.