from __future__ import absolute_import
from typing import Optional, Tuple
import sys
import math
import tensorflow as tf
from .protocol import memoize, nodes
from ..protocol.pond import (
Pond, PondTensor, PondPublicTensor, PondPrivateTensor, PondMaskedTensor, _type
)
from ..tensor.prime import PrimeFactory
from ..tensor.factory import AbstractFactory, AbstractTensor
from ..player import Player
from ..config import get_config
_thismodule = sys.modules[__name__]
[docs]class SecureNN(Pond):
"""Implementation of secureNN from the secureNN paper
https://eprint.iacr.org/2018/442.pdf
"""
def __init__(
self,
server_0: Optional[Player] = None,
server_1: Optional[Player] = None,
server_2: Optional[Player] = None,
prime_factory: Optional[AbstractFactory] = None,
odd_factory: Optional[AbstractFactory] = None,
**kwargs
) -> None:
server_0 = server_0 or get_config().get_player('server0')
server_1 = server_1 or get_config().get_player('server1')
server_2 = server_2 or get_config().get_player('crypto_producer') # TODO[Morten] use `server2` as key
super(SecureNN, self).__init__(
server_0=server_0,
server_1=server_1,
crypto_producer=server_2,
**kwargs
)
self.server_2 = server_2
if prime_factory is None:
prime_factory = PrimeFactory(107, native_type=self.tensor_factory.native_type)
if odd_factory is None:
odd_factory = self.tensor_factory
self.prime_factory = prime_factory
self.odd_factory = odd_factory
assert self.prime_factory.native_type == self.tensor_factory.native_type
assert self.odd_factory.native_type == self.tensor_factory.native_type
[docs] @memoize
def bitwise_not(self, x: PondTensor) -> PondTensor:
"""
Computes the bitwise `NOT` of the input.
`(1 - x)`
:param x: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:rtype: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
"""
assert not x.is_scaled, "Input is not supposed to be scaled"
with tf.name_scope('bitwise_not'):
return self.sub(1, x)
[docs] @memoize
def bitwise_and(self, x: 'PondTensor', y: 'PondTensor') -> 'PondTensor':
"""
Computes the bitwise `AND` of the given inputs.
`(x * y)`
:param x: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:param y: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:rtype: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
"""
assert not x.is_scaled, "Input is not supposed to be scaled"
assert not y.is_scaled, "Input is not supposed to be scaled"
with tf.name_scope('bitwise_and'):
return x * y
[docs] @memoize
def bitwise_or(self, x: 'PondTensor', y: 'PondTensor') -> 'PondTensor':
"""
Computes the bitwise `OR` of the given inputs.
`(x + y) - (x * y)`
:param x: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:param y: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:rtype: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
"""
assert not x.is_scaled, "Input is not supposed to be scaled"
assert not y.is_scaled, "Input is not supposed to be scaled"
with tf.name_scope('bitwise_or'):
return x + y - self.bitwise_and(x, y)
[docs] @memoize
def bitwise_xor(self, x: 'PondTensor', y: 'PondTensor') -> 'PondTensor':
"""
Compute the bitwise `XOR` of the given inputs.
`(x + y) - (x * y * 2)`
:param x: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:param y: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
:rtype: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
"""
assert not x.is_scaled, "Input is not supposed to be scaled"
assert not y.is_scaled, "Input is not supposed to be scaled"
with tf.name_scope('bitwise_xor'):
return x + y - self.bitwise_and(x, y) * 2
[docs] @memoize
def msb(self, x: 'PondTensor') -> 'PondTensor':
"""
Computes the most significant bit of the provided tensor.
:param x: a :class:`~tensorflow_encrypted.protocol.pond.PondTensor`
"""
# NOTE when the modulus is odd then msb reduces to lsb via x -> 2*x
# if x.backing_dtype.modulus % 2 != 1:
# # NOTE: this is currently only for use with an odd-modulus CRTTensor
# # NativeTensor will use an even modulus and will require share_convert
# raise Exception('SecureNN protocol assumes a ring of odd cardinality, ' +
# 'but it was initialized with an even one.')
return self.lsb(x * 2)
@memoize
def lsb(self, x: PondTensor) -> PondTensor:
return self.dispatch('lsb', x, container=_thismodule)
@memoize
def bits(self, x: PondTensor, factory: Optional[AbstractFactory]=None) -> 'PondTensor':
return self.dispatch('bits', x, container=_thismodule, factory=factory)
@memoize
def negative(self, x: PondTensor) -> PondTensor:
with tf.name_scope('negative'):
# NOTE MSB is 1 iff xi < 0
return self.msb(x)
@memoize
def non_negative(self, x: PondTensor) -> PondTensor:
with tf.name_scope('non_negative'):
return self.bitwise_not(self.msb(x))
@memoize
def less(self, x: PondTensor, y: PondTensor) -> PondTensor:
with tf.name_scope('less'):
return self.negative(x - y)
@memoize
def less_equal(self, x: PondTensor, y: PondTensor) -> PondTensor:
with tf.name_scope('less_equal'):
return self.bitwise_not(self.greater(x, y))
@memoize
def greater(self, x: PondTensor, y: PondTensor) -> PondTensor:
with tf.name_scope('greater'):
return self.negative(y - x)
@memoize
def greater_equal(self, x: PondTensor, y: PondTensor) -> PondTensor:
with tf.name_scope('greater_equal'):
return self.bitwise_not(self.less(x, y))
@memoize
def select(self, choice_bit: PondTensor, x: PondTensor, y: PondTensor) -> PondTensor:
with tf.name_scope('select'):
return (y - x) * choice_bit + x
@memoize
def equal_zero(self, x, out_dtype: Optional[AbstractFactory]=None):
return self.dispatch('equal_zero', x, container=_thismodule, out_dtype=out_dtype)
def share_convert(self, x):
raise NotImplementedError
def divide(self, x, y):
raise NotImplementedError
@memoize
def relu(self, x):
with tf.name_scope('relu'):
drelu = self.non_negative(x)
return drelu * x
def maxpool2d(self, x, pool_size, strides, padding):
node_key = ('maxpool2d', x, tuple(pool_size), tuple(strides), padding)
z = nodes.get(node_key, None)
if z is not None:
return z
dispatch = {
PondPublicTensor: _maxpool2d_public,
PondPrivateTensor: _maxpool2d_private,
PondMaskedTensor: _maxpool2d_masked,
}
func = dispatch.get(_type(x), None)
if func is None:
raise TypeError("Don't know how to avgpool2d {}".format(type(x)))
z = func(self, x, pool_size, strides, padding)
nodes[node_key] = z
return z
@memoize
def maximum(self, x, y):
with tf.name_scope('maximum'):
indices_of_maximum = self.greater(x, y)
return self.select(indices_of_maximum, y, x)
@memoize
def reduce_max(self, x, axis=0):
with tf.name_scope('reduce_max'):
def build_comparison_tree(ts):
assert len(ts) > 0
if len(ts) == 1:
return ts[0]
halfway = len(ts) // 2
ts_left, ts_right = ts[:halfway], ts[halfway:]
maximum_left = build_comparison_tree(ts_left)
maximum_right = build_comparison_tree(ts_right)
return self.maximum(maximum_left, maximum_right)
tensors = self.split(x, int(x.shape[axis]), axis=axis)
maximum = build_comparison_tree(tensors)
return self.squeeze(maximum, axis=(axis,))
def dmax_pool_efficient(self, x):
raise NotImplementedError
def _bits_public(prot, x: PondPublicTensor, factory: Optional[AbstractFactory]=None) -> PondPublicTensor:
factory = factory or prot.tensor_factory
with tf.name_scope('bits'):
x_on_0, x_on_1 = x.unwrapped
with tf.device(prot.server_0.device_name):
bits_on_0 = x_on_0.to_bits(factory)
with tf.device(prot.server_1.device_name):
bits_on_1 = x_on_1.to_bits(factory)
return PondPublicTensor(prot, bits_on_0, bits_on_1, False)
def _lsb_private(prot, y: PondPrivateTensor):
with tf.name_scope('lsb'):
with tf.name_scope('lsb_mask'):
with tf.device(prot.server_2.device_name):
x_raw = y.backing_dtype.sample_uniform(y.shape)
xbits_raw = x_raw.to_bits(factory=prot.prime_factory)
xlsb_raw = xbits_raw[..., 0].cast(y.backing_dtype)
x = prot._share_and_wrap(x_raw, False)
xbits = prot._share_and_wrap(xbits_raw, False)
xlsb = prot._share_and_wrap(xlsb_raw, False)
with tf.device(prot.server_0.device_name):
# TODO[Morten] pull this out as a separate `sample_bits` method on tensors (optimized for bits only)
beta_raw = prot.prime_factory.sample_bounded(y.shape, 1)
beta = PondPublicTensor(prot, beta_raw, beta_raw, is_scaled=False)
with tf.name_scope('lsb_compare'):
r = (y + x).reveal()
rbits = prot.bits(r)
rlsb = rbits[..., 0]
bp = _private_compare(prot, xbits, r, beta)
with tf.name_scope('lsb_combine'):
gamma = prot.bitwise_xor(bp, beta.cast_backing(prot.tensor_factory))
delta = prot.bitwise_xor(xlsb, rlsb)
alpha = prot.bitwise_xor(gamma, delta)
assert alpha.backing_dtype is y.backing_dtype
return alpha
def _lsb_masked(prot, x: PondMaskedTensor):
return prot.lsb(x.unmasked)
def _private_compare(prot, x_bits: PondPrivateTensor, r: PondPublicTensor, beta: PondPublicTensor):
# TODO[Morten] no need to check this (should be free)
assert r.backing_dtype == prot.tensor_factory
assert x_bits.backing_dtype == prot.prime_factory
out_shape = r.shape
out_dtype = r.backing_dtype
prime_dtype = x_bits.backing_dtype
bit_length = x_bits.shape[-1]
assert r.shape == out_shape
assert r.backing_dtype == out_dtype
assert x_bits.shape[:-1] == out_shape
assert x_bits.backing_dtype == prime_dtype
assert beta.shape == out_shape
assert beta.backing_dtype == prime_dtype
with tf.name_scope('private_compare'):
with tf.name_scope('bit_comparisons'):
# use either r or t = r + 1 according to beta
s = prot.select(beta.cast_backing(r.backing_dtype), r, r + 1)
s_bits = prot.bits(s, factory=prime_dtype)
assert s_bits.shape[-1] == bit_length
# compute w_sum
w_bits = prot.bitwise_xor(x_bits, s_bits)
w_sum = prot.cumsum(w_bits, axis=-1, reverse=True, exclusive=True)
assert w_sum.backing_dtype == prime_dtype
# compute c, ignoring edge cases at first
sign = prot.select(beta, 1, -1)
sign = prot.expand_dims(sign, axis=-1)
c_except_edge_case = (s_bits - x_bits) * sign + 1 + w_sum
assert c_except_edge_case.backing_dtype == prime_dtype
with tf.name_scope('edge_cases'):
# adjust for edge cases, i.e. where beta is 1 and s is zero (meaning r was -1)
edge_cases = prot.bitwise_and(
beta,
prot.equal_zero(s, prime_dtype)
)
edge_cases = prot.expand_dims(edge_cases, axis=-1)
c_edge_case_raw = prime_dtype.tensor(tf.constant([0] + [1] * (bit_length - 1), dtype=prime_dtype.native_type, shape=(1, bit_length)))
c_edge_case = prot._share_and_wrap(c_edge_case_raw, False)
c = prot.select(
edge_cases,
c_except_edge_case,
c_edge_case
) # type: PondPrivateTensor
assert c.backing_dtype == prime_dtype
with tf.name_scope('zero_search'):
# generate multiplicative mask to hide non-zero values
with tf.device(prot.server_0.device_name):
mask_raw = prime_dtype.sample_uniform(c.shape, minval=1)
mask = PondPublicTensor(prot, mask_raw, mask_raw, False)
# mask non-zero values; this is safe when we're in a field
c_masked = c * mask
assert c_masked.backing_dtype == prime_dtype
# TODO[Morten] permute
# reconstruct masked values on server 2 to find entries with zeros
with tf.device(prot.server_2.device_name):
d = prot._reconstruct(*c_masked.unwrapped)
# find all zero entries
zeros = d.equal_zero(out_dtype)
# for each bit sequence, determine whether it has one or no zero in it
rows_with_zeros = zeros.reduce_sum(axis=-1, keepdims=False)
# reshare result
result = prot._share_and_wrap(rows_with_zeros, False)
assert result.backing_dtype == out_dtype
return result
def _equal_zero_public(prot, x: PondPublicTensor, out_dtype: Optional[AbstractFactory]=None) -> PondPublicTensor:
with tf.name_scope('equal_zero'):
x_on_0, x_on_1 = x.unwrapped
with tf.device(prot.server_0.device_name):
equal_zero_on_0 = x_on_0.equal_zero(out_dtype)
with tf.device(prot.server_1.device_name):
equal_zero_on_1 = x_on_1.equal_zero(out_dtype)
return PondPublicTensor(prot, equal_zero_on_0, equal_zero_on_1, False)
#
# max pooling helpers
#
def _im2col(prot: Pond,
x: PondTensor,
pool_size: Tuple[int, int],
strides: Tuple[int, int],
padding: str) -> Tuple[AbstractTensor, AbstractTensor]:
x_on_0, x_on_1 = x.unwrapped
batch, channels, height, width = x.shape
if padding == "SAME":
out_height = math.ceil(int(height) / strides[0])
out_width = math.ceil(int(width) / strides[1])
else:
out_height = math.ceil((int(height) - pool_size[0] + 1) / strides[0])
out_width = math.ceil((int(width) - pool_size[1] + 1) / strides[1])
batch, channels, height, width = x.shape
pool_height, pool_width = pool_size
with tf.device(prot.server_0.device_name):
x_split = x_on_0.reshape((batch * channels, 1, height, width))
y_on_0 = x_split.im2col(pool_height, pool_width, padding, strides[0])
with tf.device(prot.server_1.device_name):
x_split = x_on_1.reshape((batch * channels, 1, height, width))
y_on_1 = x_split.im2col(pool_height, pool_width, padding, strides[0])
return y_on_0, y_on_1, [out_height, out_width, int(batch), int(channels)]
def _maxpool2d_public(prot: Pond,
x: PondPublicTensor,
pool_size: Tuple[int, int],
strides: Tuple[int, int],
padding: str) -> PondPublicTensor:
with tf.name_scope('maxpool2d'):
y_on_0, y_on_1, reshape_to = _im2col(prot, x, pool_size, strides, padding)
im2col = PondPublicTensor(prot, y_on_0, y_on_1, x.is_scaled)
max = im2col.reduce_max(axis=0)
result = max.reshape(reshape_to).transpose([2, 3, 0, 1])
return result
def _maxpool2d_private(prot: Pond,
x: PondPrivateTensor,
pool_size: Tuple[int, int],
strides: Tuple[int, int],
padding: str) -> PondPrivateTensor:
with tf.name_scope('maxpool2d'):
y_on_0, y_on_1, reshape_to = _im2col(prot, x, pool_size, strides, padding)
im2col = PondPrivateTensor(prot, y_on_0, y_on_1, x.is_scaled)
max = im2col.reduce_max(axis=0)
result = max.reshape(reshape_to).transpose([2, 3, 0, 1])
return result
def _maxpool2d_masked(prot: Pond,
x: PondMaskedTensor,
pool_size: Tuple[int, int],
strides: Tuple[int, int],
padding: str) -> PondPrivateTensor:
with tf.name_scope('maxpool2d'):
return prot.maxpool2d(x.unwrapped, pool_size, strides, padding)