Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack Memory #97

Merged
merged 19 commits into from
Sep 2, 2022
Merged

Stack Memory #97

merged 19 commits into from
Sep 2, 2022

Conversation

chriselrod
Copy link
Contributor

Before releasing, I ought to run a lot of benchmarks and add @inlines to make sure StaticArrays support is good.

@codecov
Copy link

codecov bot commented Jul 1, 2022

Codecov Report

Merging #97 (8768d8f) into main (5a493d5) will increase coverage by 0.43%.
The diff coverage is 70.99%.

@@            Coverage Diff             @@
##             main      #97      +/-   ##
==========================================
+ Coverage   72.95%   73.39%   +0.43%     
==========================================
  Files          13       14       +1     
  Lines        2326     2571     +245     
==========================================
+ Hits         1697     1887     +190     
- Misses        629      684      +55     
Impacted Files Coverage Δ
src/simple_chain.jl 63.88% <53.04%> (-23.22%) ⬇️
src/optimize.jl 84.03% <71.42%> (-0.47%) ⬇️
src/utils.jl 78.18% <77.77%> (+0.40%) ⬆️
src/dense.jl 87.32% <91.42%> (-0.04%) ⬇️
src/memory.jl 96.77% <96.77%> (ø)
src/chain_rules.jl 96.20% <98.18%> (+75.83%) ⬆️
src/activation.jl 78.57% <100.00%> (+10.71%) ⬆️
src/conv.jl 38.04% <100.00%> (-0.12%) ⬇️
src/dropout.jl 92.10% <100.00%> (+0.10%) ⬆️
src/flatten.jl 95.83% <100.00%> (ø)
... and 8 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Copy link
Member

@korsbo korsbo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before going to the very scant meat of the review, I'll go a bit meta. I think this and other big or even medium-sized PRs might be improved by treating their descriptions like summaries that provide

  • Context - "Here, I'm increasing runtime performance by providing explicit stack memory management and thus reducing the overhead of memory transfer. This should be especially powerful in the case of XXX.".
  • Method - "This PR alters the previous memory management idea which was XXX and instead replace it with XXX". "To the best of my understanding, this should have no disadvantages but it enables XXX".
  • Results - "A simple benchmark shows that XXX and that the PR achieves its goal with the only trade-offs being XXX. A more rigorous and public benchmark to be added in XXX in another PR (link to issue)".

This would serve a few different purposes.

  • it would generate excitement from users who are not someone you might have asked to review but who is keeping an eye on PRs and recent updates because they're interested. Some context and some excitement goes a long way to embolden people to weigh in. After weighing in and receiving appreciative responses (even if only "That would not work because of XXX but it's really cool to see that someone cares enough about this work to read and comment. Positive/contructive engagement is always welcomed!"). Once engaged and also after having had positive interactions with the lead dev then the step to deeper contributions is reduced.
  • it would improve the public record of changes. The release notes link to the PR, so big and cool PRs are likely to actually be read. I know that I read the descriptions of merged PRs for packages that I follow closely enough to care about their releases.
  • Finally, It would provide the reviewer with some context and an immediate high-level overview of what's going on. It would also dangle the benefits of having this merged as a carrot for the reviewer to actually get 'round to review.

If you want to just dump your current progress in the repo and benefit from CI without spending time on a description then you might make a draft PR and edit the top description before converting to a PR and requesting reviews.

Right now, I'm lacking both a bit of context but also a fair bit of expertise and understanding of the codebase to provide very meaningful comments (I admit that only a small chunk of those missing pieces could be fixed by a description - one can make hard reviews easier but not easy.). All I can contribute here is some minor linting and some parroting of the concerns that CI is raising. There are some tests lacking and there is one failing CI test run.

Fix the problem behind the failing CI, review codecovs complaints (there are some functions that look like they need tests; train_unbatched_core, for example) and after that, I don't see anything blocking even though a description for the public record would be a plus.

Comment on lines +61 to +64
else
NoTangent(),
StrideArraysCore.StrideArray(lgrad, memory),
StrideArraysCore.StrideArray(grad, memory)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test?

src/chain_rules.jl Outdated Show resolved Hide resolved
src/chain_rules.jl Outdated Show resolved Hide resolved
src/chain_rules.jl Show resolved Hide resolved
src/loss.jl Show resolved Hide resolved
src/memory.jl Outdated
Comment on lines 41 to 43
else
first(with_heap_memory(f, sc, num_bytes, args...))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test?

src/optimize.jl Outdated Show resolved Hide resolved
Comment on lines +604 to +608
if g isa AbstractMatrix && size(g, 2) == 1
gpb = preserve_buffer(g)
gv = PtrArray(pointer(g), (length(p),))
GC.@preserve gpb train_batched!(gv, p, _chn, X, opt, iters; batchsize)
return p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test?

@mohamed82008
Copy link
Member

