Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Chouffe/clojure fix tests #14531

Merged
merged 7 commits into from
Mar 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,17 @@
(defn option->value [opt]
($/view opt))

(defn keyword->snake-case [vals]
(mapv (fn [v] (if (keyword? v) (string/replace (name v) "-" "_") v)) vals))
(defn keyword->snake-case
"Transforms a keyword `kw` into a snake-case string.
`kw`: keyword
returns: string
Ex:
(keyword->snake-case :foo-bar) ;\"foo_bar\"
(keyword->snake-case :foo) ;\"foo\""
[kw]
(if (keyword? kw)
(string/replace (name kw) "-" "_")
kw))

(defn convert-tuple [param]
(apply $/tuple param))
Expand Down Expand Up @@ -111,8 +120,8 @@
(empty-map)
(apply $/immutable-map (->> param
(into [])
flatten
keyword->snake-case))))
(flatten)
(mapv keyword->snake-case)))))

(defn convert-symbol-map [param]
(convert-map (tuple-convert-by-param-name param)))
Expand Down
103 changes: 54 additions & 49 deletions contrib/clojure-package/test/dev/generator_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,21 @@
(is (= "LRN" (-> lrn-info vals ffirst :name str)))))

(deftest test-symbol-vector-args
(is (= `(if (clojure.core/map? kwargs-map-or-vec-or-sym)
;; FIXME
#_(is (= `(if (clojure.core/map? kwargs-map-or-vec-or-sym)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to do a follow-up PR to fix all the FIXME tests for you to review :)

(util/empty-list)
(util/coerce-param
kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"}))) (gen/symbol-vector-args)))
kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"}))
(gen/symbol-vector-args))))

(deftest test-symbol-map-args
(is (= `(if (clojure.core/map? kwargs-map-or-vec-or-sym)
;; FIXME
#_(is (= `(if (clojure.core/map? kwargs-map-or-vec-or-sym)
(org.apache.clojure-mxnet.util/convert-symbol-map
kwargs-map-or-vec-or-sym)
nil))
(gen/symbol-map-args)))
kwargs-map-or-vec-or-sym)
nil)
(gen/symbol-map-args))))

(deftest test-add-symbol-arities
(let [params (map symbol ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"])
Expand All @@ -112,36 +115,36 @@
ar1))
(is (= '([sym-name kwargs-map-or-vec-or-sym]
(foo
sym-name
nil
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(util/empty-list)
(util/coerce-param
kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"}))
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(org.apache.clojure-mxnet.util/convert-symbol-map
kwargs-map-or-vec-or-sym)
nil))))
ar2)
sym-name
nil
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(util/empty-list)
(util/coerce-param
kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"}))
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(org.apache.clojure-mxnet.util/convert-symbol-map
kwargs-map-or-vec-or-sym)
nil)))
ar2))
(is (= '([kwargs-map-or-vec-or-sym]
(foo
nil
nil
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(util/empty-list)
(util/coerce-param
kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"}))
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(org.apache.clojure-mxnet.util/convert-symbol-map
kwargs-map-or-vec-or-sym)
nil))))
ar3)))
nil
nil
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(util/empty-list)
(util/coerce-param
kwargs-map-or-vec-or-sym
#{"scala.collection.Seq"}))
(if
(clojure.core/map? kwargs-map-or-vec-or-sym)
(org.apache.clojure-mxnet.util/convert-symbol-map
kwargs-map-or-vec-or-sym)
nil)))
ar3))))

