diff --git a/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj b/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj index 017e19bc2c75..750329ff2351 100644 --- a/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj +++ b/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj @@ -75,43 +75,40 @@ (do (println "Starting Training of MNIST ....") (println "Running with context devices of" devs) - (resource-scope/using - (let [_mod (m/module (get-symbol) {:contexts devs})] - (m/fit _mod {:train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") - :label (str data-dir "train-labels-idx1-ubyte") - :label-name "softmax_label" - :input-shape [784] - :batch-size batch-size - :shuffle true - :flat true - :silent false - :seed 10 - :num-parts num-workers - :part-index 0}) - :eval-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") - :label (str data-dir "t10k-labels-idx1-ubyte") + (let [_mod (m/module (get-symbol) {:contexts devs})] + (m/fit _mod {:train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" :input-shape [784] :batch-size batch-size + :shuffle true :flat true :silent false + :seed 10 :num-parts num-workers :part-index 0}) - :num-epoch _num-epoch - :fit-params (m/fit-params {:kvstore kvstore - :optimizer optimizer - :eval-metric eval-metric})}) - (println "Finish fit") - _mod - )) - - )))) + :eval-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :input-shape [784] + :batch-size batch-size + :flat true + :silent false + :num-parts num-workers + :part-index 0}) + :num-epoch _num-epoch + :fit-params (m/fit-params {:kvstore kvstore + :optimizer optimizer + :eval-metric eval-metric})}) + (println "Finish fit") + _mod + ))))) (defn -main [& args] (let [[dev dev-num] args devs (if (= dev ":gpu") (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] - (start devs))) + (resource-scope/using (start devs)))) (comment (start [(context/cpu)]))