Skip to content

Commit

Permalink
[clojure-package] fix docstrings in normal.clj (apache#14295)
Browse files Browse the repository at this point in the history
* Fixed documentation string in `normal` function
* Added spec to catch `high` < `low` in `uniform`
* Added spec to catch `scale` <= 0 in `normal`
* Added unit tests
  • Loading branch information
Chouffe authored and vdantu committed Mar 31, 2019
1 parent 016be1c commit 403a68c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 29 deletions.
70 changes: 42 additions & 28 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,70 +16,84 @@
;;

(ns org.apache.clojure-mxnet.random
"Random Number interface of mxnet."
(:require
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.context :as context]
[clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet Context Random)))

(s/def ::low number?)
(s/def ::high number?)
(s/def ::low-high (fn [[low high]] (<= low high)))
(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
(s/def ::ctx #(instance? Context %))
(s/def ::uniform-opts (s/keys :opt-un [::ctx]))

(defn uniform
"Generate uniform distribution in [low, high) with shape.
low: The lower bound of distribution.
high: The upper bound of distribution.
shape-vec: vector shape of the ndarray generated.
opts-map {
ctx: Context of output ndarray, will use default context if not specified.
out: Output place holder}
returns: The result ndarray with generated result./"
"Generate uniform distribution in [`low`, `high`) with shape.
`low`: The lower bound of distribution.
`high`: The upper bound of distribution.
`shape-vec`: vector shape of the ndarray generated.
`opts-map` {
`ctx`: Context of output ndarray, will use default context if not specified.
`out`: Output place holder}
returns: The result ndarray with generated result.
Ex:
(uniform 0 1 [1 10])
(uniform -10 10 [100 100])"
([low high shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
(util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
(util/validate! ::low low "Incorrect random uniform parameter")
(util/validate! ::high high "Incorrect random uniform parameters")
(util/validate! ::low-high [low high] "Incorrect random uniform parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out))
([low high shape-vec]
(uniform low high shape-vec {})))

(s/def ::loc number?)
(s/def ::scale number?)
(s/def ::scale (s/and number? pos?))
(s/def ::normal-opts (s/keys :opt-un [::ctx]))

(defn normal
"Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
loc: The standard deviation of the normal distribution
scale: The upper bound of distribution.
shape-vec: vector shape of the ndarray generated.
opts-map {
ctx: Context of output ndarray, will use default context if not specified.
out: Output place holder}
returns: The result ndarray with generated result./"
"Generate normal (Gaussian) distribution N(mean, stdvar^^2) with shape.
`loc`: Mean (centre) of the distribution.
`scale`: Standard deviation (spread or width) of the distribution.
`shape-vec`: vector shape of the ndarray generated.
`opts-map` {
`ctx`: Context of output ndarray, will use default context if not specified.
`out`: Output place holder}
returns: The result ndarray with generated result.
Ex:
(normal 0 1 [10 10])
(normal -5 4 [2 3])"
([loc scale shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::normal-opts opts "Incorrect random normal parameters")
(util/validate! ::loc loc "Incorrect random normal parameters")
(util/validate! ::scale scale "Incorrect random normal parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out))
(Random/normal (float loc)
(float scale)
(mx-shape/->shape shape-vec) ctx out))
([loc scale shape-vec]
(normal loc scale shape-vec {})))

(s/def ::seed-state number?)
(defn seed
" Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
as well as results from executors that contains Random number
such as Dropout operators.
"Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
as well as results from executors that contains Random number
such as Dropout operators.
seed-state: The random number seed to set to all devices.
`seed-state`: The random number seed to set to all devices.
note: The random number generator of mxnet is by default device specific.
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU."
generated from GPU0 can be different from CPU.
Ex:
(seed-state 42)
(seed-state 42.0)"
[seed-state]
(util/validate! ::seed-state seed-state "Incorrect seed parameters")
(Random/seed (int seed-state)))
(Random/seed (int seed-state)))
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
(is (thrown? Exception (fn_ 'a 2 [])))
(is (thrown? Exception (fn_ 1 'b [])))
(is (thrown? Exception (fn_ 1 2 [-1])))
(is (thrown? Exception (fn_ 1 0 [1 2])))
(is (thrown? Exception (fn_ 1 -1 [1 2])))
(is (thrown? Exception (fn_ 1 2 [2 3 0])))
(is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
(let [ctx (context/default-context)]
Expand All @@ -64,4 +66,4 @@
(deftest test-random-parameters-specs
(random-or-normal random/normal)
(random-or-normal random/uniform)
(is (thrown? Exception (random/seed "a"))))
(is (thrown? Exception (random/seed "a"))))

0 comments on commit 403a68c

Please sign in to comment.