Source code for sequifier.config.infer_config

import json
import os
from typing import Optional, Union

import numpy as np
import yaml
from beartype import beartype
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator

from sequifier.helpers import normalize_path, try_catch_excess_keys


@beartype
def load_inferer_config(
    config_path: str, args_config: dict, skip_metadata: bool
) -> "InfererModel":
    """
    Load inferer configuration from a YAML file and update it with args_config.

    Args:
        config_path: Path to the YAML configuration file.
        args_config: Dictionary containing additional configuration arguments.
        skip_metadata: Flag indicating whether to process the configuration or not.

    Returns:
        InfererModel instance with loaded configuration.
    """
    with open(config_path, "r") as f:
        config_values = yaml.safe_load(f)
    config_values.update(args_config)

    config_values["seed"] = config_values.get("seed", 1010)

    if not skip_metadata:
        metadata_config_path = config_values.get("metadata_config_path")

        with open(
            normalize_path(metadata_config_path, config_values["project_root"]), "r"
        ) as f:
            metadata_config = json.load(f)

        config_values["column_types"] = config_values.get(
            "column_types", metadata_config["column_types"]
        )

        if config_values["input_columns"] is None:
            config_values["input_columns"] = list(config_values["column_types"].keys())

        config_values["categorical_columns"] = [
            col
            for col, type_ in metadata_config["column_types"].items()
            if "int64" in type_.lower() and col in config_values["input_columns"]
        ]
        config_values["real_columns"] = [
            col
            for col, type_ in metadata_config["column_types"].items()
            if "float64" in type_.lower() and col in config_values["input_columns"]
        ]
        if not (
            len(config_values["real_columns"] + config_values["categorical_columns"])
            > 0
        ):
            raise ValueError("No columns found in config")
        config_values["data_path"] = normalize_path(
            config_values.get(
                "data_path",
                metadata_config["split_paths"][
                    min(2, len(metadata_config["split_paths"]) - 1)
                ],
            ),
            config_values["project_root"],
        )

    return try_catch_excess_keys(config_path, InfererModel, config_values)


