-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stack Memory #97
Stack Memory #97
Conversation
Codecov Report
@@ Coverage Diff @@
## main #97 +/- ##
==========================================
+ Coverage 72.95% 73.39% +0.43%
==========================================
Files 13 14 +1
Lines 2326 2571 +245
==========================================
+ Hits 1697 1887 +190
- Misses 629 684 +55
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before going to the very scant meat of the review, I'll go a bit meta. I think this and other big or even medium-sized PRs might be improved by treating their descriptions like summaries that provide
- Context - "Here, I'm increasing runtime performance by providing explicit stack memory management and thus reducing the overhead of memory transfer. This should be especially powerful in the case of XXX.".
- Method - "This PR alters the previous memory management idea which was XXX and instead replace it with XXX". "To the best of my understanding, this should have no disadvantages but it enables XXX".
- Results - "A simple benchmark shows that XXX and that the PR achieves its goal with the only trade-offs being XXX. A more rigorous and public benchmark to be added in XXX in another PR (link to issue)".
This would serve a few different purposes.
- it would generate excitement from users who are not someone you might have asked to review but who is keeping an eye on PRs and recent updates because they're interested. Some context and some excitement goes a long way to embolden people to weigh in. After weighing in and receiving appreciative responses (even if only "That would not work because of XXX but it's really cool to see that someone cares enough about this work to read and comment. Positive/contructive engagement is always welcomed!"). Once engaged and also after having had positive interactions with the lead dev then the step to deeper contributions is reduced.
- it would improve the public record of changes. The release notes link to the PR, so big and cool PRs are likely to actually be read. I know that I read the descriptions of merged PRs for packages that I follow closely enough to care about their releases.
- Finally, It would provide the reviewer with some context and an immediate high-level overview of what's going on. It would also dangle the benefits of having this merged as a carrot for the reviewer to actually get 'round to review.
If you want to just dump your current progress in the repo and benefit from CI without spending time on a description then you might make a draft PR and edit the top description before converting to a PR and requesting reviews.
Right now, I'm lacking both a bit of context but also a fair bit of expertise and understanding of the codebase to provide very meaningful comments (I admit that only a small chunk of those missing pieces could be fixed by a description - one can make hard reviews easier but not easy.). All I can contribute here is some minor linting and some parroting of the concerns that CI is raising. There are some tests lacking and there is one failing CI test run.
Fix the problem behind the failing CI, review codecovs complaints (there are some functions that look like they need tests; train_unbatched_core
, for example) and after that, I don't see anything blocking even though a description for the public record would be a plus.
else | ||
NoTangent(), | ||
StrideArraysCore.StrideArray(lgrad, memory), | ||
StrideArraysCore.StrideArray(grad, memory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test?
src/memory.jl
Outdated
else | ||
first(with_heap_memory(f, sc, num_bytes, args...)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test?
if g isa AbstractMatrix && size(g, 2) == 1 | ||
gpb = preserve_buffer(g) | ||
gv = PtrArray(pointer(g), (length(p),)) | ||
GC.@preserve gpb train_batched!(gv, p, _chn, X, opt, iters; batchsize) | ||
return p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test?
Tests seem to fail on Julia 1.6 only. |
That is tricky. SimpleChains relies a lot on recursion, which triggers heuristics in Julia's compiler to give up. SimpleChains.jl/src/SimpleChains.jl Lines 84 to 93 in d737d40
It is because of this that _numparam infers on >=1.7, but fails on 1.6 (and hence the @inferred test fails).
|
…rride compiler recursion-avoidance heuristics
Tests now pass on 1.6. While I think the |
Co-authored-by: Niklas Korsbo <[email protected]>
Co-authored-by: Niklas Korsbo <[email protected]>
We could improve test coverage by dropping support for Julia 1.6, or by creating a small subset of tests to run with Julia 1.6 + enabled coverage. |
What would we lose if we dropped LTS support? I imagine that the intersection of LTS users and SC users is currently rather small. However, not supporting LTS might prevent other packages from using SC as a dependency. Are there any natural future SC integrations (Flux, Lux, SciML, etc.) that would suffer from a lack of LTS support? |
I think that we should just live with two versions for now. |
init_params!(::Activation, p, id) = p, id | ||
_check_input_dims(::Activation, _) = nothing | ||
|
||
layer_output_size(::Val{T}, a::Activation, s) where {T} = align(prod(s) * (2sizeof(T))), s | ||
forward_layer_output_size(::Val{T}, a::Activation, s) where {T} = | ||
align(prod(s) * static_sizeof(T)), s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the 2 was dropped here, was this on purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The 2 came from forward + reverse, while forward_layer_output_size
is forwad only.
@@ -1,5 +1,5 @@ | |||
|
|||
@static if isdefined(ChainRulesCore, :NoTangent) | |||
if isdefined(ChainRulesCore, :NoTangent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the @static
no longer needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps this whole branch is not needed anymore if we support the latest ChainRulesCore versions only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@static
is only really needed if you have invalid syntax hiding behind a branch, or it may be nice inside a function.
Top level in a package should be interpreted anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But yeah, we could support only ChainRulesCore
>= 1.
I'm not sure how much of the ecosystem is still stuck on old versions. At the time, I believe many packages were.
end | ||
function get_heap_memory(sc, num_bytes) | ||
heap_memory = task_local_memory(sc) | ||
length(heap_memory) >= num_bytes || resize!(empty!(heap_memory), Int(num_bytes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why empty and then resize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there isn't enough contiguous space following the current allocation, resize!
allocates new memory, and copies the old memory there before discarding the old.
We don't care about the memory, and thus don't want to waste time copying any of it.
julia> x = collect(1:10);
julia> empty!(x);
julia> resize!(x, 1000000);
julia> x[1:10]'
1×10 adjoint(::Vector{Int64}) with eltype Int64:
0 0 0 0 0 0 0 0 0 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks for the explanation.
src/memory.jl
Outdated
_static_max_16384(::StaticInt{N}) where {N} = StaticInt{N}() | ||
_static_max_16384(_) = StaticInt{16384}() | ||
@inline function with_memory(f::F, sc, num_bytes, args::Vararg{Any,K}) where {F,K} | ||
if num_bytes <= 16384 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this number be a customisable global constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By customizable, do you mean use Preferences.jl
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya or environment variable or even just a global constant for now at the top of the module and then something more fancy can be done in another PR
end | ||
|
||
function required_bytes(::Val{T}, layers, sx, additional = static(0)) where {T} | ||
output_size(Val(T), layers, sx) + additional + static(63) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a comment here? why 63?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to 64 byte align the pointers, so we add 63 extra bytes to make sure that we can offset the pointer to achieve 64 byte alignment.
That is, worse case scenario we get a pointer that is +1
with respect to a 64 byte boundary. We can then add 63 to get a multiple of 64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realistically, the pointers we're getting are probably going to be at least 16 byte aligned, meaning we probably don't need to add anything more than 48 to get 64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With x64, memory reads/writes that cross a 64 byte boundary cost double.
AVX512 reads/writes up to 64 bytes at a time, meaning either 0% or 100% of reads/writes cross such a boundary.
AVX2 reading/writing 32 bytes at a time is less extreme, with either 0% or 50%. But still, as reads/writes are often the slowest operation, that is costly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also these boundaries that get fetched by the prefetcher. If you're going across a boundary, you need to fetch 128 bytes worth of memory (the 64 byte blocks, aka cachelies, on both sides of that boundary)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note in both cases it's 256 elements. In the first case, the pointer is aligned to 64 bytes (the remainder is 0), while in the second case, we offset by 1
, throwing off the alignment, making the memory reads more expensive.
julia> Int(pointer(@view(x[begin:end-1]))) % 64
0
julia> Int(pointer(@view(x[begin+1:end]))) % 64
8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that make more sense/seeing the benchmark help?
Note that on the M1 Mac, the cachelines are actually 128 bytes (2x larger), while the register size is 16 bytes (4x smaller), so loading/storing across such a boundary is going to be a much smaller issue on your computer than it is mine.
But you should be able to reproduce something like this on JuliaHub, particularly if you replace @simd
with @turbo
or start Julia with -C"native,-prefer-256-bit"
(I did the latter for the above example).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Ya I get the gist of it. Makes me wonder if this kind of functionality should just go into an array package that allows itself to be C-incompatible semantically by doing tricks like this to achieve more performance, e.g. doing column-wise alignment and padding as necessary. You had a package like this I believe. Might be useful to hide these details there instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your code often feels like you are programming in 2-3 levels of abstraction simultaneously :) My brain can only handle 1-1.5 levels at a time usually depending on how long I stare at the code. 2 is pushing it. I am always learning something new from reviewing your code though which probably means I shouldn't be the one doing the review 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a polite way to say "your code is an illegible mess" lol
I think the + static(63)
and alignment should probably all be moved to the allocation functions, rather than being a part of the "how much to allocate" functions.
PaddedMatrices
was that library, but it was superseded by StrideArrays
.
SimpleChains already depends on StrideArraysCore
. I will probably add support for things like automatically padding again.
mt[i] = β₁ * mt[i] + (1 - β₁) * Δᵢ | ||
vt[i] = β₂ * vt[i] + (1 - β₂) * Δᵢ^2 | ||
Δxᵢ = mt[i] / ((1 - βp₁) * (sqrt(vt[i] / (1 - βp₂)) + 1f-8)) | ||
Δxᵢ = mt[i] / ((1 - βp₁) * (sqrt(vt[i] / (1 - βp₂)) + 1.0f-8)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the 1.0f-8 be eps(eltype(Δ))
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure. In Flux, it's 1e-8
(a Float64
), but allows users to customize it. It doesn't automatically change based on element type, but it may assume Float32
is used in general more than we do.
8421c1b
to
8768d8f
Compare
@test !iszero(g3) | ||
@test !iszero(g4) | ||
|
||
@test gz ≈ g rtol=1e-6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to compare against finite difference or something that doesn't use the rrules defined in the package
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True.
There are a few tests using ForwardDiff
, but some of those also hit the ForwardDiff.Dual
methods defined here.
I don't feel qualified enough to approve this PR but if tests pass, let's go ahead and merge it and make an alpha release so we can test this package more downstream. |
Well, it's a problem if no one is! The implementations in this library are supposed to be "simple" in at least the sense that it's hopefully not too difficult to guess what the assembly will look, and everything the CPU is actually going to have to do to execute your code (at least if you know what LoopVectorization.jl is likely to do when presented with a given loop, which probably only I do). Perhaps there's a sort of declarative vs imperative-like tradeoff in readability, where some code makes the former/what easier to infer, and other code makes the latter/the how easier. It might be interesting to document any "tricks" this library uses to make it easier for others to get into and modify to suite their purposes, or serve as educational material. I've also been considering taking a machine learning course and writing up NN-related exercises as SimpleChains.jl tutorials (and fix/add support for things it needs). |
I would love to volunteer for that job as I am sure it will be educational but my hands are currently full between the work at Pumas and the hobby projects I am desperately trying to keep alive like TopOpt.jl and Nonconvex.jl. Perhaps we can internally discuss allocating more of my Pumas time here. Let's see. Alternatively if someone else comes along, we can also try to provide funding for that kind of work. |
🎉 |
Before releasing, I ought to run a lot of benchmarks and add
@inline
s to make sureStaticArray
s support is good.