Skip to content

Commit

Permalink
Merge pull request #105 from PumasAI/fixsegfaults
Browse files Browse the repository at this point in the history
promote types for calculating num_bytes
  • Loading branch information
chriselrod authored Sep 7, 2022
2 parents 13d60f0 + abf9947 commit a735c44
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
1 change: 1 addition & 0 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export SimpleChain,
FrontLastPenalty

const Integer = Union{StaticInt,Base.Integer}
const MAXSTACK = 16384

include("memory.jl")
include("simple_chain.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ function valgrad_noloss(sc, arg::AbstractArray{S}, params::AbstractVector{T}) wh
goff = align(glen * static_sizeof(T))
aoff = align(arglen * static_sizeof(S))

num_bytes = required_bytes(Val{T}(), layers, size(parg), aoff + goff)
num_bytes = required_bytes(Val{promote_type(T,S)}(), layers, size(parg), aoff + goff)
memory = get_heap_memory(sc, num_bytes)

GC.@preserve barg params memory begin
Expand Down
19 changes: 10 additions & 9 deletions src/memory.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

const MAXSTACK = 16384
_static_max_stack(::StaticInt{N}) where {N} = StaticInt{N}()
_static_max_stack(_) = StaticInt{MAXSTACK}()

Expand All @@ -13,16 +12,18 @@ function task_local_memory(sc)::Vector{UInt8}
end
)::Vector{UInt8}
end

