Source code for osculari.models.pretrained_models

"""
A wrapper around publicly available pretrained models in PyTorch.
"""

import os
from typing import Optional, List, Dict, Union, Tuple
from itertools import chain

import torch
import torch.nn as nn
from torch.utils import model_zoo

from torchvision import models as torch_models
from visualpriors import taskonomy_network
import clip

from . import model_utils, pretrained_layers

_TORCHVISION_SEGMENTATION = [
    'deeplabv3_mobilenet_v3_large',
    'deeplabv3_resnet101',
    'deeplabv3_resnet50',
    'fcn_resnet101',
    'fcn_resnet50',
    # TODO: add lrassp, the problem is it has two outputs, low and high
    # 'lraspp_mobilenet_v3_large'
]

_TORCHVISION_IMAGENET = [
    'alexnet',
    'convnext_base',
    'convnext_large',
    'convnext_small',
    'convnext_tiny',
    'densenet121',
    'densenet161',
    'densenet169',
    'densenet201',
    'efficientnet_b0',
    'efficientnet_b1',
    'efficientnet_b2',
    'efficientnet_b3',
    'efficientnet_b4',
    'efficientnet_b5',
    'efficientnet_b6',
    'efficientnet_b7',
    'efficientnet_v2_l',
    'efficientnet_v2_m',
    'efficientnet_v2_s',
    'googlenet',
    'inception_v3',
    'maxvit_t',
    'mnasnet0_5',
    'mnasnet0_75',
    'mnasnet1_0',
    'mnasnet1_3',
    'mobilenet_v2',
    'mobilenet_v3_large',
    'mobilenet_v3_small',
    'regnet_x_16gf',
    'regnet_x_1_6gf',
    'regnet_x_32gf',
    'regnet_x_3_2gf',
    'regnet_x_400mf',
    'regnet_x_800mf',
    'regnet_x_8gf',
    'regnet_y_128gf',
    'regnet_y_16gf',
    'regnet_y_1_6gf',
    'regnet_y_32gf',
    'regnet_y_3_2gf',
    'regnet_y_400mf',
    'regnet_y_800mf',
    'regnet_y_8gf',
    'resnet101',
    'resnet152',
    'resnet18',
    'resnet34',
    'resnet50',
    'resnext101_32x8d',
    'resnext101_64x4d',
    'resnext50_32x4d',
    'shufflenet_v2_x0_5',
    'shufflenet_v2_x1_0',
    'shufflenet_v2_x1_5',
    'shufflenet_v2_x2_0',
    'squeezenet1_0',
    'squeezenet1_1',
    'swin_b',
    'swin_s',
    'swin_t',
    'swin_v2_b',
    'swin_v2_s',
    'swin_v2_t',
    'vgg11',
    'vgg11_bn',
    'vgg13',
    'vgg13_bn',
    'vgg16',
    'vgg16_bn',
    'vgg19',
    'vgg19_bn',
    'vit_b_16',
    'vit_b_32',
    'vit_h_14',
    'vit_l_16',
    'vit_l_32',
    'wide_resnet101_2',
    'wide_resnet50_2'
]

__all__ = [
    'available_models'
]


