Skip to content

Commit

Permalink
[Clojure] Add resource scope to clojure package (apache#13993)
Browse files Browse the repository at this point in the history
* Add resource scope to clojure package

* add rat

* fix integration test

* feedback from @benkamphaus
- move from defs to atoms to make the tests a bit better

* adding alias with-do and with-let 
more tests

* another test

* Add examples in docstring

* refactor example and test to use resource-scope/with-let

* fix tests and problem with laziness 
now they work as expected!

* refactor to be a bit more modular

* remove comments
  • Loading branch information
gigasquid authored and vdantu committed Mar 31, 2019
1 parent 7e7fef3 commit bdf1fae
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
[org.apache.clojure-mxnet.kvstore :as kvstore]
[org.apache.clojure-mxnet.kvstore-server :as kvstore-server]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.eval-metric :as eval-metric])
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.resource-scope :as resource-scope])
(:gen-class))

(def data-dir "data/") ;; the data directory to store the mnist data
Expand All @@ -51,28 +52,6 @@
(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
(sh "../../scripts/get_mnist_data.sh"))

;;; Load the MNIST datasets
(defonce train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte")
:label (str data-dir "train-labels-idx1-ubyte")
:label-name "softmax_label"
:input-shape [784]
:batch-size batch-size
:shuffle true
:flat true
:silent false
:seed 10
:num-parts num-workers
:part-index 0}))

(defonce test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte")
:label (str data-dir "t10k-labels-idx1-ubyte")
:input-shape [784]
:batch-size batch-size
:flat true
:silent false
:num-parts num-workers
:part-index 0}))

(defn get-symbol []
(as-> (sym/variable "data") data
(sym/fully-connected "fc1" {:data data :num-hidden 128})
Expand All @@ -82,7 +61,31 @@
(sym/fully-connected "fc3" {:data data :num-hidden 10})
(sym/softmax-output "softmax" {:data data})))

(defn start

(defn train-data []
(mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte")
:label (str data-dir "train-labels-idx1-ubyte")
:label-name "softmax_label"
:input-shape [784]
:batch-size batch-size
:shuffle true
:flat true
:silent false
:seed 10
:num-parts num-workers
:part-index 0}))

(defn eval-data []
(mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte")
:label (str data-dir "t10k-labels-idx1-ubyte")
:input-shape [784]
:batch-size batch-size
:flat true
:silent false
:num-parts num-workers
:part-index 0}))

(defn start
([devs] (start devs num-epoch))
([devs _num-epoch]
(when scheduler-host
Expand All @@ -96,18 +99,16 @@
(do
(println "Starting Training of MNIST ....")
(println "Running with context devices of" devs)
(let [_mod (m/module (get-symbol) {:contexts devs})]
(m/fit _mod {:train-data train-data
:eval-data test-data
(resource-scope/with-let [_mod (m/module (get-symbol) {:contexts devs})]
(-> _mod
(m/fit {:train-data (train-data)
:eval-data (eval-data)
:num-epoch _num-epoch
:fit-params (m/fit-params {:kvstore kvstore
:optimizer optimizer
:eval-metric eval-metric})})
(println "Finish fit")
_mod
)

))))
(m/save-checkpoint {:prefix "target/test" :epoch _num-epoch}))
(println "Finish fit"))))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
;;

