Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack Memory #97

Merged
merged 19 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.2.12"
version = "0.3.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"
Expand All @@ -27,6 +28,7 @@ VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e"

[compat]
ArrayInterface = "6"
ArrayInterfaceCore = "0.1.14"
CPUSummary = "0.1.8"
ChainRulesCore = "0.8, 0.9, 0.10, 1"
CloseOpenIntervals = "0.1.6"
Expand All @@ -43,7 +45,7 @@ Static = "0.7"
StaticArrays = "1"
StrideArraysCore = "0.3.5"
UnPack = "1"
VectorizationBase = "0.21.30"
VectorizationBase = "0.21.40"
VectorizedRNG = "0.2.13"
julia = "1.6"

Expand Down
25 changes: 17 additions & 8 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using UnPack,
StrideArraysCore,
Static,
VectorizedRNG
using ArrayInterfaceCore: CPUPointer
using ArrayInterface:
size,
strides,
Expand All @@ -29,7 +30,7 @@ using SIMDTypes: Bit, NativeTypes
using VectorizationBase: align, relu, stridedpointer, AbstractSIMD, NativeTypesV
using HostCPUFeatures: static_sizeof, register_size, register_count, static_sizeof
using CPUSummary: cache_linesize, num_threads, num_cores
using LayoutPointers: bytestrideindex, stridedpointer, zero_offsets, val_dense_dims
using LayoutPointers: bytestrideindex, stridedpointer, zstridedpointer, zero_offsets, val_dense_dims
using Static: One, lt
using CloseOpenIntervals: CloseOpen
using StrideArraysCore: zview, @gc_preserve
Expand All @@ -39,6 +40,7 @@ import Random
import ChainRulesCore
import ForwardDiff
import LoopVectorization
import StaticArrays

using LoopVectorization: matmul_params, @turbo
# using LoopVectorization: matmul_params
Expand Down Expand Up @@ -67,6 +69,7 @@ export SimpleChain,

const Integer = Union{StaticInt,Base.Integer}

include("memory.jl")
include("simple_chain.jl")
include("utils.jl")
include("activation.jl")
Expand All @@ -81,13 +84,19 @@ include("penalty.jl")
include("chain_rules.jl")
include("optimize.jl")

if VERSION >= v"1.7.0"
if hasfield(Method, :recursion_relation)
dont_limit = Returns(true)
for f = (chain_valgrad!, _chain, output_size, _numparam)
for m in methods(f)
m.recursion_relation = dont_limit
end
if VERSION >= v"1.7.0" && hasfield(Method, :recursion_relation)
dont_limit = Returns(true)
for f in (
chain_valgrad!,
chain_valgrad_pullback!,
__chain,
output_size,
forward_output_size,
_numparam,
pullback_layer!,
)
for m in methods(f)
m.recursion_relation = dont_limit
end
end
end
Expand Down
5 changes: 3 additions & 2 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ struct Activation{F}
f::F
end
parameter_free(::Activation) = true
numparam(::Activation, id) = 0, id
numparam(::Activation, id) = static(0), id
init_params!(::Activation, p, id) = p, id
_check_input_dims(::Activation, _) = nothing

layer_output_size(::Val{T}, a::Activation, s) where {T} = align(prod(s) * (2sizeof(T))), s
forward_layer_output_size(::Val{T}, a::Activation, s) where {T} =
align(prod(s) * static_sizeof(T)), s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the 2 was dropped here, was this on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The 2 came from forward + reverse, while forward_layer_output_size is forwad only.


Base.show(io::IO, a::Activation) = print(io, "Activation layer applying: ", a.f)

Expand Down
153 changes: 96 additions & 57 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@static if isdefined(ChainRulesCore, :NoTangent)
if isdefined(ChainRulesCore, :NoTangent)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the @static no longer needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps this whole branch is not needed anymore if we support the latest ChainRulesCore versions only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@static is only really needed if you have invalid syntax hiding behind a branch, or it may be nice inside a function.
Top level in a package should be interpreted anyway.

Copy link
Contributor Author

@chriselrod chriselrod Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yeah, we could support only ChainRulesCore >= 1.
I'm not sure how much of the ecosystem is still stuck on old versions. At the time, I believe many packages were.

const NoTangent = ChainRulesCore.NoTangent
else
const NoTangent = ChainRulesCore.DoesNotExist
Expand All @@ -24,33 +24,59 @@ function pullback_layer!(pbl::PullBackLayer, lgrad)
end
pullback_layer!(pbl::Ptr{UInt8}, grad) = grad, pbl

struct PullBack{PBL<:PullBackLayer,G,P,M}

#TODO: add support for not getting gradient with respect to input `x`
# struct PullBackParam{T,L,A,PBL}
# pg::Ptr{T}
# l::L
# arg::A
# p::Ptr{T}
# pu::Ptr{UInt8}
# pbl::PBL # either another `PullBackLayer`, or the last memory pointer from the forward pass (to start the reverse)
# end
# function pullback_layer!(pbl::PullBackParam, lgrad)
# grad, _ = pullback_layer!(pbl.pbl, lgrad)
# pullback_param!(pbl.pg, pbl.l, grad, pbl.arg, pbl.p, pbl.pu)
# end

# struct PullBack{PBL<:Union{PullBackLayer,PullBackParam},G,P,M}
struct PullBack{SA,PBL<:PullBackLayer,G,P,M}
pbl::PBL
grad::G
params::P
memory::M
function PullBack{SA}(pbl::PBL, grad::G, params::P, memory::M) where {SA,PBL,G,P,M}
new{SA,PBL,G,P,M}(pbl, grad, params, memory)
end
end
function (pb::PullBack)(x)
@inline function (pb::PullBack{SA})(x) where {SA}
@unpack pbl, grad, params, memory = pb
GC.@preserve grad params memory begin
lgrad, pu4 = pullback_layer!(pbl, x)
lgrad, _ = pullback_layer!(pbl, x)
end
if SA
NoTangent(),
_maybe_sarray(StrideArraysCore.StrideArray(lgrad, memory)),
_maybe_sarray(StrideArraysCore.StrideArray(grad, memory))
else
NoTangent(),
StrideArraysCore.StrideArray(lgrad, memory),
StrideArraysCore.StrideArray(grad, memory)
Comment on lines +61 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test?

end
NoTangent(),
StrideArraysCore.StrideArray(lgrad, memory),
StrideArraysCore.StrideArray(grad, memory)
end


function unsafe_valgrad_pullback!(g, layers, params, memory::Vector{UInt8}, arg)
GC.@preserve g params memory begin
# @show pointer(g) pointer(params) pointer(memory)
l, pbl =
chain_valgrad_pullback!(pointer(g), arg, layers, pointer(params), pointer(memory))
@inline function (pb::PullBack)(x::StaticArrays.SArray)
@unpack pbl, grad, params, memory = pb
mx = StaticArrays.MArray(x);
GC.@preserve mx grad params memory begin
lgrad, _ = pullback_layer!(pbl, PtrArray(mx))
end
l, PullBack(pbl, g, params, memory)
NoTangent(),
_maybe_sarray(StrideArraysCore.StrideArray(lgrad, memory)),
_maybe_sarray(StrideArraysCore.StrideArray(grad, memory))
end

function chain_valgrad_pullback!(

@inline function chain_valgrad_pullback!(
pg,
arg,
layers::Tuple{X1,X2,Vararg},
Expand All @@ -60,75 +86,88 @@ function chain_valgrad_pullback!(
l = getfield(layers, 1)
pg2, larg, p2, pu2 = valgrad_layer!(pg, l, arg, p, pu)

# val, grad, pu3, pbl = chain_valgrad_pullback!(pg2, larg, Base.tail(layers), p2, pu2)
val, pbl = chain_valgrad_pullback!(pg2, larg, Base.tail(layers), p2, pu2)
pbl_ret = PullBackLayer(pg, l, arg, p, pu, pbl)
return val, pbl_ret
# lgrad, pu4 = pullback!(pg, l, grad, arg, p, pu, pu3)
# return val, lgrad, pu4
end
function chain_valgrad_pullback!(
@inline function chain_valgrad_pullback!(
pg,
arg,
layers::Tuple{X1},
p::Ptr,
pu::Ptr{UInt8},
) where {X1}
l = getfield(layers, 1)
pg2, val, p2, pu2 = valgrad_layer!(pg, l, arg, p, pu)
_, val, __, pu2 = valgrad_layer!(pg, l, arg, p, pu)

# val, grad, pu3, pbl = chain_valgrad!(pg2, larg, Base.tail(layers), p2, pu2)
# pu2 gets fed into eventual `pullback!` call
pbl_ret = PullBackLayer(pg, l, arg, p, pu, pu2)
return val, pbl_ret
# lgrad, pu4 = pullback!(pg, l, grad, arg, p, pu, pu3)
# return val, lgrad, pu4
end

# No loss: chain closures.
function _rrule(sc, arg, params, memory, ::False)
valgrad_noloss(sc, arg, params, memory)
function _rrule(sc, arg, params, ::False)
valgrad_noloss(sc, arg, params)
end
function valgrad_noloss(sc, arg::AbstractArray{S}, params::StaticArrays.SVector{T}) where {T,S}
mp = StaticArrays.MVector(params);
@gc_preserve valgrad_noloss(sc, arg, mp)
end
function valgrad_noloss(sc, arg, params::AbstractVector{T}, memory = sc.memory) where {T}
function valgrad_noloss(sc, arg::AbstractArray{S}, params::AbstractVector{T}) where {T,S}
c = getchain(sc)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
arglen = length(parg)
barg = preserve_buffer(arg)
off = align(resize_memory!(layers, memory, parg, length(parg) * sizeof(eltype(parg))))
GC.@preserve barg memory begin
g = PtrArray(reinterpret(Ptr{T}, pointer(memory) + off), (static_length(params),))
l, pullback = unsafe_valgrad_pullback!(g, layers, params, memory, parg)
end
return l, pullback
# return StrideArraysCore.StrideArray(l, memory), pullback
end

glen = _try_static(numparam(sc), static_length(params))
goff = align(glen * static_sizeof(T))
aoff = align(arglen * static_sizeof(S))

num_bytes = required_bytes(Val{T}(), layers, size(parg), aoff + goff)
memory = get_heap_memory(sc, num_bytes)

# Loss: call `valgrad`.
function _rrule(sc, arg, params, memory, ::True)
l, g = valgrad(sc, arg, params, memory)
# assumes no grad w/ respect to arg
pullback = let g = g
l̄ -> begin
if !isone(l̄)
@turbo for i ∈ eachindex(g)
g[i] *= l̄
end
end
NoTangent(), NoTangent(), g
GC.@preserve barg params memory begin
pm = align(pointer(memory))
parg2 = PtrArray(Ptr{S}(pm), _try_static(c.inputdim, size(parg)))
@inbounds @simd ivdep for i in eachindex(parg)
parg2[i] = parg[i]
end
pm += aoff
g = PtrArray(Ptr{T}(pm), (glen,))
pm += goff
# @show pointer(g) pointer(params) pointer(memory)
l, pbl = chain_valgrad_pullback!(pointer(g), parg2, layers, pointer(params), pm)
end
if arg isa StaticArrays.SArray
_maybe_sarray(l), PullBack{true}(pbl, g, params, memory)
else
l, PullBack{true}(pbl, g, params, memory)
end
l, pullback
end

function ChainRulesCore.rrule(
sc::AbstractPenalty, arg, params, memory = task_local_memory()
)
_rrule(sc, arg, params, memory, True())
struct ElementwisePullback{G}
g::G
end
function ChainRulesCore.rrule(
sc::SimpleChain, arg, params, memory = task_local_memory()
)
_rrule(sc, arg, params, memory, has_loss_typed(sc))
#TODO: add support for getting gradient with respect to `arg`
function (ep::ElementwisePullback)(l̄)
g = ep.g
if !isone(l̄)
@turbo for i ∈ eachindex(g)
g[i] *= l̄
end
chriselrod marked this conversation as resolved.
Show resolved Hide resolved
end
# assumes no grad w/ respect to arg
NoTangent(), NoTangent(), g
end
# Loss: call `valgrad`.
function _rrule(sc, arg, params, ::True)
l, g = valgrad(sc, arg, params)
l, ElementwisePullback(g)
end
# TODO: support penalties without returning scalars
_returns_scalar(::AbstractPenalty) = True()
_returns_scalar(sc::SimpleChain) = has_loss_typed(sc)

function ChainRulesCore.rrule(sc::Chain, arg, params)
_rrule(sc, arg, params, _returns_scalar(sc))
end
4 changes: 2 additions & 2 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,9 @@ function getparams(c::Conv, p::Ptr{T}, inputdim::Tuple{Vararg{Integer}}) where {
(K, b), p + sizeof(T) * length(b)
end

function layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T}
function forward_layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T}
_, outputdim = numparam(c, inputdim)
2align(static_sizeof(T) * prod(outputdim)), outputdim
align(static_sizeof(T) * prod(outputdim)), outputdim
end

function init_params!(c::Conv, p, inputdim)
Expand Down
Loading