Source code for stouputils.data_science.models.all


# Imports
import itertools

from .keras.all import (
    VGG16,
    VGG19,
    ConvNeXtBase,
    ConvNeXtLarge,
    ConvNeXtSmall,
    ConvNeXtTiny,
    ConvNeXtXLarge,
    DenseNet121,
    DenseNet169,
    DenseNet201,
    EfficientNetB0,
    EfficientNetV2B0,
    EfficientNetV2L,
    EfficientNetV2M,
    EfficientNetV2S,
    MobileNet,
    MobileNetV2,
    MobileNetV3Large,
    MobileNetV3Small,
    ResNet50V2,
    ResNet101V2,
    ResNet152V2,
    SqueezeNet,
    Xception,
)
from .model_interface import ModelInterface

# Other models
from .sandbox import Sandbox


# Create a custom dictionary class to allow for documentation
[docs] class ModelClassMap(dict[type[ModelInterface], tuple[str, ...]]): pass
# Routine map CLASS_MAP: ModelClassMap = ModelClassMap({ SqueezeNet: ("squeezenet", "squeezenets", "all", "often"), DenseNet121: ("densenet121", "densenets", "all", "often", "good"), DenseNet169: ("densenet169", "densenets", "all", "often", "good"), DenseNet201: ("densenet201", "densenets", "all", "often", "good"), EfficientNetB0: ("efficientnetb0", "efficientnets", "all"), EfficientNetV2B0: ("efficientnetv2b0", "efficientnets", "all"), EfficientNetV2S: ("efficientnetv2s", "efficientnets", "all", "often"), EfficientNetV2M: ("efficientnetv2m", "efficientnets", "all", "often"), EfficientNetV2L: ("efficientnetv2l", "efficientnets", "all", "often"), ConvNeXtTiny: ("convnexttiny", "convnexts", "all", "often", "good"), ConvNeXtSmall: ("convnextsmall", "convnexts", "all", "often"), ConvNeXtBase: ("convnextbase", "convnexts", "all", "often", "good"), ConvNeXtLarge: ("convnextlarge", "convnexts", "all", "often"), ConvNeXtXLarge: ("convnextxlarge", "convnexts", "all", "often", "good"), VGG16: ("vgg16", "vggs", "all"), VGG19: ("vgg19", "vggs", "all"), MobileNet: ("mobilenet", "mobilenets", "all"), MobileNetV2: ("mobilenetv2", "mobilenets", "all", "often"), MobileNetV3Small: ("mobilenetv3small", "mobilenets", "all", "often"), MobileNetV3Large: ("mobilenetv3large", "mobilenets", "all", "often", "good"), ResNet50V2: ("resnet50v2", "resnetsv2", "resnets", "all", "often"), ResNet101V2: ("resnet101v2", "resnetsv2", "resnets", "all", "often"), ResNet152V2: ("resnet152v2", "resnetsv2", "resnets", "all", "often"), Xception: ("xception", "xceptions", "all", "often"), Sandbox: ("sandbox",), }) # All models names and aliases ALL_MODELS: list[str] = sorted(set(itertools.chain.from_iterable(v for v in CLASS_MAP.values()))) """ All models names and aliases found in the `CLASS_MAP` dictionary. """ # Additional docstring new_docstring: str = "\n\n" + "\n".join(f"- {k.__name__}: {v}" for k, v in CLASS_MAP.items()) ModelClassMap.__doc__ = "Dictionary mapping class to their names and aliases. " + new_docstring CLASS_MAP.__doc__ = ModelClassMap.__doc__