Source code for osculari.datasets.gratings

"""
Collection of gratings datasets common is psychophysical studies.
"""

import numpy as np
import numpy.typing as npt
from typing import Optional, Callable, Sequence, Union

import torch
from torch.utils.data import Dataset as TorchDataset

from . import dataset_utils

__all__ = [
    'GratingsDataset'
]


def sinusoid_grating(img_size: int, amplitude: float, theta: float, phase: float,
                     spatial_frequency: int) -> npt.NDArray[float]:
    """Generates sinusoidal grating stimuli.

    Args:
        img_size: The desired size of the grating image.
        amplitude: The amplitude of the sinusoidal modulation.
        theta: The orientation of the grating (float in radians).
        phase: The phase offset of the sinusoidal modulation.
        spatial_frequency: The spatial frequency of the grating (int cycles per image).

    Returns:
        The generated sinusoidal grating stimuli.
    """

    # Generate the grid coordinates
    radius = img_size // 2
    [x, y] = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1))

    # Compute the frequency and phase parameters
    omega = [np.cos(theta), np.sin(theta)]
    lambda_wave = (img_size * 0.5) / (np.pi * spatial_frequency)

    # Calculate the sinusoidal modulation
    stimuli = amplitude * np.cos((omega[0] * x + omega[1] * y) / lambda_wave + phase)

    # If the target size is even, the generated stimuli is 1 pixel larger.
    if np.mod(img_size, 2) == 0:
        stimuli = stimuli[:-1, :-1]

    return stimuli


def gaussian_img(img_size: int, sigma: float) -> npt.NDArray[float]:
    """Generates a Gaussian-filtered image.

    Args:
        img_size: The desired size of the Gaussian image (int).
        sigma: The standard deviation of the Gaussian filter (float).

    Returns:
        The generated Gaussian-filtered image (np.ndarray).
    """

    # Generate the grid coordinates
    radius = img_size // 2
    [x, y] = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1))

    # Compute the Gaussian filter
    gauss2d = np.exp(-(np.power(x, 2) + np.power(y, 2)) / (2 * np.power(sigma, 2)))

    # Make the size odd
    if np.mod(img_size, 2) == 0:
        gauss2d = gauss2d[:-1, :-1]

    # Normalize the intensity values to [0, 1]
    gauss2d = gauss2d / np.max(gauss2d)

    # Return the Gaussian-filtered image
    return gauss2d


[docs]class GratingsDataset(TorchDataset): """ A dataset class for generating and storing sinusoidal grating stimuli. Args: img_size: The desired size of the grating images (int). spatial_frequencies: A list of spatial frequencies for the gratings (optional). thetas: A list of orientations for the gratings (optional). gaussian_sigma: The standard deviation of the Gaussian filter (optional). transform: A transformation to be applied to the stimuli (optional). """ def __init__(self, img_size: int, spatial_frequencies: Optional[Sequence[int]] = None, thetas: Optional[Sequence[float]] = None, gaussian_sigma: Optional[float] = None, transform: Optional[Callable] = None) -> None: super(GratingsDataset, self).__init__() self.img_size = img_size self.transform = transform self.sfs = [ i for i in range(1, img_size // 2 + 1) if img_size % i == 0 ] if spatial_frequencies is None else spatial_frequencies self.thetas = np.arange(0, np.pi + 1e-3, np.pi / 12) if thetas is None else thetas self.gaussian_sigma = gaussian_sigma def __len__(self) -> int: return len(self.thetas) * len(self.sfs)
[docs] def make_grating(self, idx: int, amplitude: float, channels=3) -> npt.NDArray[float]: """ Constructs a sinusoidal grating image. Args: idx: The index of the grating to be generated. amplitude: The amplitude of the sinusoidal modulation (float). channels: The number of output channels (int; optional). Returns: The generated sinusoidal grating image (np.ndarray). """ theta_ind, sf_ind = np.unravel_index(idx, (len(self.thetas), len(self.sfs))) theta = self.thetas[theta_ind] sf = self.sfs[sf_ind] phase = 0 stimuli = sinusoid_grating(self.img_size, amplitude, theta, phase, sf) # Apply Gaussian filtering if specified if self.gaussian_sigma is not None: gauss_img = gaussian_img(self.img_size, self.gaussian_sigma) stimuli *= gauss_img # Normalise the image intensity to [0, 1] stimuli = (stimuli + 1) / 2 # Repeat the image to multiple channels (if requested) stimuli = dataset_utils.repeat_channels(stimuli, channels=channels) return stimuli
def __getitem__(self, idx: int) -> Union[torch.Tensor, npt.NDArray]: """ Retrieves a grating image from the dataset. Args: idx: The index of the grating to be retrieved. Returns: The retrieved grating image (torch.Tensor or np.ndarray) after applying the specified transformation. """ stimuli = self.make_grating(idx, 1.0) if self.transform: stimuli = self.transform(stimuli) return stimuli