Commit 608847ec authored by MattiaPujatti's avatar MattiaPujatti
Browse files

completed normalization layers

parent af0634b9
This diff is collapsed.
This diff is collapsed.
......@@ -333,7 +333,6 @@ class Haiku_Classifier:
# yet supported, especially haiku trasnformed functions (see https://github.com/deepmind/dm-haiku/issues/59)
# So we are allowed to save only a few parameters (mainly the network's parameters and states, the optimizer
# state and the training history)
backup_dict = {k: self.__dict__[k]
backup_dict = {'net_params': self.get_net_params(),
'net_state': self.get_net_state(),
'opt_state': self.get_opt_state(),
......
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from haiku._src.utils import get_channel_index
from typing import Optional, Sequence
class CmplxBatchNorm(hk.Module):
def __init__(
self,
create_scale: bool,
create_offset: bool,
decay_rate: float,
eps: float = 1e-5,
scale_init: Optional[hk.initializers.Initializer] = None,
offset_init: Optional[hk.initializers.Initializer] = None,
axis: Optional[Sequence[int]] = None,
data_format: str = "channels_last",
name: Optional[str] = None,
):
super().__init__(name=name)
if not create_scale and scale_init is not None:
raise ValueError("Cannot set `scale_init` if `create_scale=False`")
if not create_offset and offset_init is not None:
raise ValueError("Cannot set `offset_init` if `create_offset=False`")
self.create_scale = create_scale
self.create_offset = create_offset
self.eps = eps
self.scale_init = scale_init or jnp.ones
self.offset_init = offset_init or jnp.zeros
self.axis = axis
self.channel_index = get_channel_index(data_format)
self.mean_ema = hk.ExponentialMovingAverage(decay_rate, name="mean_ema")
self.var_ema = hk.ExponentialMovingAverage(decay_rate, name="var_ema")
def __call__(
self,
inputs: jnp.ndarray,
is_training: bool,
test_local_stats: bool = False,
scale: Optional[jnp.ndarray] = None,
offset: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
if self.create_scale and scale is not None:
raise ValueError(
"Cannot pass `scale` at call time if `create_scale=True`.")
if self.create_offset and offset is not None:
raise ValueError(
"Cannot pass `offset` at call time if `create_offset=True`.")
channel_index = self.channel_index
if channel_index < 0:
channel_index += inputs.ndim
if self.axis is not None:
axis = self.axis
else:
axis = [i for i in range(inputs.ndim) if i != channel_index]
if is_training or test_local_stats:
mean = jnp.mean(inputs, axis, keepdims=True)
mean_of_squares = jnp.mean(jnp.square(inputs), axis, keepdims=True)
var = mean_of_squares - jnp.square(mean)
else:
mean = self.mean_ema.average
var = self.var_ema.average
if is_training:
self.mean_ema(mean)
self.var_ema(var)
w_shape = [1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)]
w_dtype = inputs.dtype
if self.create_scale:
scale = hk.get_parameter("scale", w_shape, w_dtype, self.scale_init)
elif scale is None:
scale = np.ones([], dtype=w_dtype)
if self.create_offset:
offset = hk.get_parameter("offset", w_shape, w_dtype, self.offset_init)
elif offset is None:
offset = np.zeros([], dtype=w_dtype)
eps = jax.lax.convert_element_type(self.eps, var.dtype)
inv = scale * jax.lax.rsqrt(var + eps)
return (inputs - mean) * inv + offset
......@@ -5,6 +5,8 @@ import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from haiku._src.utils import get_channel_index
from functools import partial
from typing import Optional, Tuple, Union, Sequence
import warnings
......@@ -509,10 +511,154 @@ class Cmplx_Normalization(hk.Module):
x: jnp.ndarray,
) -> jnp.ndarray:
"""Implementation of a complex normalization."""
norm = jnp.absolute( x.flatten() )
norm = jnp.absolute( x.reshape(x.shape[0],-1) ).mean(axis=1).reshape(-1,1)
return x / norm
class CmplxBatchNorm(hk.Module):
def __init__(
self,
create_scale: bool,
create_offset: bool,
decay_rate: float,
eps: float = 1e-5,
scale_rr_init: float = 1./jnp.sqrt(2.),
scale_ri_init: float = 0.,
scale_ii_init: float = 1./jnp.sqrt(2.),
offset_init: Optional[hk.initializers.Initializer] = None,
axis: Optional[Sequence[int]] = None,
data_format: str = "channels_last",
name: Optional[str] = None,
):
super().__init__(name=name)
if not create_scale and scale_init is not None:
raise ValueError("Cannot set `scale_init` if `create_scale=False`")
if not create_offset and offset_init is not None:
raise ValueError("Cannot set `offset_init` if `create_offset=False`")
self.create_scale = create_scale
self.create_offset = create_offset
self.eps = eps
self.scale_rr_init = partial(jnp.full, fill_value=scale_rr_init)
self.scale_ri_init = partial(jnp.full, fill_value=scale_ri_init)
self.scale_ii_init = partial(jnp.full, fill_value=scale_ii_init)
self.offset_init = offset_init or jnp.zeros
self.axis = axis
self.channel_index = get_channel_index(data_format)
self.mean_ema = hk.ExponentialMovingAverage(decay_rate, name="mean_ema")
self.cov_ema = hk.ExponentialMovingAverage(decay_rate, name="cov_ema")
def __call__(
self,
inputs: jnp.ndarray,
is_training: bool,
test_local_stats: bool = False,
scale: Optional[jnp.ndarray] = None,
offset: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""The implementation is based on an equivalent module written for Pytorch
https://github.com/ivannz/cplxmodule/blob/master/cplxmodule/nn/modules/batchnorm.py."""
if self.create_scale and scale is not None:
raise ValueError(
"Cannot pass `scale` at call time if `create_scale=True`.")
if self.create_offset and offset is not None:
raise ValueError(
"Cannot pass `offset` at call time if `create_offset=True`.")
channel_index = self.channel_index
if channel_index == 1: channels_first = True
if channel_index < 0:
channel_index += inputs.ndim
if self.axis is not None:
axis = self.axis
else:
axis = [i for i in range(inputs.ndim) if i != channel_index]
# Split input into its real and imaginary components, now it has shape [2,..]
inputs = jnp.array([inputs.real, inputs.imag])
# Shift the axis to the new shape
channel_index += 1
axis = [i+1 for i in axis]
if is_training or test_local_stats:
# Compute the mean for each channel and for real/imag
mean = jnp.mean(inputs, axis, keepdims=True) # shape = (2,1,C,1,..,1) or (2,1,..,1,C)
# Center the inputs
centered_inputs = inputs - mean
# Compute the variances (Add small epsilon to increase stability)
variances = (centered_inputs * centered_inputs).mean(axis) + self.eps # shape = (2,C)
Var_Rez, Var_Imz = variances[0], variances[1] # shape = (C,)
# Compute the covariances
Cov_ReIm = Cov_ImRe = (centered_inputs[0] * centered_inputs[1]).mean([a-1 for a in axis]) # shape = (C,)
# Construct the covariance matrix
covariance_matrix = jnp.array( [[Var_Rez, Cov_ReIm], [Cov_ImRe, Var_Imz]] ).reshape(2,2,-1) # shape = (2,2,C)
else:
mean = self.mean_ema.average
covariance_matrix = self.cov_ema.average
Var_Rez, Cov_ReIm, Cov_ImRe, Var_Im = covariance_matrix.reshape(4,-1)
# Update the moving averages
if is_training:
self.mean_ema(mean)
self.cov_ema(covariance_matrix)
# To construct the inverse square root of the covariance matrix, without explicit inversion, we follow
# the instructions at https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
sqrt_det = jnp.sqrt(Var_Rez * Var_Imz - Cov_ReIm * Cov_ImRe)
sqrt_tr = jnp.sqrt(Var_Rez + Var_Imz + 2*sqrt_det)
denom = sqrt_det * sqrt_tr
inverse_root_covmat = jnp.array([[Var_Imz + sqrt_det, - Cov_ReIm],
[- Cov_ImRe, Var_Rez + sqrt_det]]).reshape(2,2,-1) # shape = (2,2,C)
inverse_root_convmat /= denom
# Normalize the input data
if channels_first:
einstein_formula = 'ijk,jlk...->ilk...'
else:
einstein_formula = 'ij...,j...->i...'
normalized_input = jnp.einsum(einstein_formula, inverse_root_covmat, centered_inputs)
w_shape = Var_Rez.shape
w_dtype = Var_Rez.dtype
b_shape = [1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)]
if self.create_scale:
scale_rr = hk.get_parameter("scale_rr", w_shape, w_dtype, self.scale_rr_init)
scale_ri = hk.get_parameter("scale_ri", w_shape, w_dtype, self.scale_ri_init)
scale_ii = hk.get_parameter("scale_ii", w_shape, w_dtype, self.scale_ii_init)
scale = jnp.array([[scale_rr, scale_ri],[scale_ri, scale_ii]], dtype=w_dtype).reshape(2,2,-1)
elif scale is None:
scale = np.ones([], dtype=w_dtype)
if self.create_offset:
offset = hk.get_parameter("offset", b_shape, w_dtype, self.offset_init)
elif offset is None:
offset = np.zeros([], dtype=w_dtype)
# Apply parameters transform
output = jnp.einsum(einstein_formula, scale, normalized_input) + offset
# Reconstruct the complex normalized input
return jax.lax.complex(output[0], output[1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment