Building autograd engine tinytorch 03
Clean up & refactor
(dont skip we also add new things here)
Let's get professional. It's a good time now to add type hints and do some cleanup. Let's start with the tensor class.
from __future__ import annotations
import numpy as np
class Tensor:
def __init__(self, data):
self.data:np.ndarray = Tensor._data_to_numpy(data)
self.grad:Tensor = None
self._ctx:Function = None
@staticmethod
def _data_to_numpy(data):
if isinstance(data,(int,float)):
return np.array([data])
if isinstance(data,(list,tuple)):
return np.array(data)
if isinstance(data,np.ndarray):
return data
if isinstance(data,Tensor):
return data.data.copy()
raise ValueError("Invalid value passed to tensor")
@staticmethod
def _ensure_tensor(data):
if isinstance(data,Tensor):
return data
return Tensor(data)
Import annotations
so type definitions won't be a problem with old Python versions. And add a new _data_to_numpy method so people can pass any input like Tensor(2), Tensor([1,2]), or Tensor(np.arange(10)) as we will get consistent data underneath. also new ensure_tensor a simple methoed that well ensure is data tensor.
def __add__(self, other):
fn = Function(Add, self, other)
result = Add.forward(self, other)
result._ctx = fn
return result
this is pretty ugly lets make it clean by modifying funtion class.
class Function:
def __init__(self, op, *args):
self.op:Function = op
self.args:list[Tensor] = args
@classmethod
def apply(cls,*args):
ctx = Function(cls,*args)
result = cls.forward(*args)
result._ctx = ctx
return result
@staticmethod
def forward(self,*args):
raise NotImplementedError
@staticmethod
def backward(self,*args):
raise NotImplementedError
Apply is a class method that essentially does what we do in __add__
and __mul__
. The idea is that all functions will inherit from the Function class with static methods. Then they will just do Add.apply(). Also, the forward and backward methods indicate that you must implement them. plus some type defs.
Lets modify add,mul to get nice properties of function
class Add(Function):
...
class Mul(Function):
...
Time to chnage magic methods
def __add__(self, other)->Tensor:
return Add.apply(self,Tensor.ensure_tensor(other))
def __mul__(self, other)->Tensor:
return Mul.apply(self,Tensor.ensure_tensor(other))
def __radd__(self, other)->Tensor:
return Add.apply(self,Tensor.ensure_tensor(other))
def __rmul__(self, other)->Tensor:
return Mul.apply(self,Tensor.ensure_tensor(other))
look nit,
new radd method do add be in reverse if other value is not tensor, Tensor(1)+1 will use add & 1+Tensor(1) will use radd
now we also do += for that we nee iadd, lets add those.
def __iadd__(self, other)->Tensor:
return Add.apply(self,Tensor.ensure_tensor(other))
def __imul__(self, other)->Tensor:
return Mul.apply(self,Tensor.ensure_tensor(other))
Since we are here lets also add detach()
that will remove node from graph.
and how will be do that. by setting ctx to None.
class Tensor:
...
def detach(self)->Tensor:
self._ctx = None
return self
Since we added detach, let's clean up any unnecessary graphs that might have been generated in the backward pass.
def backward(self,grad=None):
if self._ctx is None:
return
if grad is None:
grad = Tensor([1.])
self.grad = grad
op = self._ctx.op
child_nodes = self._ctx.args
grads = op.backward(self._ctx,grad)
if len(self._ctx.args) == 1:
grads = [grads]
for tensor,grad in zip(child_nodes,grads):
if grad is None:
continue
if tensor.grad is None:
tensor.grad = Tensor(np.zeros_like(tensor.data))
tensor.grad += grad.detach()
tensor.backward(grad)
ok we are doing couple of thing here. most imp we detach grads. we are skipping node that dont have grads. In future we will save single node transform so if args to function is 1. we are wrapping grads in list so loop works correctly.
For Bonus we will add add clone
def clone(self)->Tensor:
return Tensor(self.data.clone())
lets add also tiny utils functions
Shape property on tensor class
class Tensor:
...
@property
def shape(self) -> tuple:
return self.data.shape
and two methods to create ones and zeros
def ones(shape) -> Tensor:
return Tensor(np.ones(shape))
def zeros(shape) -> Tensor:
return Tensor(np.zeros(shape))
commit 96bd8a262398ad56699175cf64592b88ebd4be11
at the end file should look like this
from __future__ import annotations
import numpy as np
class Tensor:
def __init__(self, data):
self.data: np.ndarray = Tensor._data_to_numpy(data)
self.grad: Tensor = None
self._ctx: Function = None
@staticmethod
def _data_to_numpy(data):
if isinstance(data, (int, float)):
return np.array([data])
if isinstance(data, (list, tuple)):
return np.array(data)
if isinstance(data, np.ndarray):
return data
if isinstance(data, Tensor):
return data.data.copy()
raise ValueError("Invalid value passed to tensor")
@staticmethod
def ensure_tensor(data):
if isinstance(data, Tensor):
return data
return Tensor(data)
def __add__(self, other) -> Tensor:
return Add.apply(self, Tensor.ensure_tensor(other))
def __mul__(self, other) -> Tensor:
return Mul.apply(self, Tensor.ensure_tensor(other))
def __radd__(self, other) -> Tensor:
return Add.apply(self, Tensor.ensure_tensor(other))
def __rmul__(self, other) -> Tensor:
return Mul.apply(self, Tensor.ensure_tensor(other))
def __iadd__(self, other) -> Tensor:
return Add.apply(self, Tensor.ensure_tensor(other))
def __imul__(self, other) -> Tensor:
return Mul.apply(self, Tensor.ensure_tensor(other))
def __repr__(self):
return f"tensor({self.data})"
@property
def shape(self) -> tuple:
return self.data.shape
def detach(self) -> Tensor:
self._ctx = None
return self
def clone(self) -> Tensor:
return Tensor(self.data.clone())
def backward(self, grad=None):
if self._ctx is None:
return
if grad is None:
grad = Tensor([1.0])
self.grad = grad
op = self._ctx.op
child_nodes = self._ctx.args
grads = op.backward(self._ctx, grad)
if len(self._ctx.args) == 1:
grads = [grads]
for tensor, grad in zip(child_nodes, grads):
if grad is None:
continue
if tensor.grad is None:
tensor.grad = Tensor(np.zeros_like(tensor.data))
tensor.grad += grad.detach()
tensor.backward(grad)
class Function:
def __init__(self, op, *args):
self.op: Function = op
self.args: list[Tensor] = args
@classmethod
def apply(cls, *args):
ctx = Function(cls, *args)
result = cls.forward(*args)
result._ctx = ctx
return result
@staticmethod
def forward(self, *args):
raise NotImplementedError
@staticmethod
def backward(self, *args):
raise NotImplementedError
class Add(Function):
@staticmethod
def forward(x, y):
return Tensor(x.data + y.data)
@staticmethod
def backward(ctx, grad):
x, y = ctx.args
return Tensor([1]) * grad, Tensor([1]) * grad
class Mul(Function):
@staticmethod
def forward(x, y):
return Tensor(x.data * y.data) # z = x*y
@staticmethod
def backward(ctx, grad):
x, y = ctx.args
return Tensor(y.data) * grad, Tensor(x.data) * grad # dz/dx, dz/dy
def ones(shape) -> Tensor:
return Tensor(np.ones(shape))
def zeros(shape) -> Tensor:
return Tensor(np.zeros(shape))
if __name__ == "__main__":
def f(x):
return x * x * x + x
x = Tensor([1.2])
z = f(x)
z.backward()
print(f"X: {x} grad: {x.grad}")