Tests seem to fail on Julia 1.6 only.

@chriselrod
Copy link
Contributor Author

chriselrod commented Aug 27, 2022

Tests seem to fail on Julia 1.6 only.

That is tricky. SimpleChains relies a lot on recursion, which triggers heuristics in Julia's compiler to give up.
We can override these heuristics in Julia >=1.7

if VERSION >= v"1.7.0"
if hasfield(Method, :recursion_relation)
dont_limit = Returns(true)
for f = (chain_valgrad!, _chain, output_size, _numparam)
for m in methods(f)
m.recursion_relation = dont_limit
end
end
end
end

It is because of this that _numparam infers on >=1.7, but fails on 1.6 (and hence the @inferred test fails).

…rride compiler recursion-avoidance heuristics
@chriselrod
Copy link
Contributor Author

chriselrod commented Aug 27, 2022

Tests now pass on 1.6.
Still need to add more tests, and some code is only run on 1.6 so it won't show up as covered when using coverage=false for 1.6...

While I think the @generated functions are uglier, perhaps we should favor them over recursion so that we only need to maintain one implementation and don't rely on compiler hacks.

@chriselrod
Copy link
Contributor Author

We could improve test coverage by dropping support for Julia 1.6, or by creating a small subset of tests to run with Julia 1.6 + enabled coverage.

@korsbo
Copy link
Member

korsbo commented Aug 30, 2022

What would we lose if we dropped LTS support? I imagine that the intersection of LTS users and SC users is currently rather small. However, not supporting LTS might prevent other packages from using SC as a dependency. Are there any natural future SC integrations (Flux, Lux, SciML, etc.) that would suffer from a lack of LTS support?

@chriselrod
Copy link
Contributor Author

I think that we should just live with two versions for now.

init_params!(::Activation, p, id) = p, id
_check_input_dims(::Activation, _) = nothing

layer_output_size(::Val{T}, a::Activation, s) where {T} = align(prod(s) * (2sizeof(T))), s
forward_layer_output_size(::Val{T}, a::Activation, s) where {T} =
align(prod(s) * static_sizeof(T)), s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the 2 was dropped here, was this on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The 2 came from forward + reverse, while forward_layer_output_size is forwad only.

@@ -1,5 +1,5 @@

@static if isdefined(ChainRulesCore, :NoTangent)
if isdefined(ChainRulesCore, :NoTangent)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the @static no longer needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps this whole branch is not needed anymore if we support the latest ChainRulesCore versions only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@static is only really needed if you have invalid syntax hiding behind a branch, or it may be nice inside a function.
Top level in a package should be interpreted anyway.

Copy link
Contributor Author

@chriselrod chriselrod Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yeah, we could support only ChainRulesCore >= 1.
I'm not sure how much of the ecosystem is still stuck on old versions. At the time, I believe many packages were.

