diff --git a/micrograd/nn.py b/micrograd/nn.py index 30d5d777..36bc5c48 100644 --- a/micrograd/nn.py +++ b/micrograd/nn.py @@ -18,6 +18,7 @@ def __init__(self, nin, nonlin=True): self.nonlin = nonlin def __call__(self, x): + x = [x] if isinstance(x, Value) else x act = sum((wi*xi for wi,xi in zip(self.w, x)), self.b) return act.relu() if self.nonlin else act