[docs] class InfererModel(BaseModel): """Pydantic model for inference configuration. Attributes: project_root: The path to the sequifier project directory. metadata_config_path: The path to the data-driven configuration file. model_path: The path to the trained model file(s). model_type: The type of model, either 'embedding' or 'generative'. data_path: The path to the data to be used for inference. training_config_path: The path to the training configuration file. read_format: The file format of the input data (e.g., 'csv', 'parquet'). write_format: The file format for the inference output. input_columns: The list of input columns used for inference. categorical_columns: A list of columns that are categorical. real_columns: A list of columns that are real-valued. target_columns: The list of target columns for inference. column_types: A dictionary mapping each column to its numeric type ('int64' or 'float64'). target_column_types: A dictionary mapping target columns to their types ('categorical' or 'real'). output_probabilities: If True, outputs the probability distributions for categorical target columns. map_to_id: If True, maps categorical output values back to their original IDs. seed: The random seed for reproducibility. device: The device to run inference on (e.g., 'cuda', 'cpu', 'mps'). seq_length: The sequence length of the model's input. inference_batch_size: The batch size for inference. distributed: If True, enables distributed inference. load_full_data_to_ram: If True, loads the entire dataset into RAM. world_size: The number of processes for distributed inference. num_workers: The number of worker threads for data loading. sample_from_distribution_columns: A list of columns from which to sample from the distribution. infer_with_dropout: If True, applies dropout during inference. autoregression: If True, performs autoregressive inference. autoregression_extra_steps: The number of additional steps for autoregressive inference. """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") project_root: str metadata_config_path: str model_path: Union[str, list[str]] model_type: str data_path: str training_config_path: str = Field(default="configs/train.yaml") read_format: str = Field(default="parquet") write_format: str = Field(default="csv") input_columns: list[str] categorical_columns: list[str] real_columns: list[str] target_columns: list[str] column_types: dict[str, str] target_column_types: dict[str, str] enforce_determinism: bool = Field(default=False) output_probabilities: bool = Field(default=False) map_to_id: bool = Field(default=True) seed: int device: str seq_length: int prediction_length: int = Field(default=1) inference_batch_size: int distributed: bool = False load_full_data_to_ram: bool = True world_size: int = 1 num_workers: int = 0 sample_from_distribution_columns: Optional[list[str]] = Field(default=None) infer_with_dropout: bool = Field(default=False) autoregression: bool = Field(default=False) autoregression_extra_steps: Optional[int] = Field(default=None) @field_validator("model_type") @classmethod def validate_model_type(cls, v: str) -> str: if v not in [ "embedding", "generative", ]: raise ValueError( f"model_type must be one of 'embedding' and 'generative, {v} isn't" ) return v @field_validator("output_probabilities") @classmethod def validate_output_probabilities(cls, v: str, info: ValidationInfo) -> str: if v and info.data.get("model_type") == "embedding": raise ValueError( "For embedding models, 'output_probabilities' must be set to false" ) return v @field_validator("training_config_path") @classmethod def validate_training_config_path(cls, v: str) -> str: if not (v is None or os.path.exists(v)): raise ValueError(f"{v} does not exist") return v @field_validator("autoregression_extra_steps") @classmethod def validate_autoregression_extra_steps(cls, v: bool, info: ValidationInfo) -> bool: if v is not None and v > 0: if not info.data.get("autoregression"): raise ValueError( f"'autoregression_extra_steps' can only be larger than 0 if 'autoregression' is true: {info.data.get('autoregression')}" ) if not np.all( np.array(sorted(info.data.get("input_columns"))) == np.array(sorted(info.data.get("target_columns"))) ): raise ValueError( "'autoregression_extra_steps' can only be larger than 0 if 'input_columns' and 'target_columns' are identical" ) return v @field_validator("autoregression") @classmethod def validate_autoregression(cls, v: bool, info: ValidationInfo): if v and info.data.get("model_type") == "embedding": raise ValueError("Autoregression is not possible for embedding models") if v and info.data.get("prediction_length") > 1: raise ValueError( "Autoregressive inference is not possible for models with prediction_length > 1" ) if v and not np.all( np.array(sorted(info.data.get("input_columns"))) == np.array(sorted(info.data.get("target_columns"))) ): raise ValueError( "Autoregressive inference with non-identical 'input_columns' and 'target_columns' is possible but should not be performed" ) return v @field_validator("data_path") @classmethod def validate_data_path(cls, v: str, info: ValidationInfo) -> str: if isinstance(v, str): v2 = normalize_path(v, info.data.get("project_root")) if not os.path.exists(v2): raise ValueError(f"{v2} does not exist") if isinstance(v, list): for vv in v: v2 = normalize_path(v, info.data.get("project_root")) if not os.path.exists(v2): raise ValueError(f"{v2} does not exist") return v @field_validator("read_format", "write_format") @classmethod def validate_format(cls, v: str) -> str: if v not in ["csv", "parquet", "pt"]: raise ValueError("Currently only 'csv', 'parquet' and 'pt' are supported") return v @field_validator("target_column_types") @classmethod def validate_target_column_types(cls, v: dict, info: ValidationInfo) -> dict: if not all(vv in ["categorical", "real"] for vv in v.values()): raise ValueError( "Target column types must be either 'categorical' or 'real'" ) if list(v.keys()) != info.data.get("target_columns", []): raise ValueError( "target_columns and target_column_types must contain the same keys in the same order" ) return v @field_validator("map_to_id") @classmethod def validate_map_to_id(cls, v: bool, info: ValidationInfo) -> bool: if v and not any( vv == "categorical" for vv in info.data.get("target_column_types", {}).values() ): raise ValueError( "map_to_id can only be True if at least one target variable is categorical" ) return v @field_validator("distributed") @classmethod def validate_distributed_inference(cls, v: bool, info: ValidationInfo) -> bool: if v and info.data.get("read_format") != "pt": raise ValueError( "Distributed inference is only supported for preprocessed '.pt' files. Please set read_format to 'pt'." ) return v def __init__(self, **data): super().__init__(**data) column_ordered = list(self.column_types.keys()) columns_ordered_filtered = [ c for c in column_ordered if c in self.target_columns ] if not (columns_ordered_filtered == self.target_columns): raise ValueError(f"{columns_ordered_filtered} != {self.target_columns}")