from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import numpy as np
Session
performs computation on a graph.
class Session(object):
Initializing a session with a graph and a state dictionary to hold tensor values.
def __init__(self, graph):
self.graph = graph
self.state = {}
run_op
takes as input an operation to run and a context to fetch
pre-evaluted tensors.
def run_op(self, op, context):
args = [self.eval_tensor(tensor, context) for tensor in op.inputs]
return op.compute(self, *args)
eval_tensor
takes as input a tensor to evaluate and a context to
fetch pre-evaluted tensors. If the tensor is not already in the context
there are three possibilities for evaluating the tensor:
def eval_tensor(self, tensor, context):
if tensor not in context:
if tensor.op is not None:
context[tensor] = self.run_op(tensor.op, context)
elif tensor in self.state and self.state[tensor] is not None:
context[tensor] = self.state[tensor]
elif tensor not in self.state and tensor.initial_value is not None:
context[tensor] = self.state[tensor] = tensor.initial_value
return context[tensor]
run
takes a list of tensors to evaluate and a feed dictionary that
can be used to override tensors.
def run(self, tensors, feed_dict=None):
context = {}
if feed_dict:
context.update(feed_dict)
return [self.eval_tensor(tensor, context) for tensor in tensors]