Source code for osculari.paradigms.paradigm_utils

"""
Utility function for paradigms.
"""

import os
import numpy as np
import numpy.typing as npt
from typing import Union, Optional, List, Callable, Dict, Any, Sequence

import torch
import torch.nn as nn
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.optim import lr_scheduler

from ..models.readout import ProbeNet

__all__ = [
    'train_linear_probe'
]


def _accuracy_preds(output: torch.Tensor, target: torch.Tensor,
                    topk: Optional[Sequence] = (1,)) -> (List[float], List[torch.Tensor]):
    """
    Compute accuracy and correct predictions for the top-k thresholds.

    Parameters:
        output (torch.Tensor): Model predictions.
        target (torch.Tensor): Ground truth labels.
        topk (Optional[Sequence]): Top-k thresholds for accuracy computation. Default is (1,).

    Returns:
        Tuple[List[float], List[torch.Tensor]]: List of accuracies for each top-k threshold,
                                                list of correct predictions for each top-k
                                                threshold.
    """
    with torch.inference_mode():  # Ensure that the model is in inference mode
        maxk = max(topk)  # Extract the maximum top-k value
        batch_size = target.size(0)  # Get the batch size
        # Check if the target is a 2D tensor (i.e., multi-class classification)
        if target.ndim == 2:
            target = target.max(dim=1)[1]  # Extract the single class label for each sample

        _, pred = output.topk(maxk, 1, True, True)  # Extract the top-k predictions
        pred = pred.t()  # Transpose the predictions to match the target format
        correct = pred.eq(target[None])  # Create a tensor indicating the correct predictions

        accuracies = []  # List to store the computed accuracies
        corrects = []  # List to store the correct predictions for each top-k threshold
        for k in topk:  # Iterate over each top-k threshold
            # Extract the correct predictions for the current threshold
            corrects.append(correct[:k])
            # Compute the sum of correct predictions
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            accuracies.append(correct_k / batch_size)

    return accuracies, corrects


def accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
    """
    Compute the accuracy of model predictions.

    Parameters:
        output (torch.Tensor): Model predictions.
        target (torch.Tensor): Ground truth labels.

    Returns:
        float: Accuracy of the model predictions.
    """
    # Ensure the output has two dimensions (Linear layer output is two-dimensional)
    assert len(output.shape) == 2
    # Ensure output and target have the same number of elements
    assert len(output) == len(target)

    # Check if the model is performing binary classification
    if output.shape[1] == 1:
        # Convert to binary predictions (greater than 0)
        output_class = torch.gt(output, 0).flatten()
        # Compute accuracy for binary classification
        pred = torch.eq(output_class, target)
        return pred.float().mean().item()

    # Otherwise, the model produces multidimensional predictions
    acc, _ = _accuracy_preds(output, target, topk=[1])
    return acc[0].item()  # Extract the top-1 accuracy


def _circular_mean(a: float, b: float) -> float:
    """
    Compute the circular mean of two variables in the range of 0 to 1.

    Parameters:
        a (float): First angle in radians.
        b (float): Second angle in radians.

    Returns:
        float: Circular mean of the two angles.
    """
    # Ensure a and b are in the range of 0 to 1
    assert 0 <= a <= 1
    assert 0 <= b <= 1
    # Calculate the circular mean using a conditional expression
    mu = (a + b + 1) / 2 if abs(a - b) > 0.5 else (a + b) / 2
    # Adjust the result to be in the range [0, 1)
    return mu if mu >= 1 else mu - 1


def _compute_avg(a: Union[float, npt.NDArray], b: Union[float, npt.NDArray],
                 circular_channels: Optional[List] = None) -> Union[float, npt.NDArray]:
    if circular_channels is None:
        circular_channels = []
    if type(a) is np.ndarray:
        a, b = a.copy().squeeze(), b.copy().squeeze()
    c = (a + b) / 2
    for i in circular_channels:
        c[i] = _circular_mean(a[i], b[i])
    return c