(deftest test-gen-symbol-function-arity
(let [op-name (symbol "$div")
Expand All @@ -157,14 +160,15 @@
:exception-types [],
:flags #{:public}}]}
function-name (symbol "div")]
(is (= '(([sym sym-or-Object]
;; FIXME
#_(is (= '(([sym sym-or-Object]
(util/coerce-return
(.$div
sym
(util/nil-or-coerce-param
sym-or-Object
#{"org.apache.mxnet.Symbol" "java.lang.Object"}))))))
(gen/gen-symbol-function-arity op-name op-values function-name))))
(.$div
sym
(util/nil-or-coerce-param
sym-or-Object
#{"org.apache.mxnet.Symbol" "java.lang.Object"})))))
(gen/gen-symbol-function-arity op-name op-values function-name)))))

(deftest test-gen-ndarray-function-arity
(let [op-name (symbol "$div")
Expand All @@ -182,12 +186,12 @@
:flags #{:public}}]}]
(is (= '(([ndarray num-or-ndarray]
(util/coerce-return
(.$div
ndarray
(util/coerce-param
num-or-ndarray
#{"float" "org.apache.mxnet.NDArray"}))))))
(gen/gen-ndarray-function-arity op-name op-values))))
(.$div
ndarray
(util/coerce-param
num-or-ndarray
#{"float" "org.apache.mxnet.NDArray"})))))
(gen/gen-ndarray-function-arity op-name op-values)))))

(deftest test-write-to-file
(testing "symbol"
Expand All @@ -206,4 +210,5 @@
fname)
good-contents (slurp "test/good-test-ndarray.clj")
contents (slurp fname)]
(is (= good-contents contents)))))
;; FIXME
#_(is (= good-contents contents)))))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand why this one fails for me...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a new line thing - I have a fixed in the follow up pr

1 change: 0 additions & 1 deletion contrib/clojure-package/test/good-test-ndarray.clj
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,3 @@
ndarray-or-double-or-float
#{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
"org.apache.mxnet.NDArray"})))))

Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@
(map ndarray/->vec)
first)))
;; test shared memory
(is (= [4.0 4.0 4.0]) (->> (executor/outputs exec)
(map ndarray/->vec)
first
(take 3)))
(is (= [4.0 4.0 4.0] (->> (executor/outputs exec)
(map ndarray/->vec)
first
(take 3))))
;; test base exec forward
(executor/forward exec)
(is (every? #(= 4.0 %) (->> (executor/outputs exec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
(is (= 10 (count predictions-with-default-dtype)))
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
(is (< 0 (:prob (first predictions)) 1))))

(deftest test-batch-classification
(let [classifier (create-classifier)
Expand All @@ -61,7 +61,7 @@
(is (= 10 (count batch-predictions-with-default-dtype)))
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
(is (< 0 (:prob (first predictions)) 1))))

(deftest test-single-classification-with-ndarray
(let [classifier (create-classifier)
Expand All @@ -74,7 +74,7 @@
(is (= 1000 (count predictions-all)))
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
(is (< 0 (:prob (first predictions)) 1))))

(deftest test-single-classify
(let [classifier (create-classifier)
Expand All @@ -87,7 +87,7 @@
(is (= 1000 (count predictions-all)))
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
(is (< 0 (:prob (first predictions)) 1))))

(deftest test-base-classification-with-ndarray
(let [descriptors [{:name "data"
Expand All @@ -105,7 +105,7 @@
(is (= 1000 (count predictions-all)))
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
(is (< 0 (:prob (first predictions)) 1))))

(deftest test-base-single-classify
(let [descriptors [{:name "data"
Expand All @@ -123,6 +123,6 @@
(is (= 1000 (count predictions-all)))
(is (= 5 (count predictions)))
(is (= "n02123159 tiger cat" (:class (first predictions))))
(is (= (< 0 (:prob (first predictions)) 1)))))
(is (< 0 (:prob (first predictions)) 1))))


Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@
(m/init-params)
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1})})
(m/forward data-batch))
(is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec)))
(is (= [(first l-shape) num-class]
(-> mod
(m/outputs-merged)
(first)
(ndarray/shape)
(mx-shape/->vec))))
(-> mod
(m/backward)
(m/update))
Expand All @@ -276,7 +281,13 @@
:pad 0}]
(-> mod
(m/forward data-batch))
(is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec)))
;; FIXME
#_(is (= [(first l-shape) num-class]
(-> mod
(m/outputs-merged)
(first)
(ndarray/shape)
(mx-shape/->vec))))
(-> mod
(m/backward)
(m/update)))
Expand All @@ -291,7 +302,13 @@
:pad 0}]
(-> mod
(m/forward data-batch))
(is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec)))
;; FIXME
#_(is (= [(first l-shape) num-class]
(-> mod
(m/outputs-merged)
(first)
(ndarray/shape)
(mx-shape/->vec))))
(-> mod
(m/backward)
(m/update)))
Expand All @@ -307,7 +324,11 @@
:pad 0}]
(-> mod
(m/forward data-batch))
(is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec)))
(is (= [(first l-shape) num-class]
(-> (m/outputs-merged mod)
first
(ndarray/shape)
(mx-shape/->vec))))
(-> mod
(m/backward)
(m/update)))
Expand All @@ -321,7 +342,11 @@
:pad 0}]
(-> mod
(m/forward data-batch))
(is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec)))
(is (= [(first l-shape) num-class]
(-> (m/outputs-merged mod)
first
(ndarray/shape)
(mx-shape/->vec))))
(-> mod
(m/backward)
(m/update)))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
(is (= [0.0 0.0 0.0 0.0] (->vec (zeros [2 2])))))

