diff --git a/julia/NEWS.md b/julia/NEWS.md index 71ee86ff7da4..3da119496fac 100644 --- a/julia/NEWS.md +++ b/julia/NEWS.md @@ -1,11 +1,10 @@ -# v0.4.0 (#TBD) +# v1.5.0 (#TBD) * Following material from `mx` module got exported (#TBD): * `NDArray` * `clip()` * `clip!()` * `context()` - * `empty()` * `expand_dims()` * `@inplace` * `σ()` @@ -113,6 +112,16 @@ 3.0 ``` +* `mx.empty` is deprecated and replaced by `UndefInitializer` constructor. (#TBD) + + E.g. + ```julia + julia> NDArray(undef, 2, 5) + 2×5 NDArray{Float32,2} @ CPU0: + -21260.344f0 1.674986f19 0.00016893122f0 1.8363f-41 0.0f0 + 3.0763f-41 1.14321726f27 4.24219f-8 0.0f0 0.0f0 + ``` + * A port of Python's `autograd` for `NDArray` (#274) * `size(x, dims...)` is supported now. (#TBD) diff --git a/julia/docs/src/user-guide/overview.md b/julia/docs/src/user-guide/overview.md index a81d7ff30e9e..5815bc6d772c 100644 --- a/julia/docs/src/user-guide/overview.md +++ b/julia/docs/src/user-guide/overview.md @@ -73,9 +73,13 @@ operators in Julia directly. The followings are common ways to create `NDArray` objects: -- `mx.empty(shape[, context])`: create on uninitialized array of a - given shape on a specific device. For example, - `mx.empty(2, 3)`, `mx.((2, 3), mx.gpu(2))`. +- `NDArray(undef, shape...; ctx = context, writable = true)`: + create an uninitialized array of a given shape on a specific device. + For example, + `NDArray(undef, 2, 3)`, `NDArray(undef, 2, 3, ctx = mx.gpu(2))`. +- `NDArray(undef, shape; ctx = context, writable = true)` +- `NDArray{T}(undef, shape...; ctx = context, writable = true)`: + create an uninitialized with the given type `T`. - `mx.zeros(shape[, context])` and `mx.ones(shape[, context])`: similar to the Julia's built-in `zeros` and `ones`. - `mx.copy(jl_arr, context)`: copy the contents of a Julia `Array` to @@ -101,11 +105,11 @@ shows a way to set the contents of an `NDArray`. ```@repl using MXNet mx.srand(42) -a = mx.empty(2, 3) +a = NDArray(undef, 2, 3) a[:] = 0.5 # set all elements to a scalar a[:] = rand(size(a)) # set contents with a Julia Array copy!(a, rand(size(a))) # set value by copying a Julia Array -b = mx.empty(size(a)) +b = NDArray(undef, size(a)) b[:] = a # copying and assignment between NDArrays ``` @@ -175,7 +179,7 @@ function inplace_op() grad = mx.ones(SHAPE, CTX) # pre-allocate temp objects - grad_lr = mx.empty(SHAPE, CTX) + grad_lr = NDArray(undef, SHAPE, ctx = CTX) for i = 1:N_REP copy!(grad_lr, grad) @@ -234,7 +238,7 @@ shape = (2, 3) key = 3 mx.init!(kv, key, mx.ones(shape) * 2) -a = mx.empty(shape) +a = NDArray(undef, shape) mx.pull!(kv, key, a) # pull value into a a ``` diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl index febd80cc8f8c..68663d1e561e 100644 --- a/julia/src/MXNet.jl +++ b/julia/src/MXNet.jl @@ -53,7 +53,6 @@ export NDArray, clip, clip!, context, - empty, expand_dims, @inplace, # activation funcs diff --git a/julia/src/deprecated.jl b/julia/src/deprecated.jl index 32819810eb8d..70079b8dcd62 100644 --- a/julia/src/deprecated.jl +++ b/julia/src/deprecated.jl @@ -169,3 +169,28 @@ import Base: sum, maximum, minimum, prod, cat import Statistics: mean @deprecate mean(x::NDArray, dims) mean(x, dims = dims) + +# replaced by UndefInitializer +function empty(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType} + @warn("`mx.empty(T, dims, ctx)` is deprecated, " * + "use `NDArray{T,N}(undef, dims; ctx = ctx)` instead.") + NDArray{T,N}(undef, dims; ctx = ctx) +end + +function empty(::Type{T}, dims::Int...) where {T<:DType} + @warn("`mx.empty(T, dims...)` is deprecated, " * + "use `NDArray{T,N}(undef, dims...)` instead.") + NDArray{T,N}(undef, dims...) +end + +function empty(dims::NTuple{N,Int}, ctx::Context = cpu()) where N + @warn("`mx.empty(dims, ctx)` is deprecated, " * + "use `NDArray(undef, dims; ctx = ctx)` instead.") + NDArray(undef, dims; ctx = ctx) +end + +function empty(dims::Int...) + @warn("`mx.empty(dims...)` is deprecated, " * + "use `NDArray(undef, dims...)` instead.") + NDArray(undef, dims...) +end diff --git a/julia/src/io.jl b/julia/src/io.jl index 32f7fece7e41..6309f7ecd3f9 100644 --- a/julia/src/io.jl +++ b/julia/src/io.jl @@ -360,7 +360,7 @@ function ArrayDataProvider(data, label; batch_size::Int = 0, shuffle::Bool = fal function gen_batch_nds(arrs :: Vector{Array{MX_float}}, bsize :: Int) map(arrs) do arr shape = size(arr) - empty(shape[1:end-1]..., bsize) + NDArray(undef, shape[1:end-1]..., bsize) end end diff --git a/julia/src/kvstore.jl b/julia/src/kvstore.jl index 000684d5f20d..1fb6df20d27d 100644 --- a/julia/src/kvstore.jl +++ b/julia/src/kvstore.jl @@ -128,7 +128,7 @@ One can use ``barrier()`` to sync all workers. julia> kv = KVStore(:local) mx.KVStore @ local -julia> x = mx.empty(2, 3); +julia> x = NDArray(undef, 2, 3); julia> init!(kv, 3, x) @@ -161,11 +161,11 @@ julia> x ```jldoctest julia> keys = [4, 5]; -julia> init!(kv, keys, [empty(2, 3), empty(2, 3)]) +julia> init!(kv, keys, [NDArray(undef, 2, 3), NDArray(undef, 2, 3)]) julia> push!(kv, keys, [x, x]) -julia> y, z = empty(2, 3), empty(2, 3); +julia> y, z = NDArray(undef, 2, 3), NDArray(undef, 2, 3); julia> pull!(kv, keys, [y, z]) ``` @@ -279,7 +279,7 @@ julia> init!(kv, 42, mx.ones(2, 3)) julia> push!(kv, 42, mx.ones(2, 3)) -julia> x = empty(2, 3); +julia> x = NDArray(undef, 2, 3); julia> pull!(kv, 42, x) diff --git a/julia/src/model.jl b/julia/src/model.jl index cb5f95e3c1eb..0324edd1cdc6 100644 --- a/julia/src/model.jl +++ b/julia/src/model.jl @@ -122,7 +122,7 @@ function init_model(self::FeedForward, initializer::AbstractInitializer; overwri delete!(self.arg_params, name) end end - arg_params[name] = empty(shape) + arg_params[name] = NDArray(undef, shape) end for (name, shape) in zip(aux_names, aux_shapes) @@ -135,7 +135,7 @@ function init_model(self::FeedForward, initializer::AbstractInitializer; overwri delete!(self.aux_params, name) end end - aux_params[name] = empty(shape) + aux_params[name] = NDArray(undef, shape) end for (k,v) in arg_params @@ -463,8 +463,8 @@ function fit(self::FeedForward, optimizer::AbstractOptimizer, data::AbstractData # set up output and labels in CPU for evaluation metric output_shapes = [tuple(size(x)[1:end-1]...,batch_size) for x in train_execs[1].outputs] cpu_dev = Context(CPU) - cpu_output_arrays = [empty(shape, cpu_dev) for shape in output_shapes] - cpu_label_arrays = [empty(shape, cpu_dev) for (name,shape) in provide_label(data)] + cpu_output_arrays = [NDArray(undef, shape, ctx = cpu_dev) for shape in output_shapes] + cpu_label_arrays = [NDArray(undef, shape, ctx = cpu_dev) for (name,shape) in provide_label(data)] # invoke callbacks on epoch 0 _invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback) diff --git a/julia/src/ndarray.jl b/julia/src/ndarray.jl index 6987d572ea7a..52c6dd2fc42e 100644 --- a/julia/src/ndarray.jl +++ b/julia/src/ndarray.jl @@ -110,13 +110,33 @@ mutable struct NDArray{T,N} handle :: MX_NDArrayHandle writable :: Bool - NDArray{T,N}(handle, writable = true) where {T,N} = new(handle, writable) + NDArray{T,N}(handle::MX_NDArrayHandle, writable::Bool = true) where {T,N} = + new(handle, writable) end +# UndefInitializer constructors +NDArray{T,N}(::UndefInitializer, dims::NTuple{N,Integer}; + writable = true, ctx::Context = cpu()) where {T,N} = + NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false), writable) +NDArray{T,N}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} = + NDArray{T,N}(undef, dims; kw...) + +NDArray{T}(::UndefInitializer, dims::NTuple{N,Integer}; kw...) where {T,N} = + NDArray{T,N}(undef, dims; kw...) +NDArray{T}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} = + NDArray{T,N}(undef, dims; kw...) + +NDArray(::UndefInitializer, dims::NTuple{N,Integer}; kw...) where {N} = + NDArray{DEFAULT_DTYPE,N}(undef, dims; kw...) +NDArray(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {N} = + NDArray{DEFAULT_DTYPE,N}(undef, dims; kw...) + NDArray(x::AbstractArray{<:DType}) = copy(collect(x), cpu()) NDArray(x::Array{<:DType}) = copy(x, cpu()) + NDArray(::Type{T}, x::AbstractArray) where {T<:DType} = copy(convert(AbstractArray{T}, x), cpu()) + NDArray(handle, writable = true) = NDArray{eltype(handle), ndims(handle)}(handle, writable) @@ -124,6 +144,13 @@ NDArray(handle, writable = true) = const NDArrayOrReal = Union{NDArray,Real} const VecOfNDArray = AbstractVector{<:NDArray} +Base.unsafe_convert(::Type{MX_handle}, x::NDArray) = + Base.unsafe_convert(MX_handle, x.handle) +Base.convert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x) +Base.cconvert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x) + +MX_handle(x::NDArray) = Base.convert(MX_handle, x) + function Base.show(io::IO, x::NDArray) print(io, "NDArray(") Base.show(io, try_get_shared(x, sync = :read)) @@ -139,13 +166,6 @@ function Base.show(io::IO, ::MIME{Symbol("text/plain")}, x::NDArray{T,N}) where Base.print_array(io, try_get_shared(x, sync = :read)) end -Base.unsafe_convert(::Type{MX_handle}, x::NDArray) = - Base.unsafe_convert(MX_handle, x.handle) -Base.convert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x) -Base.cconvert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x) - -MX_handle(x::NDArray) = Base.convert(MX_handle, x) - ################################################################################ # NDArray functions exported to the users ################################################################################ @@ -163,34 +183,14 @@ function context(x::NDArray) end """ - empty(DType, dims[, ctx::Context = cpu()]) - empty(DType, dims) - empty(DType, dim1, dim2, ...) - -Allocate memory for an uninitialized `NDArray` with a specified type. -""" -empty(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType} = - NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false)) -empty(::Type{T}, dims::Int...) where {T<:DType} = empty(T, dims) - -""" - empty(dims::Tuple[, ctx::Context = cpu()]) - empty(dim1, dim2, ...) - -Allocate memory for an uninitialized `NDArray` with specific shape of type Float32. -""" -empty(dims::NTuple{N,Int}, ctx::Context = cpu()) where N = - NDArray(_ndarray_alloc(dims, ctx, false)) -empty(dims::Int...) = empty(dims) - -""" - similar(x::NDArray) + similar(x::NDArray; writable, ctx) Create an `NDArray` with similar shape, data type, and context with the given one. Note that the returned `NDArray` is uninitialized. """ -Base.similar(x::NDArray{T}) where {T} = empty(T, size(x), context(x)) +Base.similar(x::NDArray{T,N}; writable = x.writable, ctx = context(x)) where {T,N} = + NDArray{T,N}(undef, size(x)...; writable = writable, ctx = ctx) """ zeros([DType], dims, [ctx::Context = cpu()]) @@ -200,7 +200,7 @@ Base.similar(x::NDArray{T}) where {T} = empty(T, size(x), context(x)) Create zero-ed `NDArray` with specific shape and type. """ function zeros(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType} - x = empty(T, dims, ctx) + x = NDArray{T}(undef, dims..., ctx = ctx) x[:] = zero(T) x end @@ -222,7 +222,7 @@ Base.zeros(x::NDArray)::typeof(x) = zeros_like(x) Create an `NDArray` with specific shape & type, and initialize with 1. """ function ones(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType} - arr = empty(T, dims, ctx) + arr = NDArray{T}(undef, dims..., ctx = ctx) arr[:] = one(T) arr end @@ -504,10 +504,10 @@ copy(x::NDArray{T,D}, ctx::Context) where {T,D} = # Create copy: Julia Array -> NDArray in a given context copy(x::Array{T}, ctx::Context) where {T<:DType} = - copy!(empty(T, size(x), ctx), x) + copy!(NDArray{T}(undef, size(x); ctx = ctx), x) copy(x::AbstractArray, ctx::Context) = - copy!(empty(eltype(x), size(x), ctx), collect(x)) + copy!(NDArray{eltype(x)}(undef, size(x); ctx = ctx), collect(x)) """ convert(::Type{Array{<:Real}}, x::NDArray) @@ -866,8 +866,8 @@ end Create an `NDArray` filled with the value `x`, like `Base.fill`. """ -function fill(x, dims::NTuple{N,Integer}, ctx::Context=cpu()) where N - arr = empty(typeof(x), dims, ctx) +function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = cpu()) where {T,N} + arr = NDArray{T}(undef, dims, ctx = ctx) arr[:] = x arr end diff --git a/julia/src/random.jl b/julia/src/random.jl index e18e906a5307..3f3b80bbab4a 100644 --- a/julia/src/random.jl +++ b/julia/src/random.jl @@ -23,12 +23,12 @@ Samples are uniformly distributed over the half-open interval [low, high) (includes low, but excludes high). ```julia -julia> mx.rand!(empty(2, 3)) +julia> mx.rand!(NDArray(undef, 2, 3)) 2×3 mx.NDArray{Float32,2} @ CPU0: 0.385748 0.839275 0.444536 0.0879585 0.215928 0.104636 -julia> mx.rand!(empty(2, 3), low = 1, high = 10) +julia> mx.rand!(NDArray(undef, 2, 3), low = 1, high = 10) 2×3 mx.NDArray{Float32,2} @ CPU0: 6.6385 4.18888 2.07505 8.97283 2.5636 1.95586 @@ -56,8 +56,8 @@ julia> mx.rand(2, 2; low = 1, high = 10) 9.81258 3.58068 ``` """ -rand(dims::Int...; low = 0, high = 1, context = cpu()) = - rand!(empty(dims, context), low = low, high = high) +rand(dims::Integer...; low = 0, high = 1, context = cpu()) = + rand!(NDArray(undef, dims, ctx = context), low = low, high = high) """ randn!(x::NDArray; μ = 0, σ = 1) @@ -73,7 +73,7 @@ randn!(x::NDArray; μ = 0, σ = 1) = Draw random samples from a normal (Gaussian) distribution. """ randn(dims::Int...; μ = 0, σ = 1, context = cpu()) = - randn!(empty(dims, context), μ = μ, σ = σ) + randn!(NDArray(undef, dims, ctx = context), μ = μ, σ = σ) """ seed!(seed::Int) diff --git a/julia/test/unittest/bind.jl b/julia/test/unittest/bind.jl index abaca884bab1..0ae0ab427b99 100644 --- a/julia/test/unittest/bind.jl +++ b/julia/test/unittest/bind.jl @@ -33,10 +33,10 @@ function test_arithmetic(::Type{T}, uf, gf) where T <: mx.DType ret = uf(lhs, rhs) @test mx.list_arguments(ret) == [:lhs, :rhs] - lhs_arr = mx.NDArray(rand(T, shape)) - rhs_arr = mx.NDArray(rand(T, shape)) - lhs_grad = mx.empty(T, shape) - rhs_grad = mx.empty(T, shape) + lhs_arr = NDArray(rand(T, shape)) + rhs_arr = NDArray(rand(T, shape)) + lhs_grad = NDArray{T}(undef, shape) + rhs_grad = NDArray{T}(undef, shape) exec2 = mx.bind(ret, mx.Context(mx.CPU), [lhs_arr, rhs_arr], args_grad=[lhs_grad, rhs_grad]) exec3 = mx.bind(ret, mx.Context(mx.CPU), [lhs_arr, rhs_arr]) diff --git a/julia/test/unittest/io.jl b/julia/test/unittest/io.jl index cf8d8368d212..7d98d28fc541 100644 --- a/julia/test/unittest/io.jl +++ b/julia/test/unittest/io.jl @@ -38,8 +38,8 @@ function test_mnist() n_batch = 0 for batch in mnist_provider if n_batch == 0 - data_array = mx.empty(28,28,1,batch_size) - label_array = mx.empty(batch_size) + data_array = NDArray(undef, 28, 28, 1, batch_size) + label_array = NDArray(undef, batch_size) # have to use "for i=1:1" to get over the legacy "feature" of using # [ ] to do concatenation in Julia data_targets = [[(1:batch_size, data_array)] for i = 1:1] diff --git a/julia/test/unittest/kvstore.jl b/julia/test/unittest/kvstore.jl index 503a1fdbd533..db6885717edc 100644 --- a/julia/test/unittest/kvstore.jl +++ b/julia/test/unittest/kvstore.jl @@ -47,7 +47,7 @@ function test_single_kv_pair() kv = init_kv() mx.push!(kv, 3, mx.ones(SHAPE)) - val = mx.empty(SHAPE) + val = NDArray(undef, SHAPE) mx.pull!(kv, 3, val) @test maximum(abs.(copy(val) .- 1)) == 0 end diff --git a/julia/test/unittest/ndarray.jl b/julia/test/unittest/ndarray.jl index 85328ff21bc8..eb69a736a6e4 100644 --- a/julia/test/unittest/ndarray.jl +++ b/julia/test/unittest/ndarray.jl @@ -57,6 +57,78 @@ function test_constructor() @test eltype(x) == Float32 @test copy(x) ≈ [1.1, 2, 3] end + + @info "NDArray::NDArray{T,N}(undef, dims...)" + let + x = NDArray{Int,2}(undef, 5, 5) + @test eltype(x) == Int + @test size(x) == (5, 5) + @test x.writable + + y = NDArray{Int,2}(undef, 5, 5, writable = false) + @test !y.writable + + # dimension mismatch + @test_throws MethodError NDArray{Int,1}(undef, 5, 5) + end + + @info "NDArray::NDArray{T,N}(undef, dims)" + let + x = NDArray{Int,2}(undef, (5, 5)) + @test eltype(x) == Int + @test size(x) == (5, 5) + @test x.writable + + y = NDArray{Int,2}(undef, (5, 5), writable = false) + @test !y.writable + + # dimension mismatch + @test_throws MethodError NDArray{Int,1}(undef, (5, 5)) + end + + @info "NDArray::NDArray{T}(undef, dims...)" + let + x = NDArray{Int}(undef, 5, 5) + @test eltype(x) == Int + @test size(x) == (5, 5) + @test x.writable + + y = NDArray{Int}(undef, 5, 5, writable = false) + @test !y.writable + end + + @info "NDArray::NDArray{T}(undef, dims)" + let + x = NDArray{Int}(undef, (5, 5)) + @test eltype(x) == Int + @test size(x) == (5, 5) + @test x.writable + + y = NDArray{Int}(undef, (5, 5), writable = false) + @test !y.writable + end + + @info "NDArray::NDArray(undef, dims...)" + let + x = NDArray(undef, 5, 5) + @test eltype(x) == mx.MX_float + @test size(x) == (5, 5) + @test x.writable + + y = NDArray(undef, 5, 5, writable = false) + @test !y.writable + end + + @info "NDArray::NDArray(undef, dims)" + let + x = NDArray(undef, (5, 5)) + @test eltype(x) == mx.MX_float + @test size(x) == (5, 5) + @test x.writable + + y = NDArray(undef, (5, 5), writable = false) + @test !y.writable + end end # function test_constructor @@ -134,8 +206,8 @@ function test_assign() @info("NDArray::assign::dims = $dims") # Julia Array -> NDArray assignment - array = mx.empty(size(tensor)) - array[:]= tensor + array = NDArray(undef, size(tensor)...) + array[:] = tensor @test tensor ≈ copy(array) array2 = mx.zeros(size(tensor)) @@ -1006,14 +1078,14 @@ end function test_eltype() @info("NDArray::eltype") - dims1 = (3,3) + dims = (3,3) - x = mx.empty(dims1) + x = NDArray(undef, dims) @test eltype(x) == mx.DEFAULT_DTYPE for TF in instances(mx.TypeFlag) T = mx.fromTypeFlag(TF) - x = mx.empty(T, dims1) + x = NDArray{T}(undef, dims) @test eltype(x) == T end end diff --git a/julia/test/unittest/random.jl b/julia/test/unittest/random.jl index 013e4f609daa..38da9601a01a 100644 --- a/julia/test/unittest/random.jl +++ b/julia/test/unittest/random.jl @@ -30,7 +30,7 @@ function test_uniform() ret1 = mx.rand(dims..., low = low, high = high) mx.seed!(seed) - ret2 = mx.empty(dims) + ret2 = NDArray(undef, dims) mx.rand!(ret2, low = low, high = high) @test copy(ret1) == copy(ret2) @@ -47,7 +47,7 @@ function test_gaussian() ret1 = mx.randn(dims..., μ = μ, σ = σ) mx.seed!(seed) - ret2 = mx.empty(dims) + ret2 = NDArray(undef, dims) mx.randn!(ret2, μ = μ, σ = σ) @test copy(ret1) == copy(ret2)