"""
A collection of adaptive psychophysical experimental methods.
"""
import numpy as np
import numpy.typing as npt
from typing import Callable, Optional, Dict, Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader as TorchDataLoader
from . import paradigm_utils
__all__ = [
'staircase'
]
[docs]def staircase(model: nn.Module,
test_fun: Callable[[nn.Module, TorchDataLoader, torch.device], Dict],
dataset_fun: Callable[[float], Tuple],
low_val: float, high_val: float, device: Optional[torch.device] = None,
max_attempts: Optional[int] = 20) -> npt.NDArray:
"""
Computes the psychometric function following the staircase procedure.
Args:
model: The neural network model to be evaluated.
test_fun: Function for evaluating the model. This function must accept three
positional arguments (i.e., model, db_loader, device). The output of this function
should be a dictionary containing the key `accuracy`.
dataset_fun: Function for creating the dataset and dataloader. This function must
accept one argument (mid_val, i.e., the current value to be tested). This funtion must
return a tuple of three elements (i.e., dataset, batch_size, threshold).
low_val: The lower bound of the stimulus range.
high_val: The upper bound of the stimulus range.
device: The device to which the model and data will be transferred
(default: CUDA if available) (optional).
max_attempts: The maximum number of attempts allowed in the staircase procedure.
Returns:
A NumPy array containing the psychometric function data points containing two columns,
the first column tested values and the second column obtained accuracies.
"""
# Set the device
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Calculate the midpoint of the initial stimulus range
mid_val = (low_val + high_val) / 2
# List to store the psychometric function data points
results = []
# Number of attempts to perform the staircase procedure
attempt_num = 1
# Perform the staircase procedure until convergence
while True:
# Create the dataset and dataloader for the current midpoint
dataset, batch_size, th = dataset_fun(mid_val)
db_loader = TorchDataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
)
# Evaluate the model on the current midpoint
test_log = test_fun(model, db_loader, device)
accuracy = np.mean(test_log['accuracy'])
# Check if accuracy is within the acceptable range
if 1 < accuracy or accuracy < 0:
raise RuntimeError('Accuracy for staircase procedure must be between 0 and 1.')
# Append the current midpoint and accuracy to the results list
results.append(np.array([mid_val, accuracy]))
# Calculate the new stimulus range for the next iteration
new_low, new_mid, new_high = paradigm_utils.midpoint(
accuracy, low_val, mid_val, high_val, th=th
)
# Check if the procedure has converged or reached the maximum number of attempts
if new_mid is None or attempt_num == max_attempts:
break
else:
# Update the stimulus range
low_val, mid_val, high_val = new_low, new_mid, new_high
# Increment the attempt counter
attempt_num += 1
# Return the psychometric function data points
return np.array(results)