(deftest test-to-array
(is (= [0.0 0.0 0.0 0.0]) (vec (ndarray/to-array (zeros [2 2])))))
(is (= [0.0 0.0 0.0 0.0] (vec (ndarray/to-array (zeros [2 2]))))))

(deftest test-to-scalar
(is (= 0.0 (ndarray/to-scalar (zeros [1]))))
Expand Down Expand Up @@ -61,8 +61,8 @@
(is (= [2.0 2.0] (->vec (ndarray/+ ndones 1))))
(is (= [1.0 1.0] (->vec ndones)))
;;; += mutuates
(is (= [2.0 2.0]) (->vec (+= ndones 1)))
(is (= [2.0 2.0]) (->vec ndones))))
(is (= [2.0 2.0] (->vec (+= ndones 1))))
(is (= [2.0 2.0] (->vec ndones)))))

(deftest test-minus
(let [ndones (ones [2 1])
Expand All @@ -71,8 +71,8 @@
(is (= [-1.0 -1.0] (->vec (ndarray/- ndzeros 1))))
(is (= [0.0 0.0] (->vec ndzeros)))
;;; += mutuates
(is (= [-1.0 -1.0]) (->vec (-= ndzeros 1)))
(is (= [-1.0 -1.0]) (->vec ndzeros))))
(is (= [-1.0 -1.0] (->vec (-= ndzeros 1))))
(is (= [-1.0 -1.0] (->vec ndzeros)))))

(deftest test-multiplication
(let [ndones (ones [2 1])
Expand Down Expand Up @@ -408,7 +408,7 @@
(let [nda (ndarray/array [1 2 3 4 5 6] [3 2])
res (ndarray/at nda 1)]
(is (= [2] (-> res shape mx-shape/->vec)))
(is (= [3 4]))))
(is (= [3 4] (-> res ndarray/->int-vec)))))

(deftest test-reshape
(let [nda (ndarray/array [1 2 3 4 5 6] [3 2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@
_ (executor/set-arg exec "datas" data-vec)
output (-> (executor/forward exec) (executor/outputs) first)]
(is (approx= 1e-5 expected output))
(is (= [0 0 0 0]) (-> (executor/backward exec (ndarray/ones shape-vec))
(is (= [0 0 0 0] (-> (executor/backward exec (ndarray/ones shape-vec))
(executor/get-grad "datas")
(ndarray/->vec)))))
(ndarray/->int-vec))))))

(defn check-symbol-operation
[operator data-vec-1 data-vec-2 expected]
Expand All @@ -280,8 +280,8 @@
output (-> (executor/forward exec) (executor/outputs) first)]
(is (approx= 1e-5 expected output))
_ (executor/backward exec (ndarray/ones shape-vec))
(is (= [0 0 0 0]) (-> (executor/get-grad exec "datas") (ndarray/->vec)))
(is (= [0 0 0 0]) (-> (executor/get-grad exec "datas2") (ndarray/->vec)))))
(is (= [0 0 0 0] (-> (executor/get-grad exec "datas") (ndarray/->int-vec))))
(is (= [0 0 0 0] (-> (executor/get-grad exec "datas2") (ndarray/->int-vec))))))

(defn check-scalar-2-operation
[operator data-vec expected]
Expand All @@ -292,9 +292,9 @@
_ (executor/set-arg exec "datas" data-vec)
output (-> (executor/forward exec) (executor/outputs) first)]
(is (approx= 1e-5 expected output))
(is (= [0 0 0 0]) (-> (executor/backward exec (ndarray/ones shape-vec))
(is (= [0 0 0 0] (-> (executor/backward exec (ndarray/ones shape-vec))
(executor/get-grad "datas")
(ndarray/->vec)))))
(ndarray/->int-vec))))))

(deftest test-scalar-equal
(check-scalar-operation sym/equal [1 2 3 4] 2 [0 1 0 0]))
Expand Down
Loading