end
function get_heap_memory(sc, num_bytes)
heap_memory = task_local_memory(sc)
length(heap_memory) >= num_bytes || resize!(empty!(heap_memory), Int(num_bytes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why empty and then resize?

Copy link
Contributor Author

@chriselrod chriselrod Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there isn't enough contiguous space following the current allocation, resize! allocates new memory, and copies the old memory there before discarding the old.
We don't care about the memory, and thus don't want to waste time copying any of it.

julia> x = collect(1:10);

julia> empty!(x);

julia> resize!(x, 1000000);

julia> x[1:10]'
1×10 adjoint(::Vector{Int64}) with eltype Int64:
 0  0  0  0  0  0  0  0  0  0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the explanation.

src/memory.jl Outdated
_static_max_16384(::StaticInt{N}) where {N} = StaticInt{N}()
_static_max_16384(_) = StaticInt{16384}()
@inline function with_memory(f::F, sc, num_bytes, args::Vararg{Any,K}) where {F,K}
if num_bytes <= 16384
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this number be a customisable global constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By customizable, do you mean use Preferences.jl?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya or environment variable or even just a global constant for now at the top of the module and then something more fancy can be done in another PR

end

function required_bytes(::Val{T}, layers, sx, additional = static(0)) where {T}
output_size(Val(T), layers, sx) + additional + static(63)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment here? why 63?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to 64 byte align the pointers, so we add 63 extra bytes to make sure that we can offset the pointer to achieve 64 byte alignment.
That is, worse case scenario we get a pointer that is +1 with respect to a 64 byte boundary. We can then add 63 to get a multiple of 64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realistically, the pointers we're getting are probably going to be at least 16 byte aligned, meaning we probably don't need to add anything more than 48 to get 64.

Copy link
Contributor Author

@chriselrod chriselrod Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With x64, memory reads/writes that cross a 64 byte boundary cost double.
AVX512 reads/writes up to 64 bytes at a time, meaning either 0% or 100% of reads/writes cross such a boundary.
AVX2 reading/writing 32 bytes at a time is less extreme, with either 0% or 50%. But still, as reads/writes are often the slowest operation, that is costly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also these boundaries that get fetched by the prefetcher. If you're going across a boundary, you need to fetch 128 bytes worth of memory (the 64 byte blocks, aka cachelies, on both sides of that boundary)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note in both cases it's 256 elements. In the first case, the pointer is aligned to 64 bytes (the remainder is 0), while in the second case, we offset by 1, throwing off the alignment, making the memory reads more expensive.

julia> Int(pointer(@view(x[begin:end-1]))) % 64
0

julia> Int(pointer(@view(x[begin+1:end]))) % 64
8

Copy link
Contributor Author

@chriselrod chriselrod Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that make more sense/seeing the benchmark help?

Note that on the M1 Mac, the cachelines are actually 128 bytes (2x larger), while the register size is 16 bytes (4x smaller), so loading/storing across such a boundary is going to be a much smaller issue on your computer than it is mine.
But you should be able to reproduce something like this on JuliaHub, particularly if you replace @simd with @turbo or start Julia with -C"native,-prefer-256-bit" (I did the latter for the above example).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Ya I get the gist of it. Makes me wonder if this kind of functionality should just go into an array package that allows itself to be C-incompatible semantically by doing tricks like this to achieve more performance, e.g. doing column-wise alignment and padding as necessary. You had a package like this I believe. Might be useful to hide these details there instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code often feels like you are programming in 2-3 levels of abstraction simultaneously :) My brain can only handle 1-1.5 levels at a time usually depending on how long I stare at the code. 2 is pushing it. I am always learning something new from reviewing your code though which probably means I shouldn't be the one doing the review 😆

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a polite way to say "your code is an illegible mess" lol

I think the + static(63) and alignment should probably all be moved to the allocation functions, rather than being a part of the "how much to allocate" functions.

PaddedMatrices was that library, but it was superseded by StrideArrays.
SimpleChains already depends on StrideArraysCore. I will probably add support for things like automatically padding again.

mt[i] = β₁ * mt[i] + (1 - β₁) * Δᵢ
vt[i] = β₂ * vt[i] + (1 - β₂) * Δᵢ^2
Δxᵢ = mt[i] / ((1 - βp₁) * (sqrt(vt[i] / (1 - βp₂)) + 1f-8))
Δxᵢ = mt[i] / ((1 - βp₁) * (sqrt(vt[i] / (1 - βp₂)) + 1.0f-8))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the 1.0f-8 be eps(eltype(Δ)) instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. In Flux, it's 1e-8 (a Float64), but allows users to customize it. It doesn't automatically change based on element type, but it may assume Float32 is used in general more than we do.

@test !iszero(g3)
@test !iszero(g4)

@test gz ≈ g rtol=1e-6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to compare against finite difference or something that doesn't use the rrules defined in the package

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True.
There are a few tests using ForwardDiff, but some of those also hit the ForwardDiff.Dual methods defined here.

@mohamed82008
Copy link
Member

I don't feel qualified enough to approve this PR but if tests pass, let's go ahead and merge it and make an alpha release so we can test this package more downstream.

@chriselrod
Copy link
Contributor Author

I don't feel qualified enough to approve this PR

Well, it's a problem if no one is!

The implementations in this library are supposed to be "simple" in at least the sense that it's hopefully not too difficult to guess what the assembly will look, and everything the CPU is actually going to have to do to execute your code (at least if you know what LoopVectorization.jl is likely to do when presented with a given loop, which probably only I do).
In this sense, it's a lot less opaque than something like Flux.jl, which can achieve miraculous amounts of allocations and the associated bad performance out of nowhere from very few lines of code.

Perhaps there's a sort of declarative vs imperative-like tradeoff in readability, where some code makes the former/what easier to infer, and other code makes the latter/the how easier.

It might be interesting to document any "tricks" this library uses to make it easier for others to get into and modify to suite their purposes, or serve as educational material.
It may also serve to reveal things that should be abstracted out, like the + static(63) probably should be.
I do think someone else with enough time may be able to take a stab at it by just asking questions, like you did here.

I've also been considering taking a machine learning course and writing up NN-related exercises as SimpleChains.jl tutorials (and fix/add support for things it needs).

@mohamed82008
Copy link
Member

I would love to volunteer for that job as I am sure it will be educational but my hands are currently full between the work at Pumas and the hobby projects I am desperately trying to keep alive like TopOpt.jl and Nonconvex.jl. Perhaps we can internally discuss allocating more of my Pumas time here. Let's see. Alternatively if someone else comes along, we can also try to provide funding for that kind of work.

@chriselrod chriselrod merged commit 13d60f0 into main Sep 2, 2022
@chriselrod chriselrod deleted the stackmemory branch September 2, 2022 10:23
@korsbo
Copy link
Member

korsbo commented Sep 2, 2022

🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants