Source code for sequifier.config.train_config

import copy
import json
from typing import Any, Optional

import numpy as np
import yaml
from beartype import beartype
from pydantic import BaseModel, Field, validator

from sequifier.helpers import normalize_path

AnyType = str | int | float

VALID_LOSS_FUNCTIONS = [
    "L1Loss",
    "MSELoss",
    "CrossEntropyLoss",
    "CTCLoss",
    "NLLLoss",
    "PoissonNLLLoss",
    "GaussianNLLLoss",
    "KLDivLoss",
    "BCELoss",
    "BCEWithLogitsLoss",
    "MarginRankingLoss",
    "HingeEmbeddingLoss",
    "MultiLabelMarginLoss",
    "HuberLoss",
    "SmoothL1Loss",
    "SoftMarginLoss",
    "MultiLabelSoftMarginLoss",
    "CosineEmbeddingLoss",
    "MultiMarginLoss",
    "TripletMarginLoss",
    "TripletMarginWithDistanceLoss",
]

VALID_OPTIMIZERS = [
    "Adadelta",
    "Adagrad",
    "Adam",
    "AdamW",
    "SparseAdam",
    "Adamax",
    "ASGD",
    "LBFGS",
    "NAdam",
    "RAdam",
    "RMSprop",
    "Rprop",
    "SGD",
    "AdEMAMix",
    "A2GradUni",
    "A2GradInc",
    "A2GradExp",
    "AccSGD",
    "AdaBelief",
    "AdaBound",
    "Adafactor",
    "Adahessian",
    "AdaMod",
    "AdamP",
    "AggMo",
    "Apollo",
    "DiffGrad",
    "Lamb",
    "LARS",
    "Lion",
    "Lookahead",
    "MADGRAD",
    "NovoGrad",
    "PID",
    "QHAdam",
    "QHM",
    "RAdam",
    "SGDP",
    "SGDW",
    "Shampoo",
    "SWATS",
    "Yogi",
]

VALID_SCHEDULERS = [
    "LambdaLR",
    "MultiplicativeLR",
    "StepLR",
    "MultiStepLR",
    "ConstantLR",
    "LinearLR",
    "ExponentialLR",
    "PolynomialLR",
    "CosineAnnealingLR",
    "ChainedScheduler",
    "SequentialLR",
    "ReduceLROnPlateau",
    "CyclicLR",
    "OneCycleLR",
    "CosineAnnealingWarmRestarts",
]


@beartype
def load_train_config(
    config_path: str, args_config: dict[str, Any], on_unprocessed: bool
) -> "TrainModel":
    """
    Load training configuration from a YAML file and update it with args_config.

    Args:
        config_path: Path to the YAML configuration file.
        args_config: Dictionary containing additional configuration arguments.
        on_unprocessed: Flag indicating whether to process the configuration or not.

    Returns:
        TrainModel instance with loaded configuration.
    """
    with open(config_path, "r") as f:
        config_values = yaml.safe_load(f)

    config_values.update(args_config)

    if not on_unprocessed:
        ddconfig_path = config_values.get("ddconfig_path")

        with open(
            normalize_path(ddconfig_path, config_values["project_path"]), "r"
        ) as f:
            dd_config = json.loads(f.read())

        split_paths = dd_config["split_paths"]

        config_values["column_types"] = config_values.get(
            "column_types", dd_config["column_types"]
        )

        if config_values["selected_columns"] is None:
            config_values["selected_columns"] = list(
                config_values["column_types"].keys()
            )

        config_values["categorical_columns"] = [
            col
            for col, type_ in dd_config["column_types"].items()
            if "int64" in type_.lower() and col in config_values["selected_columns"]
        ]
        config_values["real_columns"] = [
            col
            for col, type_ in dd_config["column_types"].items()
            if "float64" in type_.lower() and col in config_values["selected_columns"]
        ]
        assert (
            len(config_values["real_columns"] + config_values["categorical_columns"])
            > 0
        )
        config_values["n_classes"] = config_values.get(
            "n_classes", dd_config["n_classes"]
        )
        config_values["training_data_path"] = normalize_path(
            config_values.get("training_data_path", split_paths[0]),
            config_values["project_path"],
        )
        config_values["validation_data_path"] = normalize_path(
            config_values.get(
                "validation_data_path",
                split_paths[min(1, len(split_paths) - 1)],
            ),
            config_values["project_path"],
        )

        config_values["id_maps"] = dd_config["id_maps"]

    return TrainModel(**config_values)


class DotDict(dict):
    """Dot notation access to dictionary attributes."""

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__  # type: ignore
    __delattr__ = dict.__delitem__  # type: ignore

    def __deepcopy__(self, memo=None):
        return DotDict(copy.deepcopy(dict(self), memo=memo))


[docs] class TrainingSpecModel(BaseModel): """Pydantic model for training specifications. Attributes: device: The torch.device to train the model on (e.g., 'cuda', 'cpu', 'mps'). device_max_concat_length: Maximum sequence length for concatenation on device. epochs: The total number of epochs to train for. log_interval: The interval in batches for logging. class_share_log_columns: A list of column names for which to log the class share of predictions. early_stopping_epochs: Number of epochs to wait for validation loss improvement before stopping. iter_save: The interval in epochs for checkpointing the model. batch_size: The training batch size. lr: The learning rate. criterion: A dictionary mapping each target column to a loss function. class_weights: A dictionary mapping categorical target columns to a list of class weights. accumulation_steps: The number of gradient accumulation steps. dropout: The dropout value for the transformer model. loss_weights: A dictionary mapping columns to specific loss weights. optimizer: The optimizer configuration. scheduler: The learning rate scheduler configuration. continue_training: If True, continue training from the latest checkpoint. distributed: If True, enables distributed training. load_full_data_to_ram: If True, loads the entire dataset into RAM. world_size: The number of processes for distributed training. num_workers: The number of worker threads for data loading. backend: The distributed training backend (e.g., 'nccl'). """ device: str device_max_concat_length: int = 12 epochs: int log_interval: int = 10 class_share_log_columns: list[str] = Field(default_factory=list) early_stopping_epochs: Optional[int] = None iter_save: int batch_size: int lr: float criterion: dict[str, str] class_weights: Optional[dict[str, list[float]]] = None accumulation_steps: Optional[int] = None dropout: float = 0.0 loss_weights: Optional[dict[str, float]] = None optimizer: DotDict = Field(default_factory=lambda: DotDict({"name": "Adam"})) scheduler: DotDict = Field( default_factory=lambda: DotDict( {"name": "StepLR", "step_size": 1, "gamma": 0.99} ) ) continue_training: bool = True distributed: bool = False load_full_data_to_ram: bool = True world_size: int = 1 num_workers: int = 0 backend: str = "nccl" def __init__(self, **kwargs): super().__init__( **{k: v for k, v in kwargs.items() if k not in ["optimizer", "scheduler"]} ) self.validate_optimizer_config(kwargs["optimizer"]) self.optimizer = DotDict(kwargs["optimizer"]) self.validate_scheduler_config(kwargs["scheduler"]) self.scheduler = DotDict(kwargs["scheduler"]) @validator("criterion") def validate_criterion(cls, v): for vv in v.values(): if vv not in VALID_LOSS_FUNCTIONS: raise ValueError( f"criterion must be in {VALID_LOSS_FUNCTIONS}, {vv} isn't" ) return v @validator("optimizer") def validate_optimizer_config(cls, v): if "name" not in v: raise ValueError("optimizer dict must specify 'name' field") if v["name"] not in VALID_OPTIMIZERS: raise ValueError(f"optimizer not valid as not found in {VALID_OPTIMIZERS}") return v @validator("scheduler") def validate_scheduler_config(cls, v): if "name" not in v: raise ValueError("scheduler dict must specify 'name' field") if v["name"] not in VALID_SCHEDULERS: raise ValueError(f"scheduler not valid as not found in {VALID_SCHEDULERS}") return v
[docs] class ModelSpecModel(BaseModel): """Pydantic model for model specifications. Attributes: d_model: The number of expected features in the input. d_model_by_column: The embedding dimensions for each input column. Must sum to d_model. nhead: The number of heads in the multi-head attention models. d_hid: The dimension of the feedforward network model. nlayers: The number of layers in the transformer model. """ d_model: int d_model_by_column: Optional[dict[str, int]] nhead: int d_hid: int nlayers: int @validator("d_model_by_column") def validate_d_model_by_column(cls, v, values): assert ( v is None or np.sum(list(v.values())) == values["d_model"] ), f'{values["d_model"]} is not the sum of the d_model_by_column values' return v
[docs] class TrainModel(BaseModel): """Pydantic model for training configuration. Attributes: project_path: The path to the sequifier project directory. ddconfig_path: The path to the data-driven configuration file. model_name: The name of the model being trained. training_data_path: The path to the training data. validation_data_path: The path to the validation data. read_format: The file format of the input data (e.g., 'csv', 'parquet'). selected_columns: The list of input columns to be used for training. column_types: A dictionary mapping each column to its numeric type ('int64' or 'float64'). categorical_columns: A list of columns that are categorical. real_columns: A list of columns that are real-valued. target_columns: The list of target columns for model training. target_column_types: A dictionary mapping target columns to their types ('categorical' or 'real'). id_maps: For each categorical column, a map from distinct values to their indexed representation. seq_length: The sequence length of the model's input. n_classes: The number of classes for each categorical column. inference_batch_size: The batch size to be used for inference after model export. seed: The random seed for numpy and PyTorch. export_generative_model: If True, exports the generative model. export_embedding_model: If True, exports the embedding model. export_onnx: If True, exports the model in ONNX format. export_pt: If True, exports the model using torch.save. export_with_dropout: If True, exports the model with dropout enabled. model_spec: The specification of the transformer model architecture. training_spec: The specification of the training run configuration. """ project_path: str ddconfig_path: str model_name: str training_data_path: str validation_data_path: str read_format: str = "parquet" selected_columns: list[str] column_types: dict[str, str] categorical_columns: list[str] real_columns: list[str] target_columns: list[str] target_column_types: dict[str, str] id_maps: dict[str, dict[str | int, int]] seq_length: int n_classes: dict[str, int] inference_batch_size: int seed: int export_generative_model: bool export_embedding_model: bool export_onnx: bool = True export_pt: bool = False export_with_dropout: bool = False model_spec: ModelSpecModel training_spec: TrainingSpecModel @validator("model_name") def validate_model_name(cls, v): assert "embedding" not in v, "model_name cannot contain 'embedding'" return v @validator("target_column_types") def validate_target_column_types(cls, v, values): assert all(vv in ["categorical", "real"] for vv in v.values()) assert ( list(v.keys()) == values["target_columns"] ), "target_columns and target_column_types must contain the same values/keys in the same order" return v @validator("read_format") def validate_read_format(cls, v): assert v in [ "csv", "parquet", "pt", ], "Currently only 'csv', 'parquet' and 'pt' are supported" return v @validator("training_spec") def validate_training_spec(cls, v, values): assert set(values["target_columns"]) == set( v.criterion.keys() ), "target_columns and criterion must contain the same values/keys" if v.distributed: assert ( values["read_format"] == "pt" ), "If distributed is set to 'true', the format has to be 'pt'" return v @validator("column_types") def validate_column_types(cls, v, values): target_columns = values.get("target_columns", []) column_ordered = list(v.keys()) columns_ordered_filtered = [c for c in column_ordered if c in target_columns] assert ( columns_ordered_filtered == target_columns ), f"{columns_ordered_filtered = } != {target_columns = }" return v @validator("model_spec") def validate_model_spec(cls, v, values): assert ( values["selected_columns"] is None or (v.d_model_by_column is None) or np.all( np.array(list(v.d_model_by_column.keys())) == np.array(list(values["selected_columns"])) ) ) return v