Source code for tf_encrypted.session

import os
from typing import Dict, List, Optional, Any, Union
from collections import defaultdict

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.python import debug as tf_debug

from .config import Config, RemoteConfig, get_config
from .protocol.pond import PondPublicTensor
from .tensor.factory import AbstractTensor


__TFE_STATS__ = bool(os.getenv('TFE_STATS', False))
__TFE_TRACE__ = bool(os.getenv('TFE_TRACE', False))
__TFE_DEBUG__ = bool(os.getenv('TFE_DEBUG', False))
__TENSORBOARD_DIR__ = str(os.getenv('TFE_STATS_DIR', '/tmp/tensorboard'))

_run_counter = defaultdict(int)  # type: Any


[docs]class Session(tf.Session): """ Wrap a Tensorflow Session. See :py:class:`tf.Session` :param Optional[tf.Graph] graph: A :class:`tf.Graph`. Used in the same as in tensorflow. This is the graph to be launched. If nothing is specified then the default session graph will be used. :param Optional[~tensorflow_encrypted.config.Config] config: A :class:`Local <tensorflow_encrypted.config.LocalConfig>` or :class:`Remote <tensorflow_encrypted.config.RemoteConfig>` config to be used to execute the graph. """ def __init__( self, graph=None, config=None ) -> None: if config is None: config = get_config() target, configProto = config.get_tf_config() if isinstance(config, RemoteConfig): print("Starting session on target '{}' using config {}".format(target, configProto)) super(Session, self).__init__(target, graph, configProto) # self.sess = tf.Session(target, graph, configProto) global __TFE_DEBUG__ if __TFE_DEBUG__: print('Session in debug mode') self = tf_debug.LocalCLIDebugWrapperSession(self) def sanitize_fetches(self, fetches: Any) -> Union[List[Any], tf.Tensor, tf.Operation]: if isinstance(fetches, (list, tuple)): return [self.sanitize_fetches(fetch) for fetch in fetches] else: if isinstance(fetches, (tf.Tensor, tf.Operation)): return fetches elif isinstance(fetches, PondPublicTensor): return fetches.decode() elif isinstance(fetches, AbstractTensor): return fetches.to_native() else: raise TypeError("Don't know how to fetch {}", type(fetches))
[docs] def run( self, fetches: Any, feed_dict: Dict[tf.Tensor, np.ndarray] = {}, tag: Optional[str] = None, write_trace: bool = False ): """ See :meth:tf.Session.run This method functions just as the one from tensorflow. :param Any fetches: A single graph element, a list of graph elements, or a dictionary whose values are graph elements or lists of graph elements. :param Dict[str,np.ndarray] feed_dict: A dictionary that maps graph elements to values. :param Optional[str] tag: A namespace to run the session under. :param bool write_Trace: If true, the session logs will be dumped to be used in tensorboard. :rtype: Any :returns: Either a single value if `fetches` is a single graph element, or a list of values if fetches is a list, or a dictionary with the same keys as fetches if that is a dictionary (described above). Order in which fetches operations are evaluated inside the call is undefined. """ sanitized_fetches = self.sanitize_fetches(fetches) if not __TFE_STATS__ or tag is None: fetches_out = super(Session, self).run( sanitized_fetches, feed_dict=feed_dict ) else: session_tag = "{}{}".format(tag, _run_counter[tag]) run_tag = os.path.join(__TENSORBOARD_DIR__, session_tag) _run_counter[tag] += 1 writer = tf.summary.FileWriter(run_tag, self.graph) run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches_out = super(Session, self).run( sanitized_fetches, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata ) writer.add_run_metadata(run_metadata, session_tag) writer.close() if __TFE_TRACE__ or write_trace: chrome_trace = timeline.Timeline(run_metadata.step_stats).generate_chrome_trace_format() with open('{}/{}.ctr'.format(__TENSORBOARD_DIR__, session_tag), 'w') as f: f.write(chrome_trace) return fetches_out
def setMonitorStatsFlag(monitor_stats: bool = False) -> None: global __TFE_STATS__ if monitor_stats is True: print("Tensorflow encrypted is monitoring statistics for each session.run() call using a tag") __TFE_STATS__ = monitor_stats def setTFEDebugFlag(debug: bool = False) -> None: global __TFE_DEBUG__ if debug is True: print("Tensorflow encrypted is running in DEBUG mode") __TFE_DEBUG__ = debug def setTFETraceFlag(trace: bool = False) -> None: global __TFE_TRACE__ if trace is True: print("Tensorflow encrypted is dumping computation traces") __TFE_TRACE__ = trace