-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stack Memory #97
Stack Memory #97
Changes from 17 commits
5e84192
3391fc8
cdf7897
066e4e2
e642422
1b03a11
99b42e3
9f5ce7e
ff1f72b
6ed1a40
4c16e77
7fd6498
ea35e44
b424555
5a49d17
2f90fe5
ed3988f
254afc2
8768d8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
name = "SimpleChains" | ||
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" | ||
authors = ["Chris Elrod <[email protected]> and contributors"] | ||
version = "0.2.12" | ||
version = "0.3.0" | ||
|
||
[deps] | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" | ||
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9" | ||
|
@@ -27,6 +28,7 @@ VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e" | |
|
||
[compat] | ||
ArrayInterface = "6" | ||
ArrayInterfaceCore = "0.1.14" | ||
CPUSummary = "0.1.8" | ||
ChainRulesCore = "0.8, 0.9, 0.10, 1" | ||
CloseOpenIntervals = "0.1.6" | ||
|
@@ -43,7 +45,7 @@ Static = "0.7" | |
StaticArrays = "1" | ||
StrideArraysCore = "0.3.5" | ||
UnPack = "1" | ||
VectorizationBase = "0.21.30" | ||
VectorizationBase = "0.21.40" | ||
VectorizedRNG = "0.2.13" | ||
julia = "1.6" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
|
||
@static if isdefined(ChainRulesCore, :NoTangent) | ||
if isdefined(ChainRulesCore, :NoTangent) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps this whole branch is not needed anymore if we support the latest ChainRulesCore versions only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But yeah, we could support only |
||
const NoTangent = ChainRulesCore.NoTangent | ||
else | ||
const NoTangent = ChainRulesCore.DoesNotExist | ||
|
@@ -24,33 +24,59 @@ function pullback_layer!(pbl::PullBackLayer, lgrad) | |
end | ||
pullback_layer!(pbl::Ptr{UInt8}, grad) = grad, pbl | ||
|
||
struct PullBack{PBL<:PullBackLayer,G,P,M} | ||
|
||
#TODO: add support for not getting gradient with respect to input `x` | ||
# struct PullBackParam{T,L,A,PBL} | ||
# pg::Ptr{T} | ||
# l::L | ||
# arg::A | ||
# p::Ptr{T} | ||
# pu::Ptr{UInt8} | ||
# pbl::PBL # either another `PullBackLayer`, or the last memory pointer from the forward pass (to start the reverse) | ||
# end | ||
# function pullback_layer!(pbl::PullBackParam, lgrad) | ||
# grad, _ = pullback_layer!(pbl.pbl, lgrad) | ||
# pullback_param!(pbl.pg, pbl.l, grad, pbl.arg, pbl.p, pbl.pu) | ||
# end | ||
|
||
# struct PullBack{PBL<:Union{PullBackLayer,PullBackParam},G,P,M} | ||
struct PullBack{SA,PBL<:PullBackLayer,G,P,M} | ||
pbl::PBL | ||
grad::G | ||
params::P | ||
memory::M | ||
function PullBack{SA}(pbl::PBL, grad::G, params::P, memory::M) where {SA,PBL,G,P,M} | ||
new{SA,PBL,G,P,M}(pbl, grad, params, memory) | ||
end | ||
end | ||
function (pb::PullBack)(x) | ||
@inline function (pb::PullBack{SA})(x) where {SA} | ||
@unpack pbl, grad, params, memory = pb | ||
GC.@preserve grad params memory begin | ||
lgrad, pu4 = pullback_layer!(pbl, x) | ||
lgrad, _ = pullback_layer!(pbl, x) | ||
end | ||
if SA | ||
NoTangent(), | ||
_maybe_sarray(StrideArraysCore.StrideArray(lgrad, memory)), | ||
_maybe_sarray(StrideArraysCore.StrideArray(grad, memory)) | ||
else | ||
NoTangent(), | ||
StrideArraysCore.StrideArray(lgrad, memory), | ||
StrideArraysCore.StrideArray(grad, memory) | ||
Comment on lines
+61
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test? |
||
end | ||
NoTangent(), | ||
StrideArraysCore.StrideArray(lgrad, memory), | ||
StrideArraysCore.StrideArray(grad, memory) | ||
end | ||
|
||
|
||
function unsafe_valgrad_pullback!(g, layers, params, memory::Vector{UInt8}, arg) | ||
GC.@preserve g params memory begin | ||
# @show pointer(g) pointer(params) pointer(memory) | ||
l, pbl = | ||
chain_valgrad_pullback!(pointer(g), arg, layers, pointer(params), pointer(memory)) | ||
@inline function (pb::PullBack)(x::StaticArrays.SArray) | ||
@unpack pbl, grad, params, memory = pb | ||
mx = StaticArrays.MArray(x); | ||
GC.@preserve mx grad params memory begin | ||
lgrad, _ = pullback_layer!(pbl, PtrArray(mx)) | ||
end | ||
l, PullBack(pbl, g, params, memory) | ||
NoTangent(), | ||
_maybe_sarray(StrideArraysCore.StrideArray(lgrad, memory)), | ||
_maybe_sarray(StrideArraysCore.StrideArray(grad, memory)) | ||
end | ||
|
||
function chain_valgrad_pullback!( | ||
|
||
@inline function chain_valgrad_pullback!( | ||
pg, | ||
arg, | ||
layers::Tuple{X1,X2,Vararg}, | ||
|
@@ -60,75 +86,88 @@ function chain_valgrad_pullback!( | |
l = getfield(layers, 1) | ||
pg2, larg, p2, pu2 = valgrad_layer!(pg, l, arg, p, pu) | ||
|
||
# val, grad, pu3, pbl = chain_valgrad_pullback!(pg2, larg, Base.tail(layers), p2, pu2) | ||
val, pbl = chain_valgrad_pullback!(pg2, larg, Base.tail(layers), p2, pu2) | ||
pbl_ret = PullBackLayer(pg, l, arg, p, pu, pbl) | ||
return val, pbl_ret | ||
# lgrad, pu4 = pullback!(pg, l, grad, arg, p, pu, pu3) | ||
# return val, lgrad, pu4 | ||
end | ||
function chain_valgrad_pullback!( | ||
@inline function chain_valgrad_pullback!( | ||
pg, | ||
arg, | ||
layers::Tuple{X1}, | ||
p::Ptr, | ||
pu::Ptr{UInt8}, | ||
) where {X1} | ||
l = getfield(layers, 1) | ||
pg2, val, p2, pu2 = valgrad_layer!(pg, l, arg, p, pu) | ||
_, val, __, pu2 = valgrad_layer!(pg, l, arg, p, pu) | ||
|
||
# val, grad, pu3, pbl = chain_valgrad!(pg2, larg, Base.tail(layers), p2, pu2) | ||
# pu2 gets fed into eventual `pullback!` call | ||
pbl_ret = PullBackLayer(pg, l, arg, p, pu, pu2) | ||
return val, pbl_ret | ||
# lgrad, pu4 = pullback!(pg, l, grad, arg, p, pu, pu3) | ||
# return val, lgrad, pu4 | ||
end | ||
|
||
# No loss: chain closures. | ||
function _rrule(sc, arg, params, memory, ::False) | ||
valgrad_noloss(sc, arg, params, memory) | ||
function _rrule(sc, arg, params, ::False) | ||
valgrad_noloss(sc, arg, params) | ||
end | ||
function valgrad_noloss(sc, arg::AbstractArray{S}, params::StaticArrays.SVector{T}) where {T,S} | ||
mp = StaticArrays.MVector(params); | ||
@gc_preserve valgrad_noloss(sc, arg, mp) | ||
end | ||
function valgrad_noloss(sc, arg, params::AbstractVector{T}, memory = sc.memory) where {T} | ||
function valgrad_noloss(sc, arg::AbstractArray{S}, params::AbstractVector{T}) where {T,S} | ||
c = getchain(sc) | ||
@unpack layers = c | ||
parg = maybe_static_size_arg(c.inputdim, arg) | ||
arglen = length(parg) | ||
barg = preserve_buffer(arg) | ||
off = align(resize_memory!(layers, memory, parg, length(parg) * sizeof(eltype(parg)))) | ||
GC.@preserve barg memory begin | ||
g = PtrArray(reinterpret(Ptr{T}, pointer(memory) + off), (static_length(params),)) | ||
l, pullback = unsafe_valgrad_pullback!(g, layers, params, memory, parg) | ||
end | ||
return l, pullback | ||
# return StrideArraysCore.StrideArray(l, memory), pullback | ||
end | ||
|
||
glen = _try_static(numparam(sc), static_length(params)) | ||
goff = align(glen * static_sizeof(T)) | ||
aoff = align(arglen * static_sizeof(S)) | ||
|
||
num_bytes = required_bytes(Val{T}(), layers, size(parg), aoff + goff) | ||
memory = get_heap_memory(sc, num_bytes) | ||
|
||
# Loss: call `valgrad`. | ||
function _rrule(sc, arg, params, memory, ::True) | ||
l, g = valgrad(sc, arg, params, memory) | ||
# assumes no grad w/ respect to arg | ||
pullback = let g = g | ||
l̄ -> begin | ||
if !isone(l̄) | ||
@turbo for i ∈ eachindex(g) | ||
g[i] *= l̄ | ||
end | ||
end | ||
NoTangent(), NoTangent(), g | ||
GC.@preserve barg params memory begin | ||
pm = align(pointer(memory)) | ||
parg2 = PtrArray(Ptr{S}(pm), _try_static(c.inputdim, size(parg))) | ||
@inbounds @simd ivdep for i in eachindex(parg) | ||
parg2[i] = parg[i] | ||
end | ||
pm += aoff | ||
g = PtrArray(Ptr{T}(pm), (glen,)) | ||
pm += goff | ||
# @show pointer(g) pointer(params) pointer(memory) | ||
l, pbl = chain_valgrad_pullback!(pointer(g), parg2, layers, pointer(params), pm) | ||
end | ||
if arg isa StaticArrays.SArray | ||
_maybe_sarray(l), PullBack{true}(pbl, g, params, memory) | ||
else | ||
l, PullBack{true}(pbl, g, params, memory) | ||
end | ||
l, pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
sc::AbstractPenalty, arg, params, memory = task_local_memory() | ||
) | ||
_rrule(sc, arg, params, memory, True()) | ||
struct ElementwisePullback{G} | ||
g::G | ||
end | ||
function ChainRulesCore.rrule( | ||
sc::SimpleChain, arg, params, memory = task_local_memory() | ||
) | ||
_rrule(sc, arg, params, memory, has_loss_typed(sc)) | ||
#TODO: add support for getting gradient with respect to `arg` | ||
function (ep::ElementwisePullback)(l̄) | ||
g = ep.g | ||
if !isone(l̄) | ||
@turbo for i ∈ eachindex(g) | ||
g[i] *= l̄ | ||
end | ||
chriselrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
# assumes no grad w/ respect to arg | ||
NoTangent(), NoTangent(), g | ||
end | ||
# Loss: call `valgrad`. | ||
function _rrule(sc, arg, params, ::True) | ||
l, g = valgrad(sc, arg, params) | ||
l, ElementwisePullback(g) | ||
end | ||
# TODO: support penalties without returning scalars | ||
_returns_scalar(::AbstractPenalty) = True() | ||
_returns_scalar(sc::SimpleChain) = has_loss_typed(sc) | ||
|
||
function ChainRulesCore.rrule(sc::Chain, arg, params) | ||
_rrule(sc, arg, params, _returns_scalar(sc)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the 2 was dropped here, was this on purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The 2 came from forward + reverse, while
forward_layer_output_size
is forwad only.