def midpoint(
        acc: float, low: Union[float, npt.NDArray], mid: Union[float, npt.NDArray],
        high: Union[float, npt.NDArray], th: float, ep: Optional[float] = 1e-4,
        circular_channels: Optional[List] = None
) -> (
        Union[float, npt.NDArray, None], Union[float, npt.NDArray, None],
        Union[float, npt.NDArray, None]
):
    """
    Compute new midpoints for a given accuracy in a binary search.

    Parameters:
        acc (float): Current accuracy.
        low (Union[float, npt.NDArray]): Low value in the search space.
        mid (Union[float, npt.NDArray]): Midpoint in the search space.
        high (Union[float, npt.NDArray]): High value in the search space.
        th (float): Target accuracy.
        ep (Optional[float]): Acceptable range around the target accuracy. Default is 1e-4.
        circular_channels (Optional[List]): List of circular channels. Default is None.

    Returns:
        (Union[float, npt.NDArray, None], Union[float, npt.NDArray, None], Union[float, npt.NDArray, None]):
        Tuple containing the updated low, mid, and high values.
        If the accuracy is within the acceptable range of the target accuracy, returns
        (None, None, None).
    """
    # Calculate the difference between the current accuracy and the target accuracy
    diff_acc = acc - th

    # Check if the accuracy is within the acceptable range of the target accuracy
    if abs(diff_acc) < ep:
        return None, None, None

    # Check if the current accuracy is above the target accuracy
    if diff_acc > 0:
        # Compute the new midpoint by averaging the current low and mid values
        new_mid = _compute_avg(low, mid, circular_channels)

        # Update the low and mid values
        return low, new_mid, mid

    # Otherwise, the current accuracy is below the target accuracy
    else:
        # Compute the new midpoint by averaging the current mid and high values
        new_mid = _compute_avg(high, mid, circular_channels)

        # Update the mid and high values
        return mid, new_mid, high


[docs]def train_linear_probe( model: ProbeNet, dataset: Union[TorchDataset, TorchDataLoader], epoch_loop: Callable[[nn.Module, TorchDataLoader, Any, torch.device], Dict], out_dir: str, device: Optional[torch.device] = None, epochs: Optional[int] = 10, optimiser: Optional[torch.optim.Optimizer] = None, scheduler: Optional[lr_scheduler.LRScheduler] = None ) -> Dict: """ Train a linear probe on top of a frozen backbone model. Parameters: model (ProbeNet): Linear probe model. dataset (Union[TorchDataset, TorchDataLoader]): Training dataset or data loader. epoch_loop (Callable): Function defining the training loop for one epoch. This function must accept for positional arguments (i.e., model, train_loader, optimiser, device). This function should return a dictionary. out_dir (str): Output directory to save checkpoints. device (Optional[torch.device]): Device on which to perform training. epochs (Optional[int]): Number of training epochs. Default is 10. optimiser (Optional[torch.optim.Optimizer]): Optimization algorithm. Default is SGD. scheduler (Optional[lr_scheduler.LRScheduler]): Learning rate scheduler. Default is MultiStepLR at 50 and 80% of epochs Returns: Dict: Training logs containing statistics. """ # Data loading if isinstance(dataset, TorchDataLoader): train_loader = dataset else: train_loader = TorchDataLoader( dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True, sampler=None ) # Optimisation model.freeze_backbone() if optimiser is None: # Create an optimiser for the linear probe parameters params_to_optimize = [{'params': [p for p in model.fc.parameters()]}] optimiser = torch.optim.SGD(params_to_optimize, lr=0.1, momentum=0.9, weight_decay=1e-4) if scheduler is None: # Create a learning rate scheduler milestones = [int(epochs * e) for e in [0.5, 0.8]] scheduler = lr_scheduler.MultiStepLR(optimiser, milestones=milestones, gamma=0.1) # Training loop training_logs = dict() for epoch in range(epochs): # Run an epoch of training train_log = epoch_loop(model, train_loader, optimiser, device) # Update the learning rate scheduler scheduler.step() # Store training statistics for log_key, log_val in train_log.items(): if log_key not in training_logs: training_logs[log_key] = [] training_logs[log_key].append(np.mean(log_val)) # Print training summary log_str = ' '.join('%s=%.3f' % (key, val[-1]) for key, val in training_logs.items()) print('[%.3d] %s' % (epoch, log_str)) # Save checkpoints os.makedirs(out_dir, exist_ok=True) file_path = os.path.join(out_dir, 'checkpoint.pth.tar') torch.save({ 'epoch': epoch, 'network': model.serialisation_params(), 'optimizer': optimiser.state_dict(), 'scheduler': scheduler.state_dict(), 'log': training_logs }, file_path) return training_logs