-
-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New layer architecture #159
base: master
Are you sure you want to change the base?
Conversation
1. Static network graph is separated from invocation context. a) Static graph captures layers, connections between them and shapes of the units of data. b) Invocation context specifies the batch size and stores all data associated with an invocation (data, gradients). 2. Batch size is now explicit in the context instead of being implicitly extracted by layers from incoming data. 3. Separation into Layer and ILayer is now gone, everything is now handled in layer implementations (with "leaf" layers focusing on data manipulations while container layers focusing on network composition). This is still a very early prototype not intended for mergin: 1. Solver architecture not changed and just crudely hacked to support new network architecture. 2. Shared weights not supported. 3. Serialization not supported.
I did an initial pass, it simplifies things on the user end, but what I see as plus, on the other hand side, it removes the ability to mix execution backends iiuc? I'll do another pass soon. |
You mean use different backends for net vs loss in In principle, we can keep the ability to have different backed for loss layer through either of these approaches:
Or did you mean mixing backends in different invocations of the network? I think already nothing precludes that, as layers don't store |
As mentioned earlier, this makes passing different execution contexts more difficult from what I can see API wise, since the creation of the layers then would have to hold an Storing the associated data as part of the descriptor is not something that seems idiomatic. The descriptor becomes the owner of the actual learned weights iiuc. A plus of this is, that all layers now have to use the same storage and not be backend specific and also allow things to extend more quickly to other serialization formats. One use case that must be supported, is to load external network definitions that only share the same input and output dimensions. This allows to i.e. hotswap networks during runtime.
I think this is the biggest gain in the new architecture.
👍 This is the first pass, it generally looks very promising, I have to give it another pass in hopefully less than 24d from now 😅 |
I think I misunderstood your earlier. Are you saying that we can't create the network using backend Spoiler: I'm toying with an idea of separating backend from context: pub trait Layer<B: IBackend>: Debug {
fn compute_output(&self, backend: &B, context: &mut Context);
} which I think is cleaner (and pub trait Layer: Debug {
fn compute_output(&self, backend: &dyn IBackend + LayerOps<f32>, context: &mut Context);
} (the latter will not compile, but hopefully the idea is clear).
Well the descriptor is just a convenient way of exposing data from a The question of ownership is a bit fuzzy with pub struct Linear {
// Weight (A) and bias (b) for the linear operation y = Ax + b.
weight: Rc<RefCell<LearnableParams>>,
bias: Rc<RefCell<LearnableParams>>,
}
This should be already be supported I think. At least I don't see any immediate issues.
Thanks. I have some updates on my end which I hope to push in about a week. Some cleanups on network side, plus I'm looking into solvers, as I need Adam optimizer for my tasks. |
1. Static network graph is separated from invocation context. a) Static graph captures layers, connections between them and shapes of the units of data. b) Invocation context specifies the batch size and stores all data associated with an invocation (data, gradients). 2. Batch size is now explicit in the context instead of being implicitly extracted by layers from incoming data. 3. Separation into Layer and ILayer is now gone, everything is now handled in layer implementations (with "leaf" layers focusing on data manipulations while container layers focusing on network composition). 4. Solvers replaced by a more linear architecture of a top-level Trainer and different Optimizers (although only SGD with momentum is currently supported since both RMSprop and Adam require squaring backend support). This is still a very early prototype not intended for mergin: 1. Shared weights not supported. 2. Serialization not supported. 3. Not all layers are migrated.
OK, pushed a refreshed version. I couldn't implement Adam since it requires squaring tensors, which is not supported by the backends currently, but I've added some placeholders for it in the new |
1. Static network graph is separated from invocation context. a) Static graph captures layers, connections between them and shapes of the units of data. b) Invocation context specifies the batch size and stores all data associated with an invocation (data, gradients). 2. Batch size is now explicit in the context instead of being implicitly extracted by layers from incoming data. 3. Separation into Layer and ILayer is now gone, everything is now handled in layer implementations (with "leaf" layers focusing on data manipulations while container layers focusing on network composition). 4. Solvers replaced by a more linear architecture of a top-level Trainer and different Optimizers (SGD with momentum and Adam are currently supported). This is still a very early prototype not intended for mergin: 1. Shared weights not supported. 2. Serialization not supported. 3. Not all layers are migrated.
Added Adam implementation, for now without backend support. |
.as_mut_slice::<f32>(); | ||
|
||
// We can rewrite the matrix equations at the top of this file in a element-wise form: | ||
// Mᵢ[j] = β₁Mᵢ₋₁[j] + (1-β₁)∇ᵢ[j] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, this a significant chunk of work ❤️ I'd like to discuss how we can move towards filling in the missing pieces and a path to getting the adjusted arch back to master. |
I think the remaining part should be pretty mechanical -- port other layers to the new infra, write unit tests, etc. I'm happy to do all of that, or we can split the work. I think it's probably a good idea to commit the current work to a branch, maybe even split in several PRs, to make the review more manageable. The currently missing pieces can be committed as separate PRs into the branch. The branch will have old and new code alongside until everything is ported, after which old code will be deleted. Do you still want to do an in-depth review of the core infra? I'd be definitely more comfortable if someone can double-check the file structure, names, etc. |
I'll get to that. One thing that came to mind was, bring ready to impl auto differentiation with the new arch. The old one was a bit clunky in that regard. |
Sorry, not sure what this means. Could you elaborate? |
That was supposed to be |
Sorry, can you clarify this? I think it already does it. Right now the API provides 2 types of abstraction:
Both APIs hide the low-level details like constructing a |
I think we can move forward with this large refactor. We could have a sync call if you'd like? Sorry for the delay(s) |
Sure, happy to have a call! I'm in PDT timezone, so it seems the acceptable overlapping time range is your evening and my morning. How about Jun 24, 19:00 Munich time? If it works, I can send a Google Meet invite. |
That'd work, please drop to [email protected] - if you get a bounce ( I hope not) it's due some email forwarding service issues, which are hopefully dealt with by now 🤞 |
Hey 👋 - I created https://github.com/spearow/juice/tree/arch-refactor where we should land the changeset first. You should also have received an invite that allows you to create branches. |
How much do we want the RNN layer to be implemented in the new arch before switching to it? I'm looking into it, but it will likely require some extensive changes of the backend.
As far as I can tell, the existing RNN implementation is not used anywhere. I'm not even sure it's implemented correctly. |
@@ -311,9 +308,8 @@ fn run_mnist( | |||
targets.push(label_val as usize); | |||
} | |||
// train the network! | |||
let infered_out = solver.train_minibatch(inp_lock.clone(), label_lock.clone()); | |||
let mut infered = solver.train_minibatch(inp_lock.clone(), label_lock.clone()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
//pub mod layer; | ||
//pub mod layers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//pub mod layer; | |
//pub mod layers; |
// // Gradient is calculated as 2 * (predictions - labels). | ||
// backend.copy(&labels.borrow(), &mut input_gradient.borrow_mut()); | ||
// backend.axpby( | ||
// &native_scalar(2.0), | ||
// &predictions.borrow(), | ||
// &native_scalar(-2.0), | ||
// &mut input_gradient.borrow_mut(), | ||
// ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be faster :) I am not sure how NaN
is treated in axpby
though.
branches: Vec<LayerConfig>, | ||
} | ||
|
||
pub struct Fanout<B: IBackend> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of documentation would be nice, since it'll become user visible
/// of the scenario (so the longer the agent is able to keep pole from falling, the bigger | ||
/// overall reward it gets). | ||
/// | ||
/// State "s" consists of [cart_pos, cart_vel, pole_angle, pole_angle_vel] variables. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// State "s" consists of [cart_pos, cart_vel, pole_angle, pole_angle_vel] variables. | |
/// State `s` consists of `[cart_pos, cart_vel, pole_angle, pole_angle_vel]` variables. |
if br != k || c.rows() != m || c.cols() != n { | ||
panic!("Wrong GEMM dimensions: [{},{}]x[{},{}] -> [{},{}]", ar, ac, br, bc, c.rows(), c.cols()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should consider making it a debug_assert!
and rely on the caller.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have yet to do a full code review, generally, looks excellent, a few nits.
Sorry, this is an old PR, at this point superseded by all the recent ones. I used it to ask the question: #159 (comment) (my bad, probably should have asked directly). |
New layer architecture prototype
Relates to #155 .
Changes proposed by this PR:
a) Static graph captures layers, connections between them and shapes of the units of data.
b) Invocation context specifies the batch size and stores all data associated with an invocation (data, gradients).
Notes to reviewer:
This is still a very early prototype not intended for merging:
A good order for exploring this PR is starting at comments in
net/mod.rs
,net/layer.rs
,net/descriptor.rs
andnet/context.rs
.