Commit 86049b46 authored by MattiaPujatti's avatar MattiaPujatti
Browse files

Add complex libraries

parent 6571887c
# -*- coding: utf-8 -*-
"""
HAIKU CLASSIFIER
-----------------------------------------
Example:
----------------------------------------
Last modified: 01/09/2021
Author:
Mattia Pujatti,
Physics of Data student,
Internship at FBK-DSIP, Trento.
mpujatti@fbk.eu
----------------------------------------
"""
import jax
import haiku as hk
from jax import random, grad, jit, value_and_grad
import jax.numpy as jnp
from functools import partial
import time
from tqdm.notebook import tqdm
import numpy as np
from typing import Optional, Any, Tuple, Callable
class Haiku_Classifier:
"""This class is formulated in order to simplify all the sequence of operations that one has to write
in order to train a model written in Haiku. Initializing an instance of this class, fixed the
necessary hyperparameters, and simply calling the 'train' method on a dataset:
The idea is giving in input to the training function a 'forward' function, in which one invokes the
'__call__' method, defined in an object of type 'hk.Module', with the necessary parameters.
Attributes
----------
forward_fn: function
function calling the network class (**check class description above)
opt_init, opt_update, get_params: functions
Triple of functions constituting an optimizer from the jax.experimental.optimizers module
(refer to official documentation at https://jax.readthedocs.io/en/latest/jax.experimental.optimizers.html?highlight=optimizers)
trained_parameters: dictionary
Container of the trained network parameters (empty before calling the train function)
rng_seq: hk.PRNGSequence
Sequence of random jax keys.
With next(self.rng_seq) you have access to a new random key.
with_state: bool
Flag that enable the presence of an internal trainable state in the network (like Batchnorm).
When True, it transforms the model with 'hk.transform_with_state' and mantains a 'network state'
variable across the training.
network: transformed model
Internal variable containing the model to train after the transformation of 'forward' into a pure
function, according to haiku's abstraction technique (please refer to official documentation at
https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html).
Methods
-------
init(forwar_fn, optimizer, rng_seed, with_state)
Initialize the attributes of the class.
__categorical_accuracy(params, inputs, targets, rng_key, net_state)
Compute the accuracy of the network's predictions for input data.
__crossentropy_loss(params, inputs, targets, rng_key, net_state, is_training)
Compute the categorical crossentropy loss of the network's predicitons over the input data.
__update(step, params, opt_state, x, y, rng_key, net_state, is_training)
Perform an update step of the parameters of the network.
train(n_epochs, train_dataloader, test_dataloader, verbose)
Setup and run the training process looping over both the train and validation sets in order
to call the update function (when required), and collecting the values of accuracy and loss
for each epoch.
evaluate_dataset(dataloader)
Once the network has been correctly trained, one can call this function and compute the values
of accuracy and loss for the given dataset.
"""
def __init__(self,
forward_fn: Callable[[hk.Module, Optional[bool]], Any],
optimizer: jax.experimental.optimizers,
rng_seed: Optional[int] = 42,
with_state: Optional[bool] = False
):
"""Initialize the attributes of the class.
Args
----
forward_fn: function
Forward function of the haiku.Module object defining the network.
optimizer: haiku.experimental.optimizer
One of the optimizer proposed by haiku.
rng_seed: int, optional (default is 42)
Initial seed to construct the PRNGSequence.
with_state: bool, optional (default is False)
Flag that enable/disable the tracking of the network state.
"""
self.forward_fn = forward_fn
self.opt_init, self.opt_update, self.get_params = optimizer
self.trained_parameters = None
self.rng_seq = hk.PRNGSequence(random.PRNGKey(rng_seed))
self.with_state = with_state
@partial(jit, static_argnums=(0,))
def __categorical_accuracy(self,
params: hk.Params,
inputs,
targets,
rng_key: Optional[jax.random.PRNGKey] =None,
net_state: Optional[hk.State] = None
) -> Any:
"""Compute the fraction of correctly classified samples by the network, for a given input batch.
Args
----
params: hk.Params
Parameters of the network.
inputs: array
Array of samples to give in input to the network.
targets: array
Array of one-hot encoded labels.
rng_key: jax.random.PRNGKey, optional (default is None)
PRNGKey necessary to the 'apply' method of the transformed network.
net_state: hk.State, optional (defualt is None)
Internal state of the network. Set 'None' if the network has no internal trainable state.
Return
------
categorical_accuracy: float
Fraction of correctly classified samples by the network.
"""
target_class = jnp.argmax(targets, axis=-1)
if net_state is None:
predictions = self.network.apply(params, rng_key, inputs, is_training=False)
else:
predictions, net_state = self.network.apply(params, net_state, rng_key, inputs, is_training=False)
# Traditional accuracy is not defined for complex output
predictions = jnp.absolute(predictions)
predicted_class = jnp.argmax(predictions, axis=-1)
return jnp.mean(predicted_class == target_class)
@partial(jit, static_argnums=(0,6,))
def __crossentropy_loss(self,
params: hk.Params,
inputs,
targets,
rng_key: Optional[jax.random.PRNGKey] =None,
net_state: Optional[hk.State] = None,
is_training: bool = False
) -> Tuple[Any, hk.State]:
"""Compute the categorical crossentropy loss between the samples given in input and the
corresponding network's predictions.
Args
----
params: hk.Module.parameters()
Parameters of the network.
inputs: array
Array of samples to give in input to the network.
targets: array
Array of one-hot encoded labels.
rng_key: jax.random.PRNGKey, optional (default is None)
PRNGKey necessary to the 'apply' method of the transformed network.
net_state: , optional (defualt is None)
Internal state of the network. Set 'None' if the network has no internal trainable state.
is_training: bool, optional (default is False)
Flags that alert the network if it is called in training or evaluation mode. Useful in presence
of dropout or batchnormalization layers.
Return
------
softmax_xent: float
Estimate of the crossentropy loss for the input batch.
net_state:
Actual internal state of the network.
"""
if net_state is None:
logits = self.network.apply(params, rng_key, inputs, is_training)
else:
logits, net_state = self.network.apply(params, net_state, rng_key, inputs, is_training)
# Traditional cross-entropy is not defined for complex output
logits = jnp.absolute(logits)
# Add weigth regularization
#l1_loss = sum(jnp.sum(jnp.abs(p)) for p in jax.tree_leaves(params))
#l2_loss = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree_leaves(params))).real
softmax_xent = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) / len(targets)
#total_loss = softmax_xent + 1e-4*l2_loss
return softmax_xent, net_state
@partial(jit, static_argnums=(0,8,))
def __update(self,
step: int,
params: hk.Params,
opt_state: Any,
x,
y,
rng_key: Optional[jax.random.PRNGKey] =None,
net_state: Optional[hk.State] = None,
is_training: bool = False
) -> Tuple[hk.Params, Any, float, hk.State]:
"""Given a minibatch of samples, it compute the loss and the parameters updates of the network.
Then, since jax.grad calculate the complex gradient df/dz and not the conjugate (as needed by
the complex backpropagation), an additional step performs this operation, before applying the
updates just computed.
Args
----
step: int
Index of the update step.
params: hk.Module.parameters()
Parameters of the network.
opt_state: jax pytree
Object representing the actual optimizer state.
x: array
Array of samples to give in input to the network.
y: array
Array of one-hot encoded labels.
rng_key: jax.random.PRNGKey, optional (default is None)
PRNGKey necessary to the 'apply' method of the transformed network.
net_state: , optional (default is None)
Internal state of the network. Set 'None' if the network has no internal trainable state.
is_training: bool, optional (default is False)
Flags that alert the network if it is called in training or evaluation mode. Useful in presence
of dropout or batchnormalization layers.
Return
------
new_params: hk.Module.parameters
New estimates of network's parameters.
opt_state: jax pytree
Optimizer state after the update.
loss: float
Loss estimate for the given minibatch.
net_state:
Internal state of the network.
"""
(loss, net_state), grads = value_and_grad(self.__crossentropy_loss, has_aux=True)(params, x, y, rng_key, net_state, is_training)
grads = jax.tree_multimap(jnp.conjugate, grads)
#print(jax.tree_multimap(jnp.mean, grads))
opt_state = self.opt_update(step, grads, opt_state)
return self.get_params(opt_state), opt_state, loss, net_state
def train(self,
n_epochs: int,
train_dataloader,
test_dataloader,
verbose: bool = False
) -> Any:
"""Setup and run the training process looping over both the train and validation sets in order
to call the update function (when required), and to collect the values of accuracy and loss
for each epoch.
Args
----
n_epochs: int
Number of epochs of the training loop.
train_dataloader: pytorch DataLoader
Dataloader containing all the training samples
test_dataloader: pytorch DataLoader
Dataloader containing all the validation samples
verbose: bool, optional (default is False)
Verbosity of the output
Return
------
training_history: dict
Dictionary containing the train/validation losses and accuracies for each epoch.
"""
# Initialize the parameters of the network
init_batch = next(iter(train_dataloader))
if self.with_state:
self.network = hk.transform_with_state(self.forward_fn)
net_params, net_state = self.network.init( next(self.rng_seq), init_batch[0].numpy(), is_training=True )
else:
self.network = hk.transform(self.forward_fn)
net_state = None
net_params = self.network.init( next(self.rng_seq), init_batch[0].numpy(), is_training=True )
# Initialize the optimizer state
opt_state = self.opt_init(net_params)
training_history = {'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
'epoch_time': []}
step = 0
for epoch in tqdm(range(n_epochs), desc='Training for several epochs', leave=False):
start_time = time.time()
log_train_loss, log_train_acc = [], []
for batch in tqdm(train_dataloader, desc='Looping over the minibatches', leave=False):
x_batch, y_batch = batch[0].numpy(), batch[1].numpy()
net_params, opt_state, batch_loss, net_state = self.__update(
step, net_params, opt_state, x_batch, y_batch, next(self.rng_seq), net_state, is_training=True )
batch_accuracy = self.__categorical_accuracy( net_params, x_batch, y_batch, next(self.rng_seq), net_state )
log_train_loss.append(batch_loss)
log_train_acc.append(batch_accuracy)
step += 1
training_history['epoch_time'].append(time.time() - start_time)
training_history['train_loss'].append(np.mean(log_train_loss))
training_history['train_acc'].append(np.mean(log_train_acc))
log_val_loss, log_val_acc = [], []
for batch in tqdm(test_dataloader, desc='Computing the validation loss', leave=False):
x_batch, y_batch = batch[0].numpy(), batch[1].numpy()
batch_loss, _ = self.__crossentropy_loss( net_params, x_batch, y_batch, next(self.rng_seq), net_state, is_training=False )
batch_accuracy = self.__categorical_accuracy( net_params, x_batch, y_batch, next(self.rng_seq), net_state )
log_val_loss.append(batch_loss)
log_val_acc.append(batch_accuracy)
training_history['val_loss'].append(np.mean(log_val_loss))
training_history['val_acc'].append(np.mean(log_val_acc))
if verbose:
print("Epoch {} in {:0.2f} sec".format( epoch+1, training_history['epoch_time'][-1]) )
print("Training set loss {}".format( training_history['train_loss'][-1]) )
print("Test set loss {}".format( training_history['val_loss'][-1]) )
print("Training set accuracy {}".format( training_history['train_acc'][-1]) )
print("Test set accuracy {}".format( training_history['val_acc'][-1]) )
self.trained_parameters = net_params
self.trained_net_state = net_state
return training_history
def evaluate_dataset(self,
dataloader
):
"""Once the network has been correctly trained, one can call this function and compute the values
of accuracy and loss for the given dataset.
Args
----
dataloader: pytorch DataLoader
Dataloader containing all the test samples.
"""
if self.trained_parameters is None:
raise ValueError('Please train the network before evaluating the performance.')
net_params = self.trained_parameters
log_val_loss, log_val_acc = [], []
for batch in tqdm(dataloader, desc='Computing the accuracy / loss over the dataset.', leave=False):
x_batch, y_batch = batch[0].numpy(), batch[1].numpy()
batch_loss, _ = self.__crossentropy_loss( net_params, x_batch, y_batch, next(self.rng_seq), self.trained_net_state, is_training=False )
batch_accuracy = self.__categorical_accuracy( net_params, x_batch, y_batch, next(self.rng_seq), self.trained_net_state )
log_val_loss.append(batch_loss)
log_val_acc.append(batch_accuracy)
print('Final loss of the test set: {:.3f}'.format(np.mean(log_val_loss)))
print('Final accuracy of the test set: {:.2f}%'.format(np.mean(log_val_acc)*100))
""" Definition of some complex activation functions according to the literature. """
import jax.numpy as jnp
import numpy as np
import torch
from jax import lax
from jax import custom_jvp
def cmplx_sigmoid(z):
return 1. / ( 1. + jnp.exp(-z) )
def separable_sigmoid(z):
return cmplx_sigmoid(jnp.real(z)) + 1.j*cmplx_sigmoid(jnp.imag(z))
def siglog(z):
c = 1. # stepness
r = 1. # scale
return z / (c + 1./r*jnp.abs(z))
def igaussian(z):
sigma = 1.
g = 1 - jnp.exp( -z*jnp.conj(z) / (2*sigma**2) )
n = z / jnp.sqrt( z*jnp.conj(z) )
return g * n
#@custom_jvp
def cardioid(z):
return 0.5 * (1 + jnp.cos(jnp.angle(z))) * z
#cardioid.defjvps(lambda g, ans, x: 0.5 + 0.5*jnp.cos(jnp.angle(x)) + 0.25j*jnp.sin(jnp.angle(x)))
""" Reframe haiku initializers to allow complex initialization of the network's weights. """
import haiku as hk
import jax
import jax.numpy as jnp
from typing import Any, Sequence
class CmplxRndUniform(hk.initializers.Initializer):
"""Initializes by sampling from a uniform distribution."""
def __init__(self, minval=0, maxval=1.):
"""Constructs a :class:`RandomUniform` initializer.
Args:
minval: The lower limit of the uniform distribution.
maxval: The upper limit of the uniform distribution.
"""
self.minval = minval
self.maxval = maxval
def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:
real_part = jax.random.uniform(hk.next_rng_key(), shape, dtype='float64', minval=self.minval, maxval=self.maxval)
imag_part = jax.random.uniform(hk.next_rng_key(), shape, dtype='float64', minval=self.minval, maxval=self.maxval)
return jax.lax.complex(real_part, imag_part)
class CmplxTruncatedNormal(hk.initializers.Initializer):
"""Initializes by sampling from a truncated normal distribution."""
def __init__(self, mean=0., stddev=1., lower=-2., upper=2.):
"""Constructs a :class:`TruncatedNormal` initializer.
Args:
stddev: The standard deviation parameter of the truncated normal distribution.
mean: The mean of the truncated normal distribution.
lower: The lower bound for truncation.
upper: The upper bound for truncation.
"""
self.mean = mean
self.stddev = stddev
self.lower = lower
self.upper = upper
def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:
m = jax.lax.convert_element_type(self.mean, new_dtype='float64')
s = jax.lax.convert_element_type(self.stddev, new_dtype='float64')
unscaled_r = jax.random.truncated_normal(hk.next_rng_key(), self.lower, self.upper, shape, dtype='float64')
unscaled_c = jax.random.truncated_normal(hk.next_rng_key(), self.lower, self.upper, shape, dtype='float64')
return jax.lax.complex(s * unscaled_r + m, s * unscaled_c + m)
This diff is collapsed.
import jax
import haiku as hk
from jax import random, grad, jit
import jax.numpy as jnp
from functools import partial
import time
from tqdm.notebook import tqdm
from itertools import cycle
class Classifier_wrapper:
def __init__(self, network, optimizer, rng_seed=42):
self.network = network
self.opt_init, self.opt_update, self.get_params = optimizer
self.rng_seq = hk.PRNGSequence(random.PRNGKey(rng_seed))
@partial(jit, static_argnums=(0,))
def __categorical_accuracy(self, params, inputs, targets):
target_class = jnp.argmax(targets, axis=1)
predictions = self.network.apply(params, inputs)
# Traditional accuracy is not defined for complex output
predictions = jnp.absolute(predictions)
predicted_class = jnp.argmax(predictions, axis=1)
return jnp.mean(predicted_class == target_class)
@partial(jit, static_argnums=(0,))
def __crossentropy_loss(self, params, inputs, targets):
logits = self.network.apply(params, inputs)
# Traditional cross-entropy is not defined for complex output
logits = jnp.absolute(logits)
return -jnp.sum(targets * jax.nn.log_softmax(logits)) / len(targets)
@partial(jit, static_argnums=(0,))
def __update(self, params, opt_state, x, y):
grads = grad(self.__crossentropy_loss)(params, x, y)
opt_state = self.opt_update(0, grads, opt_state)
return self.get_params(opt_state), opt_state
def train(self, n_epochs, batch_size, train_data, test_data, n_classes, verbose=False):
# Initialize the parameters of the network
init_batch = next(iter(train_data))[0].numpy()
net_params = self.network.init(next(self.rng_seq), init_batch)
# Initialize the optimizer state
opt_state = self.opt_init(net_params)
training_history = {'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
'epoch_time': []}
for epoch in tqdm(range(n_epochs), desc='Training for several epochs', leave=False):
start_time = time.time()
for batch in tqdm(train_data, desc='Looping over the minibatches', leave=False):
x_batch, y_batch = batch[0].numpy(), batch[1].numpy()
y_batch = jax.nn.one_hot(y_batch, n_classes)
net_params, opt_state = self.__update(net_params, opt_state, x_batch, y_batch)
training_history['epoch_time'].append(time.time() - start_time)
# Evaluate the model
for train_batch in train_data:
x_batch, y_batch = train_batch[0].numpy(), train_batch[1].numpy()
y_batch = jax.nn.one_hot(y_batch, n_classes)
training_history['train_loss'].append(
self.__crossentropy_loss(net_params, x_batch, y_batch))
training_history['train_acc'].append(
self.__categorical_accuracy(net_params, x_batch, y_batch))