Commit e4fbd08c authored by MattiaPujatti's avatar MattiaPujatti
Browse files

reformat layers

parent 57fdda4b
......@@ -62,5 +62,5 @@
\newlabel{fig:cmplx_convolution}{{1.4}{7}{Implementation details of the Complex Convolution (by \cite {trabelsi2018deep}).\relax }{figure.caption.11}{}}
\newlabel{eq:cmplx_batchnorm}{{1.3}{7}{Normalization Layers}{equation.1.3.3}{}}
\@writefile{toc}{\contentsline {section}{\numberline {1.4}Complex-Valued Activation Functions}{8}{section.1.4}\protected@file@percent }
\@writefile{toc}{\contentsline {section}{\numberline {1.5}JAX Implementation}{8}{section.1.5}\protected@file@percent }
\gdef \@abspage@last{8}
\@writefile{toc}{\contentsline {section}{\numberline {1.5}JAX Implementation}{9}{section.1.5}\protected@file@percent }
\gdef \@abspage@last{9}
......@@ -3,9 +3,9 @@ Capacity: max_strings=200000, hash_size=200000, hash_prime=170003
The top-level auxiliary file: extent.aux
I found no \bibdata command---while reading file extent.aux
I found no \bibstyle command---while reading file extent.aux
You've used 8 entries,
You've used 9 entries,
0 wiz_defined-function locations,
93 strings with 653 characters,
95 strings with 677 characters,
and the built_in function-call counts, 0 in all, are:
= -- 0
> -- 0
......
This is pdfTeX, Version 3.14159265-2.6-1.40.21 (TeX Live 2020/Debian) (preloaded format=pdflatex 2021.6.3) 4 NOV 2021 20:24
This is pdfTeX, Version 3.14159265-2.6-1.40.21 (TeX Live 2020/Debian) (preloaded format=pdflatex 2021.6.3) 5 NOV 2021 13:34
entering extended mode
restricted \write18 enabled.
%&-line parsing enabled.
......@@ -871,11 +871,25 @@ Package fancyhdr Warning: \headheight is too small (12.0pt):
LaTeX Warning: Reference `th:Liouville' on page 8 undefined on input line 186.
Underfull \hbox (badness 10000) in paragraph at lines 186--190
Underfull \hbox (badness 10000) in paragraph at lines 186--189
[]
Underfull \hbox (badness 10000) in paragraph at lines 192--194
[]
Package fancyhdr Warning: \headheight is too small (12.0pt):
(fancyhdr) Make it at least 13.59999pt, for example:
(fancyhdr) \setlength{\headheight}{13.59999pt}.
(fancyhdr) You might also make \topmargin smaller to compensate:
(fancyhdr) \addtolength{\topmargin}{-1.59999pt}.
[8]
Package fancyhdr Warning: \headheight is too small (12.0pt):
(fancyhdr) Make it at least 13.59999pt, for example:
(fancyhdr) \setlength{\headheight}{13.59999pt}.
......@@ -883,7 +897,7 @@ Package fancyhdr Warning: \headheight is too small (12.0pt):
(fancyhdr) \addtolength{\topmargin}{-1.59999pt}.
[8] (./extent.aux)
[9] (./extent.aux)
LaTeX Warning: There were undefined references.
......@@ -891,8 +905,8 @@ Package rerunfilecheck Info: File `extent.out' has not changed.
(rerunfilecheck) Checksum: 69418383BC20A3C5ADE2D66D57B72767;594.
)
Here is how much of TeX's memory you used:
13007 strings out of 479304
193647 string characters out of 5869780
13013 strings out of 479304
193713 string characters out of 5869780
549899 words of memory out of 5000000
29880 multiletter control sequences out of 15000+600000
416756 words of font info for 81 fonts, out of 8000000 for 9000
......@@ -922,10 +936,10 @@ pfb></usr/share/texmf/fonts/type1/public/lm/lmss17.pfb></usr/share/texlive/texm
f-dist/fonts/type1/public/amsfonts/symbols/msam10.pfb></usr/share/texlive/texmf
-dist/fonts/type1/public/rsfs/rsfs10.pfb></usr/share/texmf/fonts/type1/public/c
m-super/sfrm1095.pfb>
Output written on extent.pdf (8 pages, 839174 bytes).
Output written on extent.pdf (9 pages, 841233 bytes).
PDF statistics:
247 PDF objects out of 1000 (max. 8388607)
200 compressed objects within 2 object streams
36 named destinations out of 1000 (max. 500000)
260 PDF objects out of 1000 (max. 8388607)
211 compressed objects within 3 object streams
42 named destinations out of 1000 (max. 500000)
85 words of extra memory for PDF output out of 10000 (max. 10000000)
......@@ -53,7 +53,10 @@ As anticipated in the introductory section, the interest of researchers in this
\item[-] he relied on "bad" activation functions, since, as told by he himself, many times the algorithm failed to converge.
\end{itemize}
I decided to report his work because it was still one of the first and working attempts to develop a complex backpropagation algorithm, but also because of the purely theoretical analysis realized on the transformation that a complex network can learn. Nitta, managed to teach its networks several transformations in $\mathds{R}^2$, like rotations, reductions and parallel displacements, that the corresponding real-valued model didn't make. He understood first that this was possible thanks to the higher degrees of freedom offered by complex multiplication (discussed in section \ref{subsec:cmplx_multiplication}). But what I believe it is even more interesting, is the relation that Nitta have found among complex-valued networks and the \textbf{Identity theorem} \ref{th:identity}:\\
\textit{``We believe that Complex-BP networks satisfy the Identity Theorem, that is, Complex-BP networks can approximate complex functions just by training them only over a part of the domain of the complex functions."}\\
\textit{``We believe
that Complex-BP networks satisfy the Identity Theorem, that is, Complex-BP networks can approximate complex
functions just by training them only over a part of the
domain of the complex functions."}\\
This means that exploiting holomorphic functions when building a complex-valued network can sometimes impact on its generalization capabilities (since its shape will be rigidly determined by its characteristics on a small local region of its domain) \cite{hirose_cvnn}. Unfortunately, no additional work have been realized on this statement during the years, but I think it is an aspect deserving further attention.\\
In section \ref{sec:cmplx_differentiability} we have discussed about complex differentiability, and we also said that holomorphicity is not a property assured for most functions, and even simple ones, like the square modulus, can be not differentiable in the complex sense. In our architectures we have mainly two sources of \textit{nonholomorphicity}: the loss and the activations. For reasons that will be clearer later on, boundedness and analiticity cannot be achieved simultaneously in the complex domain, and the first feature is often preferred \cite{amin_wirtinger}.\\
An elegant approach that can save computational labor is the usage of Wirtinger calculus to setup optimization problems, solvable via gradient descent, for functions that are not holomorphic but at least differentiable with respect to their real and imaginary components.\\
......@@ -180,13 +183,36 @@ In simple \texttt{Complex Normalization} we scales a complex scalar input $\vb{z
There are many layers that do not need any further re-definition to work also in the complex domain: \texttt{Dropout}, Pad or Attention layer, for example. There are also many other structures that should be re-derived (e.g Recurrent layers, LSTM, etc.), but that were out of our scope and so we haven't examined. This should be interpreted just as a starting point in the development of an higher level complex-valued deep learning framework.
\section{Complex-Valued Activation Functions}
One of the main issues encountered in the last 30 years in the developing a complex-valued deep learning framework was exactly the definition of reliable activation functions. The extension from the real-valued domain is everything but easy: during the years, tons of complex-valued non-linear functions have been proposed and tested, but the limitations imposed by the Liouville's theorem \ref{th:Liouville}, together with the fact that many operations (like \textit{max}) are undefined, was a huge obstacle. Additionally, with complex-valued outputs, we have lost the probabilistic interpretations that functions
like \texttt{sigmoid} and \texttt{softmax} use to provide.\\
We have to say that most of the candidate functions that have been proposed are based, however, on the simple decomposition of the input into real and imaginary part, that are then sent to a real non-linear activation. But, as discussed also in the previous chapter, this approach should be abandoned, since you risk losing the complex correlations stored in those variables.\\
One of the main issues encountered in the last 30 years in the developing of a complex-valued deep learning framework was exactly the definition of reliable activation functions. The extension from the real-valued domain is everything but easy: during the years, tons of complex-valued non-linear functions have been proposed and tested, but the limitations imposed by the Liouville's theorem \ref{th:Liouville}, together with the fact that many operations (like \textit{max}) are undefined, was a huge obstacle. Additionally, with complex-valued outputs, we have lost the probabilistic interpretations that functions like \texttt{sigmoid} and \texttt{softmax} used to provide.\\
We have to say that most of the candidate functions that have been proposed, have been developed in a split fashion, i.e. by considering the real and imaginary parts of the activation separately. But, as discussed also in the previous chapter, this approach should be abandoned, since you risk losing the complex correlations stored in those variables.\\
In this section, we will explore a few complex-valued activations proposed during the years: first with the ones that are direct extensions of their real counterparts, and then with more "abstract" candidates, that have more reasons to live and work in the complex domain.\\
\subsection*{Complex Sigmoid}
The most straightforward complex-valued non-linear function that we can think about is definitely the \textbf{complex sigmoid}, that is nothing but the same real-valued sigmoid extended to $\mathds{C}$.
\[ \sigma_\mathds{C}(z) = \frac{1}{1+e^{-z}} \]
Problem: the sigmoid function diverges periodically on the imaginary axis of the complex plane. The instability can be .. limiting the domain of the input values. However, it does not seem a good approach to begin with.\\
\subsection*{Separable Activations}
As already explained, the main tendency in the development of complex-valued activation functions was basically getting back the "old" designs for real-valued models and using them independently on the real and imaginary components of the input.
\subsection*{Phase-preserving Activations}
\subsection*{Complex Cardioid}
\begin{table}[!ht]
\centering
\begin{tabular}{c c c}
\toprule
\textbf{Activation} & \textbf{Analytic Form} & \textbf{Reference}\\
\midrule
Sigmoid \\
Separable Sigmoid\\
Siglog\\
\bottomrule
\end{tabular}
\end{table}
\section{JAX Implementation}
......
This is pdfTeX, Version 3.14159265-2.6-1.40.21 (TeX Live 2020/Debian) (preloaded format=pdflatex 2021.6.3) 4 NOV 2021 20:24
This is pdfTeX, Version 3.14159265-2.6-1.40.21 (TeX Live 2020/Debian) (preloaded format=pdflatex 2021.6.3) 5 NOV 2021 11:24
entering extended mode
restricted \write18 enabled.
%&-line parsing enabled.
......
No preview for this file type
No preview for this file type
import jax
import jax.numpy as jnp
import haiku as hk
from tqdm.notebook import tqdm
from typing import Optional, Tuple, Any
def initialize_cmplx_haiku_model(model, init_shape, rng_seed=42, **model_kwargs):
......@@ -18,3 +19,91 @@ def initialize_cmplx_haiku_model(model, init_shape, rng_seed=42, **model_kwargs)
net_params, net_state = network.init( key, dummy_input, is_training=True )
return network, net_params, net_state
def evaluate_model(network,
net_params,
net_state,
metric,
dataloader
) -> Any:
"""Calling this function one can compute the `metric` estimate over the input
dataset, provided as a Pytorch dataloader.
Args
----
dataloader: pytorch DataLoader
Dataloader containing all the test samples.
"""
log_metric = []
for batch in tqdm(dataloader, desc='Looping over the dataset.', unit='batches', leave=False):
x_batch, y_batch = batch[0].numpy(), batch[1].numpy()
batch_metric, _ = metric( model, net_params, x_batch, y_batch, next(self.rng_seq), net_state, is_training=False )
log_metric.append( batch_metric )
return np.mean(log_metric)
@partial(jit, static_argnums=(0,8,))
def update(
params: hk.Params,
x: jnp.array,
y: jnp.array,
metric: Callable,
opt_state: Any,
opt_update: Callable,
get_params: Callable,
rng_key: Optional[jax.random.PRNGKey] =None,
net_state: Optional[hk.State] = None,
is_training: bool = False,
step: int = 0,
) -> Tuple[hk.Params, Any, float, hk.State]:
"""Given a minibatch of samples, it compute the loss and the parameters updates of the network.
Then, since jax.grad calculate the complex gradient df/dz and not the conjugate (as needed by
the complex backpropagation), an additional step performs this operation, before applying the
updates just computed.
Args
----
params: hk.Module.parameters()
Parameters of the network.
opt_state: jax pytree
Object representing the actual optimizer state.
x: array
Array of samples to give in input to the network.
y: array
Array of one-hot encoded labels.
rng_key: jax.random.PRNGKey, optional (default is None)
PRNGKey necessary to the 'apply' method of the transformed network.
net_state: hk.State, optional (default is None)
Internal state of the network. Set 'None' if the network has no internal trainable state.
is_training: bool, optional (default is False)
Flags that alert the network if it is called in training or evaluation mode. Useful in presence
of dropout or batchnormalization layers.
step: int, optional (default is 0)
Index of the update step.
Return
------
new_params: hk.Module.parameters
New estimates of network's parameters.
opt_state: jax pytree
Optimizer state after the update.
loss: float
Loss estimate for the given minibatch.
net_state:
Internal state of the network.
"""
(loss, net_state), grads = value_and_grad(self.__crossentropy_loss, has_aux=True)(params, x, y, rng_key, net_state, is_training)
grads = jax.tree_multimap(jnp.conjugate, grads)
#print(jax.tree_multimap(jnp.mean, grads))
opt_state = self.opt_update(step, grads, opt_state)
return self.get_params(opt_state), opt_state, loss, net_state
......@@ -12,6 +12,11 @@ import warnings
from complex_nn.initializers import CmplxRndUniform, CmplxTruncatedNormal
##################### LINEAR LAYERS ##################################################
class Cmplx_Linear(hk.Module):
"""Linear module."""
......@@ -72,6 +77,9 @@ class Cmplx_Linear(hk.Module):
############# DROPOUT ####################################################################
class Dropout(hk.Module):
"""Basic implementation of a Dropout layer."""
......@@ -103,31 +111,9 @@ class Dropout(hk.Module):
def to_dimension_numbers(
num_spatial_dims: int,
channels_last: bool,
transpose: bool,
) -> lax.ConvDimensionNumbers:
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if channels_last:
spatial_dims = tuple(range(1, num_dims - 1))
image_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
image_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(lhs_spec=image_dn, rhs_spec=kernel_dn,
out_spec=image_dn)
##################### CONVOLUTIONAL LAYERS ###########################################
class Cmplx_ConvND(hk.Module):
......@@ -381,6 +367,29 @@ def _infer_shape(
return tuple(size)
def to_dimension_numbers(
num_spatial_dims: int,
channels_last: bool,
transpose: bool,
) -> lax.ConvDimensionNumbers:
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if channels_last:
spatial_dims = tuple(range(1, num_dims - 1))
image_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
image_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(lhs_spec=image_dn, rhs_spec=kernel_dn,
out_spec=image_dn)
......@@ -438,9 +447,7 @@ def cmplx_max_pool(
################ POOLING LAYERS #######################################################
class MaxMagnitude_Pooling(hk.Module):
......@@ -480,4 +487,32 @@ class MaxMagnitude_Pooling(hk.Module):
################## NORMALIZATION LAYERS ####################################################
class Cmplx_Normalization(hk.Module):
"""Basic implementation of a Complex normalization layer, that normalize all the input data
to a unitary magnitude, leaving the phase untouched."""
def __init__(
self,
name: Optional[str] = None,
):
"""Constructs the Cmplx_Normalization module.
Args:
name: Name of the module.
"""
super().__init__(name=name)
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
"""Implementation of a complex normalization."""
norm = jnp.absolute( x.flatten() )
return x / norm
import haiku as hk
from haiku import RNNCore, LSTMState
from haiku._src import utils
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from typing import Optional, Tuple, Union, Sequence
import warnings
from complex_nn.initializers import CmplxRndUniform, CmplxTruncatedNormal
class Cmplx_Linear(hk.Module):
"""Linear module."""
def __init__(
self,
output_size: int,
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
b_init: Optional[hk.initializers.Initializer] = None,
name: Optional[str] = None,
):
"""Constructs the Linear module.
Args:
output_size: Output dimensionality.
with_bias: Whether to add a bias to the output.
w_init: Optional initializer for weights. By default, uses random values
from truncated normal, with stddev ``1 / sqrt(fan_in)``. See
https://arxiv.org/abs/1502.03167v3.
b_init: Optional initializer for bias. By default, uniform in [-0.001, 0.001].
name: Name of the module.
"""
super().__init__(name=name)
self.input_size = None
self.output_size = output_size
self.with_bias = with_bias
self.w_init = w_init
self.b_init = b_init or CmplxRndUniform(minval=-0.001, maxval=0.001)
def __call__(
self,
inputs: jnp.ndarray,
*,
precision: Optional[jax.lax.Precision] = None,
) -> jnp.ndarray:
"""Computes a linear transform of the input."""
if not inputs.shape:
raise ValueError("Input must not be scalar.")
input_size = self.input_size = inputs.shape[-1]
output_size = self.output_size
dtype = inputs.dtype
w_init = self.w_init
if w_init is None:
stddev = 1. / np.sqrt(self.input_size)
w_init = CmplxTruncatedNormal(mean=0., stddev=stddev)
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
out = jnp.dot(inputs, w, precision=precision)
if self.with_bias:
b = hk.get_parameter("b", [self.output_size], dtype, init=self.b_init)
b = jnp.broadcast_to(b, out.shape)
out = out + b
return out
class Dropout(hk.Module):
"""Basic implementation of a Dropout layer."""
def __init__(
self,
rate: float,
name: Optional[str] = None,
):
"""Constructs the Dropout module.
Args:
rate: Probability that each element of x is discarded. Must be a scalar in the range [0, 1).
name: Name of the module.
"""
super().__init__(name=name)
self.rate = rate
def __call__(
self,
x: jnp.ndarray,
is_training: Optional[bool] = True,
) -> jnp.ndarray:
"""Wrapper layer of the function hk.Dropout."""
if is_training:
return hk.dropout(hk.next_rng_key(), self.rate, x)
else:
return x
def to_dimension_numbers(
num_spatial_dims: int,
channels_last: bool,
transpose: bool,
) -> lax.ConvDimensionNumbers:
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if channels_last:
spatial_dims = tuple(range(1, num_dims - 1))
image_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
image_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(lhs_spec=image_dn, rhs_spec=kernel_dn,
out_spec=image_dn)
class Cmplx_ConvND(hk.Module):
"""General N-dimensional complex convolutional."""
def __init__(
self,
num_spatial_dims: int,
output_channels: int,
kernel_shape: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
rate: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn,
Sequence[hk.pad.PadFn]] = "SAME",
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
b_init: Optional[hk.initializers.Initializer] = None,
data_format: str = "channels_last",
mask: Optional[jnp.ndarray] = None,
feature_group_count: int = 1,
name: Optional[str] = None,
):
"""Initializes the module.
Args:
num_spatial_dims: The number of spatial dimensions of the input.
output_channels: Number of output channels.
kernel_shape: The shape of the kernel. Either an integer or a sequence of
length ``num_spatial_dims``.
stride: Optional stride for the kernel. Either an integer or a sequence of
length ``num_spatial_dims``. Defaults to 1.
rate: Optional kernel dilation rate. Either an integer or a sequence of
length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
``rate > 1`` corresponds to dilated convolution. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME`` or a
sequence of n ``(low, high)`` integer pairs that give the padding to
apply before and after each spatial dimension. or a callable or sequence
of callables of size ``num_spatial_dims``. Any callables must take a
single integer argument equal to the effective kernel size and return a
sequence of two integers representing the padding before and after. See
``haiku.pad.*`` for more details and example functions. Defaults to
``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
with_bias: Whether to add a bias. By default, true.
w_init: Optional weight initialization. By default, truncated normal.
b_init: Optional bias initialization. By default, zeros.
data_format: The data format of the input. Can be either
``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
default, ``channels_last``.
mask: Optional mask of the weights.
feature_group_count: Optional number of groups in group convolution.
Default value of 1 corresponds to normal dense convolution. If a higher
value is used, convolutions are applied separately to that many groups,
then stacked together. This reduces the number of parameters
and possibly the compute for a given ``output_channels``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
name: The name of the module.
"""
super().__init__(name=name)
if num_spatial_dims <= 0:
raise ValueError(
"We only support convolution operations for `num_spatial_dims` "
f"greater than 0, received num_spatial_dims={num_spatial_dims}.")
self.num_spatial_dims = num_spatial_dims
self.output_channels = output_channels
self.kernel_shape = (utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape"))
self.with_bias = with_bias
self.stride = utils.replicate(stride, num_spatial_dims, "strides")
self.w_init = w_init
self.b_init = b_init or CmplxRndUniform(minval=-0.001, maxval=0.001)
self.mask = mask
self.feature_group_count = feature_group_count
self.lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation")
self.kernel_dilation = (utils.replicate(rate, num_spatial_dims, "kernel_dilation"))
self.data_format = data_format
self.channel_index = utils.get_channel_index(data_format)
self.dimension_numbers = to_dimension_numbers(
num_spatial_dims, channels_last=(self.channel_index == -1),
transpose=False)
if isinstance(padding, str):
self.padding = padding.upper()
elif hk.pad.is_padfn(padding):
self.padding = hk.pad.create_from_padfn(padding=padding,
kernel=self.kernel_shape,
rate=self.kernel_dilation,
n=self.num_spatial_dims)
else:
self.padding = hk.pad.create_from_tuple(padding, self.num_spatial_dims)
def __call__(
self,
inputs: jnp.ndarray,
*,
precision: Optional[lax.Precision] = None,
) -> jnp.ndarray: