Skip to content

Commit

Permalink
Merge pull request #93 from PumasAI/staticarrays2staticarrays
Browse files Browse the repository at this point in the history
StaticArrays in, StaticArrays out
  • Loading branch information
chriselrod authored Jun 29, 2022
2 parents 5a493d5 + cefd7f2 commit 287adee
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 49 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
Expand All @@ -39,6 +40,7 @@ Polyester = "0.4, 0.5, 0.6"
SIMDTypes = "0.1"
SLEEFPirates = "0.6"
Static = "0.7"
StaticArrays = "1"
StrideArraysCore = "0.3.5"
UnPack = "1"
VectorizationBase = "0.21.30"
Expand Down
14 changes: 5 additions & 9 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using SIMDTypes: Bit, NativeTypes
using VectorizationBase: align, relu, stridedpointer, AbstractSIMD, NativeTypesV
using HostCPUFeatures: static_sizeof, register_size, register_count, static_sizeof
using CPUSummary: cache_linesize, num_threads, num_cores
using LayoutPointers: bytestrideindex, stridedpointer, zero_offsets
using LayoutPointers: bytestrideindex, stridedpointer, zero_offsets, val_dense_dims
using Static: One, lt
using CloseOpenIntervals: CloseOpen
using StrideArraysCore: zview, @gc_preserve
Expand Down Expand Up @@ -84,14 +84,10 @@ include("optimize.jl")
if VERSION >= v"1.7.0"
if hasfield(Method, :recursion_relation)
dont_limit = Returns(true)
for m in methods(chain_valgrad!)
m.recursion_relation = dont_limit
end
for m in methods(_chain)
m.recursion_relation = dont_limit
end
for m in methods(output_size)
m.recursion_relation = dont_limit
for f = (chain_valgrad!, _chain, output_size, _numparam)
for m in methods(f)
m.recursion_relation = dont_limit
end
end
end
end
Expand Down
5 changes: 2 additions & 3 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,8 @@ function getparams(c::Conv, p::Ptr{T}, inputdim::Tuple{Vararg{Integer}}) where {
end

function layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T}
g1, outputdim = numparam(c, inputdim)
g2 = prod(outputdim)
align(static_sizeof(T) * g1) + 2align(static_sizeof(T) * g2), outputdim
_, outputdim = numparam(c, inputdim)
2align(static_sizeof(T) * prod(outputdim)), outputdim
end

function init_params!(c::Conv, p, inputdim)
Expand Down
45 changes: 21 additions & 24 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ function numparam(d::TurboDense, inputdim::Tuple)
np, (d.outputdim, Base.tail(inputdim)...)
end
_numparam(d::TurboDense{false}, inputdim::Integer) = inputdim * d.outputdim
_numparam(d::TurboDense{true}, inputdim::Integer) = (inputdim + 1) * d.outputdim
_numparam(d::TurboDense{true}, inputdim::Integer) = inputdim * d.outputdim + d.outputdim
parameter_free(::TurboDense) = false
function layer_output_size(::Val{T}, td::TurboDense, inputdim::Tuple) where {T}
g1, outputdim = numparam(td, inputdim)
g2 = prod(outputdim)
align(static_sizeof(T) * g1) + 2align(static_sizeof(T) * g2), outputdim
_, outputdim = numparam(td, inputdim)
2align(static_sizeof(T) * prod(outputdim)), outputdim
end

fast_fuse(td::TurboDense) = fast_fuse(getfield(td, :f))

