session.py

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import numpy as np
#

Session performs computation on a graph.

class Session(object):
#
#

Initializing a session with a graph and a state dictionary to hold tensor values.

    def __init__(self, graph):
#
        self.graph = graph
        self.state = {}
#

run_op takes as input an operation to run and a context to fetch pre-evaluted tensors.

    def run_op(self, op, context):
#
        args = [self.eval_tensor(tensor, context) for tensor in op.inputs]
        return op.compute(self, *args)
#

eval_tensor takes as input a tensor to evaluate and a context to fetch pre-evaluted tensors. If the tensor is not already in the context there are three possibilities for evaluating the tensor:

  • The tensor has an operation and is therefore the result of the operation that must be computed.
  • The tensor has an active state from another session run that can be fetched.
  • The tensor has an initial value from its instantiation that can be fetched and added to the state.
    def eval_tensor(self, tensor, context):
#
        if tensor not in context:
            if tensor.op is not None:
                context[tensor] = self.run_op(tensor.op, context)
            elif tensor in self.state and self.state[tensor] is not None:
                context[tensor] = self.state[tensor]
            elif tensor not in self.state and tensor.initial_value is not None:
                context[tensor] = self.state[tensor] = tensor.initial_value

        return context[tensor]
#

run takes a list of tensors to evaluate and a feed dictionary that can be used to override tensors.

    def run(self, tensors, feed_dict=None):
#
        context = {}

        if feed_dict:
            context.update(feed_dict)

        return [self.eval_tensor(tensor, context) for tensor in tensors]