(ns imclassification.train-mnist-test
(:require
(:require
[clojure.test :refer :all]
[clojure.java.io :as io]
[clojure.string :as s]
Expand All @@ -26,14 +26,15 @@

(defn- file-to-filtered-seq [file]
(->>
file
file
(io/file)
(io/reader)
(line-seq)
(filter #(not (s/includes? % "mxnet_version")))))

(deftest mnist-two-epochs-test
(module/save-checkpoint (mnist/start [(context/cpu)] 2) {:prefix "target/test" :epoch 2})
(is (=
(file-to-filtered-seq "test/test-symbol.json.ref")
(file-to-filtered-seq "target/test-symbol.json"))))
(do
(mnist/start [(context/cpu)] 2)
(is (=
(file-to-filtered-seq "test/test-symbol.json.ref")
(file-to-filtered-seq "target/test-symbol.json")))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.resource-scope
(:require [org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet ResourceScope)))

(defmacro
using
"Uses a Resource Scope for all forms. This is a way to manage all Native Resources like NDArray and Symbol - it will deallocate all Native Resources by calling close on them automatically. It will not call close on Native Resources returned from the form.
Example:
(resource-scope/using
(let [temp-x (ndarray/ones [3 1])
temp-y (ndarray/ones [3 1])]
(ndarray/+ temp-x temp-y))) "
[& forms]
`(ResourceScope/using (new ResourceScope) (util/forms->scala-fn ~@forms)))


(defmacro
with-do
"Alias for a do within a resource scope using.
Example:
(resource-scope/with-do
(ndarray/ones [3 1])
:all-cleaned-up)
"
[& forms]
`(using (do ~@forms)))

(defmacro
with-let
"Alias for a let within a resource scope using.
Example:
(resource-scope/with-let [temp-x (ndarray/ones [3 1])
temp-y (ndarray/ones [3 1])]
(ndarray/+ temp-x temp-y))"
[& forms]
`(using (let ~@forms)))
6 changes: 6 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,9 @@
(apply $/immutable-list))
;; pass-through
map-or-tuple-seq)))

(defmacro forms->scala-fn
"Creates a scala fn of zero args from forms"
[& forms]
`($/fn []
(do ~@forms)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.resource-scope-test
(:require [org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.resource-scope :as resource-scope]
[clojure.test :refer :all]))


(deftest test-resource-scope-with-ndarray
(let [native-resources (atom {})
x (ndarray/ones [2 2])
return-val (resource-scope/using
(let [temp-x (ndarray/ones [3 1])
temp-y (ndarray/ones [3 1])]
(swap! native-resources assoc :temp-x temp-x)
(swap! native-resources assoc :temp-y temp-y)
(ndarray/+ temp-x 1)))]
(is (true? (ndarray/is-disposed (:temp-x @native-resources))))
(is (true? (ndarray/is-disposed (:temp-y @native-resources))))
(is (false? (ndarray/is-disposed return-val)))
(is (false? (ndarray/is-disposed x)))
(is (= [2.0 2.0 2.0] (ndarray/->vec return-val)))))

(deftest test-nested-resource-scope-with-ndarray
(let [native-resources (atom {})
x (ndarray/ones [2 2])
return-val (resource-scope/using
(let [temp-x (ndarray/ones [3 1])]
(swap! native-resources assoc :temp-x temp-x)
(resource-scope/using
(let [temp-y (ndarray/ones [3 1])]
(swap! native-resources assoc :temp-y temp-y)))))]
(is (true? (ndarray/is-disposed (:temp-y @native-resources))))
(is (true? (ndarray/is-disposed (:temp-x @native-resources))))
(is (false? (ndarray/is-disposed x)))))

(deftest test-resource-scope-with-sym
(let [native-resources (atom {})
x (sym/ones [2 2])
return-val (resource-scope/using
(let [temp-x (sym/ones [3 1])
temp-y (sym/ones [3 1])]
(swap! native-resources assoc :temp-x temp-x)
(swap! native-resources assoc :temp-y temp-y)
(sym/+ temp-x 1)))]
(is (true? (sym/is-disposed (:temp-x @native-resources))))
(is (true? (sym/is-disposed (:temp-y @native-resources))))
(is (false? (sym/is-disposed return-val)))
(is (false? (sym/is-disposed x)))))

(deftest test-nested-resource-scope-with-ndarray
(let [native-resources (atom {})
x (ndarray/ones [2 2])
return-val (resource-scope/using
(let [temp-x (ndarray/ones [3 1])]
(swap! native-resources assoc :temp-x temp-x)
(resource-scope/using
(let [temp-y (ndarray/ones [3 1])]
(swap! native-resources assoc :temp-y temp-y)))))]
(is (true? (ndarray/is-disposed (:temp-y @native-resources))))
(is (true? (ndarray/is-disposed (:temp-x @native-resources))))
(is (false? (ndarray/is-disposed x)))))

(deftest test-nested-resource-scope-with-sym
(let [native-resources (atom {})
x (sym/ones [2 2])
return-val (resource-scope/using
(let [temp-x (sym/ones [3 1])]
(swap! native-resources assoc :temp-x temp-x)
(resource-scope/using
(let [temp-y (sym/ones [3 1])]
(swap! native-resources assoc :temp-y temp-y)))))]
(is (true? (sym/is-disposed (:temp-y @native-resources))))
(is (true? (sym/is-disposed (:temp-x @native-resources))))
(is (false? (sym/is-disposed x)))))

(deftest test-list-creation-with-returning-first
(let [native-resources (atom [])
return-val (resource-scope/using
(let [temp-ndarrays (doall (repeatedly 3 #(ndarray/ones [3 1])))
_ (reset! native-resources temp-ndarrays)]
(first temp-ndarrays)))]
(is (false? (ndarray/is-disposed return-val)))
(is (= [false true true] (mapv ndarray/is-disposed @native-resources)))))

(deftest test-list-creation
(let [native-resources (atom [])
return-val (resource-scope/using
(let [temp-ndarrays (doall (repeatedly 3 #(ndarray/ones [3 1])))
_ (reset! native-resources temp-ndarrays)]
(ndarray/ones [3 1])))]
(is (false? (ndarray/is-disposed return-val)))
(is (= [true true true] (mapv ndarray/is-disposed @native-resources)))))

(deftest test-list-creation-without-let
(let [native-resources (atom [])
return-val (resource-scope/using
(first (doall (repeatedly 3 #(do
(let [x (ndarray/ones [3 1])]
(swap! native-resources conj x)
x))))))]
(is (false? (ndarray/is-disposed return-val)))
(is (= [false true true] (mapv ndarray/is-disposed @native-resources)))))

(deftest test-with-let
(let [native-resources (atom {})
x (ndarray/ones [2 2])
return-val (resource-scope/with-let [temp-x (ndarray/ones [3 1])
temp-y (ndarray/ones [3 1])]
(swap! native-resources assoc :temp-x temp-x)
(swap! native-resources assoc :temp-y temp-y)
(ndarray/+ temp-x 1))]
(is (true? (ndarray/is-disposed (:temp-x @native-resources))))
(is (true? (ndarray/is-disposed (:temp-y @native-resources))))
(is (false? (ndarray/is-disposed return-val)))
(is (false? (ndarray/is-disposed x)))
(is (= [2.0 2.0 2.0] (ndarray/->vec return-val)))))

(deftest test-with-do
(let [native-resources (atom {})
x (ndarray/ones [2 2])
return-val (resource-scope/with-do
(swap! native-resources assoc :temp-x (ndarray/ones [3 1]))
(swap! native-resources assoc :temp-y (ndarray/ones [3 1]))
(ndarray/ones [3 1]))]
(is (true? (ndarray/is-disposed (:temp-x @native-resources))))
(is (true? (ndarray/is-disposed (:temp-y @native-resources))))
(is (false? (ndarray/is-disposed return-val)))
(is (false? (ndarray/is-disposed x)))
(is (= [1.0 1.0 1.0] (ndarray/->vec return-val)))))
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,10 @@
(let [nda (util/map->scala-tuple-seq {:a-b (ndarray/ones [1 2])})]
(is (= "a_b" (._1 (.head nda))))
(is (= [1.0 1.0] (ndarray/->vec (._2 (.head nda)))))))

(deftest test-forms->scala-fn
(let [scala-fn (util/forms->scala-fn
(def x 1)
(def y 2)
{:x x :y y})]
(is (= {:x 1 :y 2} (.apply scala-fn)))))

0 comments on commit bdf1fae

Please sign in to comment.