-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GLU support #38
Add GLU support #38
Conversation
I've updated the PR based on @tgale96 's feedback, mainly:
I've left a few of the issues unresolved as I work through them, namely |
@@ -38,8 +38,8 @@ class Arguments: | |||
|
|||
# Compute arguments. | |||
memory_optimized_mlp : bool = False | |||
mlp_type: str = 'mlp' | |||
grouped_mlp: bool = False | |||
mlp_type : str = 'mlp' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to add these spaces? Looks like we're actually mixed on having them and not having them in this file...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm yea I thought that the spaces would match the existing style of the file
megablocks/layers/dmlp_registry.py
Outdated
|
||
MlpType = Union[mlp.SparseMLP, glu.SparseGLU] | ||
|
||
class dMlpRegistry: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistic thing - can we remove this class and have get
just be a function on the module? Then REGISTRY can be a private, global? i.e., _REGISTRY
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored it to get rid of the class, let me know if it looks good now!
LGTM! Ready to merge? |
@tgale96 great, yes! |
Thanks for the contribution Sasha! This is awesome. |
This change adds GLU blocks to megablocks (replacing vanilla MLPs), and does some refactoring around the mlp types, including an
MLP_TYPE_REGISTRY
.Note, this is unoptimized at the moment: