Source code for tf_encrypted.layers.pooling

import math
from . import core

from ..protocol.pond import TFEVariable
from typing import Optional, Union, Tuple, List
from abc import abstractmethod

IntTuple = Union[int, Tuple[int, int], List[int]]


class Pooling2D(core.Layer):
    """
    Subclass for Average and Max pooling layers

    Do not instantiate.
    """
    def __init__(self,
                 input_shape: List[int],
                 pool_size: IntTuple,
                 strides: Optional[IntTuple] = None,
                 padding: str = "SAME", channels_first: bool = True) -> None:
        if type(pool_size) == int:
            pool_size = (pool_size, pool_size)  # type: ignore
        self.pool_size = pool_size
        if strides is None:
            strides = pool_size
        elif type(strides) == int:
            strides = (strides, strides)  # type: ignore
        self.strides = strides
        if padding not in ['SAME', 'VALID']:
            raise ValueError("Don't know how to do padding of type {}".format(padding))
        self.padding = padding
        self.channels_first = channels_first

        super(Pooling2D, self).__init__(input_shape)
        self.cache = None
        self.cached_input_shape = None

    def initialize(self,
                   input_shape: IntTuple,
                   initializer: Optional[TFEVariable] = None) -> None:
        pass

    def get_output_shape(self) -> List[int]:
        if self.channels_first:
            _, channels, H_in, W_in = self.input_shape
        else:
            _, H_in, W_in, channels = self.input_shape

        if self.padding == "SAME":
            H_out = math.ceil(H_in / self.strides[0])
            W_out = math.ceil(W_in / self.strides[1])
        else:
            H_out = math.ceil((H_in - self.pool_size[0] + 1) / self.strides[0])
            W_out = math.ceil((W_in - self.pool_size[1] + 1) / self.strides[1])
        return [self.input_shape[0], self.input_shape[1], H_out, W_out]

    @abstractmethod
    def pool(self, x: TFEVariable, pool_size, strides, padding) -> TFEVariable:
        raise NotImplementedError

    def forward(self, x: TFEVariable) -> TFEVariable:
        if not self.channels_first:
            x = self.prot.transpose(x, perm=[0, 3, 1, 2])

        self.cached_input_shape = x.shape
        self.cache = x

        out = self.pool(x, self.pool_size, self.strides, self.padding)

        if not self.channels_first:
            out = self.prot.transpose(out, perm=[0, 2, 3, 1])

        return out

    def backward(self, d_y: TFEVariable, learning_rate: float) -> Optional[TFEVariable]:
        raise NotImplementedError


[docs]class AveragePooling2D(Pooling2D): """ AveragePooling2D :See: tf.nn.avg_pool """ def __init__(self, input_shape: List[int], pool_size: IntTuple, strides: Optional[IntTuple] = None, padding: str = "SAME", channels_first: bool = True) -> None: super(AveragePooling2D, self).__init__(input_shape, pool_size, strides, padding, channels_first) def pool(self, x: TFEVariable, pool_size, strides, padding) -> TFEVariable: return self.prot.avgpool2d(x, pool_size, strides, padding)
[docs]class MaxPooling2D(Pooling2D): """ MaxPooling2D :See: tf.nn.max_pool """ def __init__(self, input_shape: List[int], pool_size: IntTuple, strides: Optional[IntTuple] = None, padding: str = "SAME", channels_first: bool = True) -> None: super(MaxPooling2D, self).__init__(input_shape, pool_size, strides, padding, channels_first) # TODO -- throw an error here if the protocol is not secureNN def pool(self, x: TFEVariable, pool_size, strides, padding) -> TFEVariable: return self.prot.maxpool2d(x, pool_size, strides, padding)