function getparams(td::TurboDense{false}, p::Ptr{T}, inputdim::Integer) where {T}
Expand Down Expand Up @@ -153,7 +153,6 @@ function (td::TurboDense{O})(
put = Base.unsafe_convert(Ptr{T}, pu)
A, p = getparams(td, p, size(B, StaticInt(1)))
C, _pu =
# alloc_return(td, size(pB, StaticInt(2)), put, contiguous_axis(B), stride_rank(B))
alloc_return(
td,
size(pB, StaticInt(2)),
Expand Down Expand Up @@ -429,13 +428,13 @@ function dense!(
end
function dense!(
::typeof(relu),
∂C::AbstractArray{Bool,N},
C::AbstractArray{T1,N},
∂C::AbstractMatrix{Bool},
C::AbstractMatrix{T1},
A::AbstractMatrix,
B::AbstractArray{T2,N},
B::AbstractMatrix{T2},
::True,
) where {T1<:Base.HWReal,T2<:Base.HWReal,N}
Kp1 = ArrayInterface.size(A, StaticInt(2))
) where {T1<:Base.HWReal,T2<:Base.HWReal}
Kp1 = size(A, StaticInt(2))
K = Kp1 - StaticInt(1)
@turbo for n indices((B, C), 2), m indices((A, C), 1)
Cmn = zero(eltype(C))
Expand All @@ -456,29 +455,28 @@ function dense!(
B::AbstractVector{T2},
::True,
) where {T1<:Base.HWReal,T2<:Base.HWReal}
Kp1 = ArrayInterface.size(A, StaticInt(2))
Kp1 = size(A, StaticInt(2))
K = Kp1 - StaticInt(1)
n = StaticInt(1)
@turbo for m indices((A, C), 1)
Cmn = zero(eltype(C))
for k 1:K
Cmn += A[m, k] * B[k, n]
Cmn += A[m, k] * B[k]
end
Cmnr = Cmn + A[m, Kp1]
Cmnr_gt_0 = Cmnr > zero(Cmnr)
C[m, n] = ifelse(Cmnr_gt_0, Cmnr, zero(Cmnr))
∂C[m, n] = Cmnr_gt_0
C[m] = ifelse(Cmnr_gt_0, Cmnr, zero(Cmnr))
∂C[m] = Cmnr_gt_0
end
end

function dense!(
::typeof(relu),
∂C::AbstractArray{Bool,N},
C::AbstractArray{T1,N},
∂C::AbstractMatrix{Bool},
C::AbstractMatrix{T1},
A::AbstractMatrix,
B::AbstractArray{T2,N},
B::AbstractMatrix{T2},
::False,
) where {T1<:Base.HWReal,T2<:Base.HWReal,N}
) where {T1<:Base.HWReal,T2<:Base.HWReal}
@turbo for n indices((B, C), 2), m indices((A, C), 1)
Cmn = zero(eltype(C))
for k indices((A, B), (2, 1))
Expand All @@ -498,15 +496,14 @@ function dense!(
::False,
) where {T1<:Base.HWReal,T2<:Base.HWReal}
K = ArrayInterface.size(A, StaticInt(2))
n = StaticInt(1)
@turbo for m indices((A, C), 1)
Cmn = zero(eltype(C))
for k 1:K
Cmn += A[m, k] * B[k, n]
Cmn += A[m, k] * B[k]
end
Cmn_gt_0 = Cmn > zero(Cmn)
C[m, n] = ifelse(Cmn_gt_0, Cmn, zero(Cmn))
∂C[m, n] = Cmn_gt_0
C[m] = ifelse(Cmn_gt_0, Cmn, zero(Cmn))
∂C[m] = Cmn_gt_0
end
end
function dense!(
Expand Down Expand Up @@ -874,7 +871,7 @@ alloc_return_B_dense(B::PtrArray, pu::Ptr{UInt8}, _) = (B, pu) # assume `PtrArra
function alloc_return_B_dense(B::AbstractArray{T}, pu::Ptr{UInt8}, input_dim) where {T}
si = bytestrideindex(B)
sp = stridedpointer(reinterpret(Ptr{T}, pu), si)
= PtrArray(sp, (input_dim, size(B, static(2))), StrideArraysCore.val_dense_dims(B))
= PtrArray(sp, (input_dim, size(B, static(2))), val_dense_dims(B))
B̄, pu + align(length(B̄) * sizeof(T))
end
function pullback!(
Expand Down
2 changes: 1 addition & 1 deletion src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end


gradval(::Val{T}, d::Dropout) where {T} = T(0xffffffff) / (T(0xffffffff) - d.p)
numparam(::Dropout, id) = 0, id
numparam(::Dropout, id) = static(0), id
parameter_free(::Dropout) = true

init_params!(::Dropout, p, id) = p, id
Expand Down
2 changes: 1 addition & 1 deletion src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ _iterate_over_losses(_) = false
iterate_over_losses(sc) = _iterate_over_losses(target(sc))

parameter_free(::AbstractLoss) = true
numparam(::AbstractLoss, _) = 0, 1
numparam(::AbstractLoss, _) = static(0), 1
function _layer_output_size_needs_temp(
::Val{T},
sl::AbstractLoss{<:AbstractArray{<:AbstractArray}},
Expand Down
49 changes: 38 additions & 11 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Base.vcat(c::SimpleChain, l) = SimpleChain(chain_input_dims(c), (c.layers..., l)

function numparam(c::SimpleChain, id = nothing)
_id = chain_input_dims(c, id)
_numparam(0, c.layers, _id)
_numparam(static(0), c.layers, _id)
end
_numparam(s, ::Tuple{}, _) = s
function _numparam(s, layers::Tuple{L,Vararg}, id) where {L}
Expand Down Expand Up @@ -113,7 +113,7 @@ end
additional_per_thread,
nthread,
) where {T}
base_mem_per_thread = 2output_size(Val(T), layers, sx) + additional_per_thread
base_mem_per_thread = output_size(Val(T), layers, sx) + additional_per_thread
mem_total = additional + base_mem_per_thread * nthread
if mem_total > length(memory)
empty!(memory)
Expand Down Expand Up @@ -178,6 +178,28 @@ function (c::SimpleChain)(arg, params, memory = task_local_memory())
resize_memory!(layers, memory, parg)
GC.@preserve arg unsafe_chain(layers, params, memory, parg)
end
using StaticArrays
@inline _maybe_sarray(x) = x
@inline _maybe_sarray(x::AbstractArray) = _maybe_sarray(x, size(x))
@inline _maybe_sarray(x::AbstractArray, _) = x
@generated function _maybe_sarray(A::AbstractArray, s::Tuple{Vararg{StaticInt}})
k = known(s)
t = Expr(:tuple)
ct = Expr(:curly, :Tuple)
for x in k
push!(ct.args, x)
end
for i = 1:prod(k)::Int
push!(t.args, :(unsafe_load(p, $i)))
end
Expr(:block, Expr(:meta, :inline), :(p = pointer(A)), :(GC.@preserve A SArray{$ct}($t)))
end
function (c::SimpleChain)(arg::SArray, params, memory = task_local_memory())
marg = MArray(arg)
GC.@preserve marg begin
_maybe_sarray(c(PtrArray(marg), params, memory))
end
end
@inline function unsafe_chain(layers, params, memory::Vector{UInt8}, arg)
GC.@preserve params memory _chain(arg, layers, pointer(params), pointer(memory))
end
Expand All @@ -186,7 +208,9 @@ end
@inline function (output_size(::Val{T}, x::Tuple{X}, s1)::Int) where {T,X}
first(layer_output_size(Val{T}(), getfield(x, 1), s1))
end
@inline function (output_size(::Val{T}, x::Tuple{X1,X2,Vararg}, s1::Tuple)::Int) where {T,X1,X2}
@inline function (
output_size(::Val{T}, x::Tuple{X1,X2,Vararg}, s1::Tuple)::Int
) where {T,X1,X2}
b, s2 = layer_output_size(Val{T}(), getfield(x, 1), s1)
b + output_size(Val{T}(), Base.tail(x), s2)
end
Expand Down Expand Up @@ -258,7 +282,7 @@ function init_params(
::Type{T} = Float32,
) where {T}
_id = chain_input_dims(Λ, id)
init_params!(Λ, Vector{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id))
init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id))
end
"""
SimpleChains.init_params(chn[, id = nothing][, ::Type{T} = Float32])
Expand Down Expand Up @@ -376,22 +400,25 @@ function chain_valgrad!(pg, arg, layers::Tuple{X}, p::Ptr, pu::Ptr{UInt8}) where
return val, lgrad, pu3
end
@inline getchain(sc::SimpleChain) = sc
function valgrad(
sc, arg, params::AbstractVector{T},
memory = task_local_memory()
) where {T}
function valgrad(sc, arg, params::AbstractVector{T}, memory = task_local_memory()) where {T}
c = getchain(sc)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
off = align(resize_memory!(layers, memory, parg))
glen = _try_static(numparam(sc), static_length(params))
off = align(resize_memory!(layers, memory, parg, glen*static_sizeof(T)))
GC.@preserve memory arg begin
g = PtrArray(reinterpret(Ptr{T}, pointer(memory) + off), (static_length(params),))
g = PtrArray(Ptr{T}(pointer(memory) + off), (glen,))
l = Base.FastMath.add_fast(
unsafe_valgrad!(g, layers, params, memory, parg),
apply_penalty!(g, getpenalty(sc), params, size(parg)),
)
end
return l, StrideArraysCore.StrideArray(g, memory)
gv = StrideArraysCore.StrideArray(g, memory)
if arg isa SArray
return l, _maybe_sarray(gv)
else
return l, gv
end
end

isstochastic(_) = false
3 changes: 3 additions & 0 deletions test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ sc = SimpleChain(
p = SimpleChains.init_params(sc)

@test SimpleChains.remove_loss(sc)(x, p) isa AbstractVector
using SimpleChains.StaticArrays
@test @inferred(SimpleChains.remove_loss(sc)(SVector{5}(x), p)) isa SVector{2,Float64}
@test @inferred(SimpleChains.valgrad(sc, SVector{5}(x), p)) isa Tuple{Float64,SVector{126,Float32}}

g = similar(p);
g2 = similar(g);
Expand Down

0 comments on commit 287adee

Please sign in to comment.