@inline function with_stack_memory(f::F, ::StaticInt{N}, sc, args::Vararg{Any,K}) where {F,N,K}
stack_memory = Ref{NTuple{N,UInt8}}()
p = Base.unsafe_convert(Ptr{UInt8}, stack_memory)
ret = GC.@preserve stack_memory f(
sc,
align(p),
args...,
)
VectorizationBase.lifetime_end!(p, Val{N}())
# stack_memory = pointer(NOTSTACKMEM) + (Threads.threadid()-1)*MAXSTACK
GC.@preserve stack_memory begin
p = Base.unsafe_convert(Ptr{UInt8}, stack_memory)
ret = f(
sc,
align(p),
args...,
)
VectorizationBase.lifetime_end!(p, Val{N}())
end
return ret
end
function get_heap_memory(sc, num_bytes)
Expand Down
23 changes: 13 additions & 10 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,12 @@ Arguments:
"""
function train_unbatched!(
g,
p::AbstractVector{T},
p::AbstractVector,
_chn::Chain,
X,
opt::AbstractOptimizer,
t,
) where {T}
)
if g isa AbstractMatrix && size(g, 2) == 1
gpb = preserve_buffer(g)
gv = PtrArray(pointer(g), (length(p),))
Expand All @@ -392,6 +392,7 @@ function train_unbatched!(
pX = maybe_static_size_arg(chn.inputdim, X)
optoff = optmemsize(opt, p)
@unpack layers = chn
T = Base.promote_eltype(p, X)
bytes_per_thread, total_bytes =
required_bytes(Val{T}(), layers, size(pX), optoff, static(0), size(g, static(2)))
GC.@preserve X begin
Expand All @@ -411,20 +412,20 @@ function train_unbatched!(
end

function train_unbatched!(
p::AbstractVector{T},
p::AbstractVector,
_chn::Chain,
X::AbstractArray,
opt::AbstractOptimizer,
t,
) where {T}

)
chn = getchain(_chn)
pX = maybe_static_size_arg(chn.inputdim, X)
optoff = optmemsize(opt, p)
@unpack layers = chn
glen = _try_static(numparam(chn), static_length(params))
numthreads = _numthreads()

T = Base.promote_eltype(p, X)
bytes_per_thread, total_bytes = required_bytes(
Val{T}(),
layers,
Expand Down Expand Up @@ -576,31 +577,32 @@ function train_batched_core!(
c::Chain,
pu::Ptr{UInt8},
::Nothing,
p::AbstractVector{T},
p::AbstractVector,
pX,
opt::AbstractOptimizer,
iters,
leaveofflast::Bool,
mpt,
N_bs,
) where {T}
)
numthreads = _numthreads()
glen = _try_static(numparam(getchain(c)), static_length(p))
aligned_glen = align(glen)
T = Base.promote_eltype(p, pX)
g = _alloc_grad(Ptr{T}(pu), glen, numthreads, aligned_glen)
offset = static_sizeof(T) * aligned_glen * numthreads
train_batched_core!(c, pu + offset, g, p, pX, opt, iters, leaveofflast, mpt, N_bs)
end
function train_batched!(
g::Union{Nothing,AbstractVector{T},AbstractMatrix{T}},
p::AbstractVector{T},
g::Union{Nothing,AbstractVector,AbstractMatrix},
p::AbstractVector,
_chn::Chain,
X,
opt::AbstractOptimizer,
iters;
batchsize = nothing,
leaveofflast::Bool = false,
) where {T}
)
if g isa AbstractMatrix && size(g, 2) == 1
gpb = preserve_buffer(g)
gv = PtrArray(pointer(g), (length(p),))
Expand Down Expand Up @@ -637,6 +639,7 @@ function train_batched!(
else
base_mem = optoff + perm_mem
end
T = Base.promote_eltype(p, X)
mpt, total_bytes =
required_bytes(Val{T}(), layers, sxb, base_mem, shuffle_per_thread, nthread)
GC.@preserve X begin
Expand Down
63 changes: 46 additions & 17 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,59 @@ function verify_arg(c, arg)
throw(ArgumentError("Input argument: !matches(chain_input_dims(c), size(arg))"))
end
end
struct SArrayOutput{T}
f::T
end
@inline function (f::SArrayOutput)(x::Vararg{Any,K}) where {K}
fx = f.f(x...)
_maybe_sarray(fx, size(fx))
end


function (c::SimpleChain)(arg::AbstractArray{T}, params) where {T}
function (c::SimpleChain)(
arg::AbstractArray{T0},
params::AbstractVector{T1}
) where {T0,T1}
verify_arg(c, arg)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
num_bytes = required_forward_bytes(Val(T), layers, size(parg), static(0))
num_bytes = required_forward_bytes(Val(promote_type(T0,T1)), layers, size(parg), static(0))
if has_loss(c)
GC.@preserve arg params with_memory(_chain, c, num_bytes, parg, pointer(params))
else
GC.@preserve arg params begin
res, heap_memory = with_heap_memory(_chain, c, num_bytes, parg, pointer(params))
StrideArray(res, heap_memory)
lret = with_memory(_chain, c, num_bytes, parg, pointer(params))
end
return lret
else
ol = tsprod(outputdim(c, size(arg)))
if ol isa StaticInt && num_bytes isa StaticInt && ol < 64 && num_bytes <= MAXSTACK
GC.@preserve arg params begin
saret = with_stack_memory(
SArrayOutput(_chain), num_bytes, c, parg, pointer(params)
)
end
return saret
else
GC.@preserve arg params begin
res, heap_memory = with_heap_memory(_chain, c, num_bytes, parg, pointer(params))
sret = StrideArray(res, heap_memory)
end
return sret
end
end
end
@inline _maybe_sarray(x) = x
@inline _maybe_sarray(x::AbstractArray) = _maybe_sarray(x, size(x))
@inline _maybe_sarray(x::AbstractArray, _) = x
@inline _maybe_sarray(A::AbstractArray, s::Tuple{Vararg{StaticInt}}) = _to_sarray(A, s)
@generated function _marray_type(s::Tuple{Vararg{StaticInt}})
k = known(s)
t = Expr(:tuple)
ct = Expr(:curly, :Tuple)
for x in k
push!(ct.args, x)
end
:($StaticArrays.MArray{$ct})
end
@inline function _maybe_sarray(A::AbstractArray{T}, s::Tuple{Vararg{StaticInt}}) where {T}
@inline function _to_sarray(A::AbstractArray{T}, s::Tuple{Vararg{StaticInt}}) where {T}
B = _marray_type(s){T}(undef)
if T <: Base.HWReal
@turbo for i = eachindex(B)
Expand Down Expand Up @@ -203,13 +228,13 @@ end
mparams = StaticArrays.MArray(params)
@gc_preserve c(arg, mparams)
end
@inline function (c::SimpleChain)(arg::StaticArrays.SArray, params)
@inline function (c::SimpleChain)(arg::StaticArrays.SArray, params::AbstractVector)
verify_arg(c, arg)
@unpack layers = c
marg = StaticArrays.MArray(arg)
GC.@preserve marg params begin
parg = maybe_static_size_arg(c.inputdim, marg)
num_bytes = required_forward_bytes(Val(eltype(arg)), layers, size(parg), static(0))
num_bytes = required_forward_bytes(Val(Base.promote_eltype(arg, params)), layers, size(parg), static(0))
if has_loss(c)
with_memory(_chain, c, num_bytes, parg, pointer(params))
else
Expand Down Expand Up @@ -422,12 +447,14 @@ function valgrad!(memory::Ptr{UInt8}, g, c::SimpleChain, arg, params)
parg = maybe_static_size_arg(c.inputdim, arg)
GC.@preserve arg unsafe_valgrad!(c, memory, g, params, parg)
end
function valgrad!(g, c::SimpleChain, arg::AbstractArray{T}, params) where {T}
function valgrad!(
g, c::SimpleChain, arg::AbstractArray{T0}, params::AbstractVector{T1}
) where {T0, T1}
verify_arg(c, arg)
@assert has_loss(c)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
num_bytes = required_bytes(Val{T}(), layers, size(parg), static(0))
num_bytes = required_bytes(Val{promote_type(T0,T1)}(), layers, size(parg), static(0))
GC.@preserve arg with_memory(unsafe_valgrad!, c, num_bytes, g, params, parg)
end

Expand Down Expand Up @@ -587,35 +614,37 @@ function valgrad_core_sarray(
)
return l, _maybe_sarray(g, (static(L),))
end
function valgrad(sc::Chain, arg, params::AbstractVector{T}) where {T}
function valgrad(sc::Chain, arg, params::AbstractVector{TP}) where {TP}
c = getchain(sc)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
glen = _try_static(numparam(sc), static_length(params))
T = Base.promote_eltype(arg, params)
num_bytes = required_bytes(Val{T}(), layers, size(parg), glen * static_sizeof(T))
l, heap_memory = with_heap_memory(valgrad_core, sc, num_bytes, parg, params, glen)
gv = StrideArraysCore.StrideArray(
PtrArray(align(Ptr{T}(pointer(heap_memory))), (glen,)),
PtrArray(align(Ptr{TP}(pointer(heap_memory))), (glen,)),
heap_memory,
)
return l, gv
end
@inline function valgrad(
sc::Chain,
arg::StaticArrays.SArray,
params::AbstractVector{T},
) where {T}
params::AbstractVector{TP},
) where {TP}
c = getchain(sc)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
glen = _try_static(numparam(sc), static_length(params))
T = Base.promote_eltype(arg, params)
num_bytes = required_bytes(Val{T}(), layers, size(parg), glen * static_sizeof(T))
if glen isa StaticInt
return with_memory(valgrad_core_sarray, sc, num_bytes, parg, params, glen)
else
l, heap_memory = with_heap_memory(valgrad_core, sc, num_bytes, parg, params, glen)
gv = StrideArraysCore.StrideArray(
PtrArray(Ptr{T}(pointer(heap_memory)), (glen,)),
PtrArray(Ptr{TP}(pointer(heap_memory)), (glen,)),
heap_memory,
)
return l, gv
Expand Down
6 changes: 3 additions & 3 deletions test/mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ lenet = SimpleChain(
TurboDense(SimpleChains.relu, 84),
TurboDense(identity, 10),
)
# 3d and 0-indexed
xtrain3, ytrain0 = MLDatasets.MNIST.traindata(Float32);
xtest3, ytest0 = MLDatasets.MNIST.testdata(Float32);
# 3d and 0-indexed
xtrain3, ytrain0 = MLDatasets.MNIST(:train)[:]
xtest3, ytest0 = MLDatasets.MNIST(:test)[:]
xtrain4 = reshape(xtrain3, 28, 28, 1, :);
xtest4 = reshape(xtest3, 28, 28, 1, :);
ytrain1 = UInt32.(ytrain0 .+ 1);
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ InteractiveUtils.versioninfo(verbose=true)
ldd = tanh.(Add * x .+ bdd)
ldd_dd = tanh.(Add * xdd .+ bdd)
GC.@preserve pd pu begin
@test reinterpret(T, td(x, pointer(pd), pointer(pu))[1]) == reinterpret(T, SimpleChain(td)(x, pd))
@test reinterpret(T, ld) reinterpret(T, td(x, pointer(pd), pointer(pu))[1])
@test reinterpret(T, ld)
reinterpret(T, td(permutedims(x)', pointer(pd), pointer(pu))[1])
Expand Down

2 comments on commit a735c44

@chriselrod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67831

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" a735c44c976aa699b6263629b4c21c227504ad08
git push origin v0.3.1

Please sign in to comment.