[docs]def available_models(flatten: Optional[bool] = False) -> Union[Dict, List]: """ List of supported models. Parameters: flatten (bool, optional): If True, returns a flattened list of all supported models. If False, returns a dictionary of model categories. Returns: Union[Dict, List]: List of supported models or a dictionary of model categories. """ # Define categories and corresponding models all_models = { 'imagenet': _TORCHVISION_IMAGENET, 'segmentation': _TORCHVISION_SEGMENTATION, 'taskonomy': ['taskonomy_%s' % t for t in taskonomy_network.LIST_OF_TASKS], 'clip': ['clip_%s' % c for c in clip.available_models()] } # Return either a flattened list or the dictionary of model categories return list(chain.from_iterable(all_models.values())) if flatten else all_models
class ViTLayers(nn.Module): def __init__(self, parent_model: nn.Module, layer: str) -> None: """ Initialize the ViTLayers module. Parameters: parent_model (nn.Module): The parent ViT model. layer (str): The layer to extract features up to (format: 'blockX' where X is the block number). Raises: RuntimeError: If the specified layer exceeds the total number of blocks in the parent model. """ super(ViTLayers, self).__init__() self.parent_model = parent_model block = int(layer.replace('block', '')) + 1 max_blocks = len(self.parent_model.encoder.layers) if block > max_blocks: raise RuntimeError( f'Layer {layer} exceeds the total number of {max_blocks} blocks.' ) self.parent_model.encoder.layers = self.parent_model.encoder.layers[:block] delattr(self.parent_model, 'heads') def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the ViT model up to the specified layer. Parameters: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after processing up to the specified layer. """ # Reshape and permute the input tensor x = self.parent_model._process_input(x) n = x.shape[0] # Expand the class token to the full batch batch_class_token = self.parent_model.class_token.expand(n, -1, -1) x = torch.cat([batch_class_token, x], dim=1) x = self.parent_model.encoder(x) return x class ViTClipLayers(nn.Module): def __init__(self, parent_model: nn.Module, layer: str) -> None: """ Initialize the ViTClipLayers module. Parameters: parent_model (nn.Module): The parent ViT-Clip model. layer (str): The layer to extract features up to (format: 'blockX' where X is the block number). Raises: RuntimeError: If the specified layer exceeds the total number of blocks in the parent model. """ super(ViTClipLayers, self).__init__() self.parent_model = parent_model block = int(layer.replace('block', '')) + 1 max_blocks = len(self.parent_model.transformer.resblocks) if block > max_blocks: raise RuntimeError( f'Layer {layer} exceeds the total number of {max_blocks} blocks.' ) self.parent_model.transformer.resblocks = self.parent_model.transformer.resblocks[:block] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the ViT-Clip model up to the specified layer. Parameters: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after processing up to the specified layer. """ x = self.parent_model.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( [self.parent_model.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1 ) # shape = [*, grid ** 2 + 1, width] x = x + self.parent_model.positional_embedding.to(x.dtype) x = self.parent_model.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.parent_model.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD return x class AuxiliaryLayers(nn.Module): def __init__(self, parent_model: nn.Module) -> None: """ Initialize the AuxiliaryLayers module (e.g., GoogLeNet, Inception, segmentations). Parameters: parent_model (nn.Module): The parent model. """ super(AuxiliaryLayers, self).__init__() self.parent_model = parent_model def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the model up to the specified layer and returning only the output. Parameters: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after processing up to the specified layer. """ # Reshape and permute the input tensor x = self.parent_model(x) if isinstance(x, (torch_models.GoogLeNetOutputs, torch_models.InceptionOutputs)): x = getattr(x, 'logits') elif isinstance(x, dict) and 'out' in x: # segmentation models x = x['out'] return x def _vit_features(model: nn.Module, layer: str) -> ViTLayers: """Creating a feature extractor from ViT network.""" if layer == 'conv_proj': return model.conv_proj return ViTLayers(model, layer) def _swin_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from SwinTransformer network.""" layer = int(layer.replace('block', '')) + 1 return nn.Sequential(*list(model.features.children())[:layer], model.permute) def _sequential_features(model: nn.Module, layer: str, architecture: str, avgpool: Optional[bool] = True) -> nn.Module: """Creating a feature extractor from sequential network.""" if 'feature' in layer: layer = int(layer.replace('feature', '')) + 1 features = nn.Sequential(*list(model.features.children())[:layer]) elif 'classifier' in layer: layer = int(layer.replace('classifier', '')) + 1 avgpool_layers = [model.avgpool, nn.Flatten(1)] if avgpool else [] features = nn.Sequential( model.features, *avgpool_layers, *list(model.classifier.children())[:layer] ) else: raise RuntimeError('Unsupported %s layer %s' % (architecture, layer)) return features def _vgg_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from VGG network.""" return _sequential_features(model, layer, 'vgg') def _alexnet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from AlexNet network.""" return _sequential_features(model, layer, 'alexnet') def _mobilenet_features(model: nn.Module, layer: str, architecture: str) -> nn.Module: """Creating a feature extractor from MobileNet network.""" if architecture in ['lraspp_mobilenet_v3_large', 'deeplabv3_mobilenet_v3_large']: layer = int(layer.replace('feature', '')) + 1 return nn.Sequential(*list(model.parent_model.children())[:layer]) return _sequential_features(model, layer, 'mobilenet') def _convnext_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from ComvNeXt network.""" return _sequential_features(model, layer, 'convnext') def _squeezenet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from SqueezeNet network.""" return _sequential_features(model, layer, 'squeezenet', avgpool=False) def _efficientnet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from EfficientNet network.""" return _sequential_features(model, layer, 'efficientnet') def _googlenet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from GoogLeNet network.""" l_ind = pretrained_layers.googlenet_cutoff_slice(layer) return nn.Sequential(*list(model.children())[:l_ind]) def _inception_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from Inception network.""" l_ind = pretrained_layers.inception_cutoff_slice(layer) return nn.Sequential(*list(model.children())[:l_ind]) def _mnasnet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from MnasNet network.""" l_ind = int(layer.replace('layer', '')) + 1 return nn.Sequential(*list(model.layers.children())[:l_ind]) def _shufflenet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from ShuffleNet network.""" l_ind = int(layer.replace('layer', '')) + 1 return nn.Sequential(*list(model.children())[:l_ind]) def _densenet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from DenseNet network.""" return _sequential_features(model, layer, 'densenet') def _clip_features(model: nn.Module, layer: str, architecture: str) -> nn.Module: """Creating a feature extractor from CLIP network.""" clip_arch = architecture.replace('clip_', '') if layer == 'encoder': features = model else: if clip_arch in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64']: features = _resnet_features(model, layer, architecture) elif layer == 'conv_proj': features = model.conv1 else: features = ViTClipLayers(model, layer) return features def _maxvit_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from MaxVit network.""" if 'stem' in layer: features = model.stem elif 'block' in layer: layer = int(layer.replace('block', '')) features = nn.Sequential(model.stem, *list(model.blocks.children())[:layer]) elif 'classifier' in layer: layer = int(layer.replace('classifier', '')) + 1 features = nn.Sequential( model.stem, *list(model.blocks.children()), *list(model.classifier.children())[:layer] ) else: raise RuntimeError('Unsupported regnet layer %s' % layer) return features def _regnet_features(model: nn.Module, layer: str) -> nn.Module: """Creating a feature extractor from RegNet network.""" if 'stem' in layer: features = model.stem elif 'block' in layer: layer = int(layer.replace('block', '')) features = nn.Sequential(model.stem, *list(model.trunk_output.children())[:layer]) else: raise RuntimeError('Unsupported regnet layer %s' % layer) return features def _resnet_features(model: nn.Module, layer: str, architecture: str) -> nn.Module: """Creating a feature extractor from ResNet network.""" is_clip = 'clip' in architecture l_ind = pretrained_layers.resnet_cutoff_slice(layer, is_clip=is_clip) if architecture in _TORCHVISION_SEGMENTATION: return nn.Sequential(*list(model.parent_model.children())[:l_ind]) return nn.Sequential(*list(model.children())[:l_ind]) def model_features(model: nn.Module, architecture: str, layer: str) -> nn.Module: """ Return features extracted from one layer. Parameters: model (nn.Module): The pretrained model. architecture (str): Name of the architecture. layer (str): Name of the layer. Raises: RuntimeError: If the specified layer is not supported for the given architecture. Returns: nn.Module: The features extracted from the specified layer. """ if layer not in pretrained_layers.available_layers(architecture): raise RuntimeError( 'Layer %s is not supported for architecture %s. Call pretrained_layers.available_layers' ' to see a list of supported layers for an architecture.' % (layer, architecture) ) if architecture in _TORCHVISION_IMAGENET and layer == 'fc': features = model elif 'clip' in architecture: features = _clip_features(model, layer, architecture) else: if model_utils.is_resnet_backbone(architecture): features = _resnet_features(model, layer, architecture) elif 'regnet' in architecture: features = _regnet_features(model, layer) elif 'vgg' in architecture: features = _vgg_features(model, layer) elif architecture == 'alexnet': features = _alexnet_features(model, layer) elif architecture == 'googlenet': features = _googlenet_features(model, layer) elif architecture == 'inception_v3': features = _inception_features(model, layer) elif 'convnext' in architecture: features = _convnext_features(model, layer) elif 'densenet' in architecture: features = _densenet_features(model, layer) elif 'mnasnet' in architecture: features = _mnasnet_features(model, layer) elif 'shufflenet' in architecture: features = _shufflenet_features(model, layer) elif 'squeezenet' in architecture: features = _squeezenet_features(model, layer) elif 'efficientnet' in architecture: features = _efficientnet_features(model, layer) elif 'mobilenet' in architecture: features = _mobilenet_features(model, layer, architecture) elif 'maxvit' in architecture: features = _maxvit_features(model, layer) elif 'swin_' in architecture: features = _swin_features(model, layer) elif 'vit_' in architecture: features = _vit_features(model, layer) else: raise RuntimeError('Unsupported network %s' % architecture) return features def _taskonomy_weights(network_name: str, weights: str) -> Union[str, None]: """Handling default Taskonomy weights.""" if weights is None: return None if network_name == weights: feature_task = weights.replace('taskonomy_', '') weights = taskonomy_network.TASKONOMY_PRETRAINED_URLS[feature_task + '_encoder'] return weights def _torchvision_weights(network_name: str, weights: str) -> Union[str, None]: """Handling default torchvision weights.""" if weights == 'none': return None return 'DEFAULT' if network_name == weights else weights def _load_weights(model: nn.Module, weights: str) -> nn.Module: """Loading the weights of a network from URL or local file.""" if weights in ['none', None]: pass elif 'https://' in weights or 'http://' in weights: checkpoint = model_zoo.load_url(weights, model_dir=None, progress=True) model.load_state_dict(checkpoint['state_dict']) elif os.path.exists(weights): checkpoint = torch.load(weights, map_location='cpu') model.load_state_dict(checkpoint['state_dict']) else: raise RuntimeError('Unknown weights: %s' % weights) return model def get_pretrained_model(network_name: str, weights: str) -> nn.Module: """ Load a network with or without pretrained weights. Parameters: network_name (str): Name of the network. weights (str): Path to the pretrained weights file. Raises: RuntimeError: If the specified network is not supported. Returns: nn.Module: The pretrained model. """ if network_name not in available_models(flatten=True): raise RuntimeError('Network %s is not supported.' % network_name) if 'clip_' in network_name: # Load CLIP model # TODO: support for None clip_version = network_name.replace('clip_', '') device = "cuda" if torch.cuda.is_available() and weights not in ['none', None] else "cpu" model, _ = clip.load(clip_version, device=device) elif 'taskonomy_' in network_name: # Load Taskonomy model model = taskonomy_network.TaskonomyEncoder() model = _load_weights(model, _taskonomy_weights(network_name, weights)) else: # Load torchvision networks weights = _torchvision_weights(network_name, weights) net_fun = torch_models.segmentation if network_name in _TORCHVISION_SEGMENTATION else torch_models model = net_fun.__dict__[network_name](weights=weights) return model def get_image_encoder(network_name: str, model: nn.Module) -> nn.Module: """ Returns the encoder block of a network. Parameters: network_name (str): Name of the network. model (nn.Module): The pretrained model. Returns: nn.Module: The encoder block of the network. """ if 'clip' in network_name: return model.visual elif network_name in _TORCHVISION_SEGMENTATION: return AuxiliaryLayers(model.backbone) elif network_name in ['googlenet', 'inception_v3']: model.AuxLogits = None return AuxiliaryLayers(model) return model def preprocess_mean_std(network_name: str) -> (Tuple[float], Tuple[float]): """ Returns the mean and std used in pretrained preprocessing. Parameters: network_name (str): Name of the network. Returns: Tuple[float], Tuple[float]: Mean and std for preprocessing. """ if 'clip_' in network_name: mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) elif 'taskonomy_' in network_name: mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) elif network_name in [ *torch_models.list_models(module=torch_models.segmentation), *torch_models.list_models(module=torch_models) ]: mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) else: raise RuntimeError('The preprocess for architecture %s is unknown.' % network_name) return mean, std