from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from tensor import Tensor
from ops import AddOp, SubOp, MulOp, DivOp, \
DotOp, TransposeOp, SquareOp, NegOp, \
MeanOp, SigmoidOp, AssignOp, GroupOp
Graph
represents a computation to be evaluated by a Session
. With the
exception of Graph#tensor
, Graph#convert
, and Graph#gradients
, most
methods simply create an operation and return the output tensor of the
operation.
class Graph(object):
The tensor
method defines a new tensor with the given initial value
and operation.
def tensor(self, initial_value=None, op=None):
return Tensor(initial_value=initial_value, graph=self, op=op)
The convert
method returns the given value if it is a Tensor
,
otherwise convert it to one.
def convert(self, value):
if isinstance(value, Tensor):
return value
return self.tensor(initial_value=value)
The gradients
method performs backpropagation using reverse accumulation and the chain rule.
It traverses the graph from y
to each x
in xs
, accumulating
gradients, and returning the partial gradients for each xs
. We use a
queue to keep track of the next tensor for which to compute the
gradient and keep a dictionary of the gradients computed thus far.
Iteration starts from the target output y
with an output gradient
of 1.
def gradients(self, y, xs):
queue = []
queue.append((y, 1))
grads = {}
while len(queue) > 0:
y, grad_y = queue.pop(0)
grad_y = self.convert(grad_y)
gradients = y.op.gradient(grad_y)
assert len(gradients) == len(y.op.inputs)
for tensor, gradient in zip(y.op.inputs, gradients):
if tensor in grads:
grads[tensor] += gradient
else:
grads[tensor] = gradient
if tensor.op:
queue.append((tensor, gradient))
return [grads[x] for x in xs]
Each operation method defines a new operation with the provided input tensors and returns the operations' output.
def add(self, a, b):
op = AddOp([a, b], graph=self)
return op.output
def sub(self, a, b):
op = SubOp([a, b], graph=self)
return op.output
def mul(self, a, b):
op = MulOp([a, b], graph=self)
return op.output
def div(self, a, b):
op = DivOp([a, b], graph=self)
return op.output
def neg(self, x):
op = NegOp([x], graph=self)
return op.output
def square(self, x):
op = SquareOp([x], graph=self)
return op.output
def sigmoid(self, x):
op = SigmoidOp([x], graph=self)
return op.output
def dot(self, a, b):
op = DotOp([a, b], graph=self)
return op.output
def transpose(self, x):
op = TransposeOp([x], graph=self)
return op.output
def mean(self, x):
op = MeanOp([x], graph=self)
return op.output
def assign(self, a, b):
op = AssignOp([a, b], graph=self)
return op.output
def group(self, inputs):
op = GroupOp(inputs, graph=self)
return op.output