Source code for sequifier.config.preprocess_config

import os
from typing import Optional

import numpy as np
import yaml
from beartype import beartype
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator

from sequifier.helpers import try_catch_excess_keys


@beartype
def load_preprocessor_config(
    config_path: str, args_config: dict
) -> "PreprocessorModel":
    """
    Load preprocessor configuration from a YAML file and update it with args_config.

    Args:
        config_path: Path to the YAML configuration file.
        args_config: Dictionary containing additional configuration arguments.

    Returns:
        PreprocessorModel instance with loaded configuration.
    """
    with open(config_path, "r") as f:
        config_values = yaml.safe_load(f)

    config_values.update(args_config)

    config_values["seed"] = config_values.get("seed", 1010)

    return try_catch_excess_keys(config_path, PreprocessorModel, config_values)


[docs] class PreprocessorModel(BaseModel): """ Pydantic model for preprocessor configuration. Attributes: project_root: The path to the sequifier project directory. data_path: The path to the input data file. read_format: The file type of the input data. Can be 'csv' or 'parquet'. write_format: The file type for the preprocessed output data. merge_output: If True, combines all preprocessed data into a single file. selected_columns: A list of columns to be included in the preprocessing. If None, all columns are used. split_ratios: A list of floats that define the relative sizes of data splits (e.g., for train, validation, test). The sum of proportions must be 1.0. seq_length: The sequence length for the model inputs. stride_by_split: A list of step sizes for creating subsequences within each data split. max_rows: The maximum number of input rows to process. If None, all rows are processed. seed: A random seed for reproducibility. n_cores: The number of CPU cores to use for parallel processing. If None, it uses the available CPU cores. batches_per_file: The number of batches to process per file. process_by_file: A flag to indicate if processing should be done file by file. continue_preprocessing: Continue preprocessing job that was interrupted while writing to temp folder. subsequence_start_mode: "distribute" to minimize max subsequence overlap, or "exact". """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") project_root: str data_path: str read_format: str = "csv" write_format: str = "parquet" merge_output: bool = True selected_columns: Optional[list[str]] = None split_ratios: list[float] seq_length: int stride_by_split: Optional[list[int]] = None max_rows: Optional[int] = None seed: int n_cores: Optional[int] = None batches_per_file: int = 1024 process_by_file: bool = True continue_preprocessing: bool = False subsequence_start_mode: str = "distribute" @field_validator("data_path") @classmethod def validate_data_path(cls, v: str) -> str: if not os.path.exists(v): raise ValueError(f"{v} does not exist") return v @field_validator("read_format", "write_format") @classmethod def validate_format(cls, v: str) -> str: supported_formats = ["csv", "parquet", "pt"] if v not in supported_formats: raise ValueError( f"Currently only {', '.join(supported_formats)} are supported" ) return v @field_validator("merge_output") @classmethod def validate_format2(cls, v: bool, info: ValidationInfo): write_format = info.data.get("write_format") # Existing check: 'pt' format cannot be combined if write_format == "pt" and v is True: raise ValueError( "With write_format 'pt', merge_output must be set to False" ) # New constraint: 'parquet' or 'csv' formats cannot be uncombined (split) if write_format != "pt" and v is False: raise ValueError( f"With write_format '{write_format}', merge_output must be set to True. " "Only 'pt' format supports uncombined (split) output." ) return v @field_validator("split_ratios") @classmethod def validate_proportions_sum(cls, v: list[float]) -> list[float]: if not np.isclose(np.sum(v), 1.0): raise ValueError(f"split_ratios must sum to 1.0, but sums to {np.sum(v)}") if not all(p > 0 for p in v): raise ValueError(f"All split_ratios must be positive: {v}") return v @field_validator("stride_by_split") @classmethod def validate_step_sizes( cls, v: Optional[list[int]], info: ValidationInfo ) -> list[int]: split_ratios = info.data.get("split_ratios") if not (split_ratios is not None): raise ValueError("split_ratios must be set to validate stride_by_split") if not isinstance(v, list): raise ValueError("stride_by_split should be a list after __init__") if len(v) != len(split_ratios): raise ValueError( f"Length of stride_by_split ({len(v)}) must match length of " f"split_ratios ({len(split_ratios)})" ) if not all(step > 0 for step in v): raise ValueError(f"All stride_by_split must be positive integers: {v}") return v @field_validator("batches_per_file") @classmethod def validate_batches_per_file(cls, v: int) -> int: if v < 1: raise ValueError("batches_per_file must be a positive integer") return v @field_validator("continue_preprocessing") @classmethod def validate_continue_preprocessing(cls, v: bool, info: ValidationInfo) -> bool: if v and info.data.get("data_path").split(".") in ["csv", "parquet"]: raise ValueError( "'continue_preprocessing' can only be set to true for folder inputs, not single files " ) return v @field_validator("subsequence_start_mode") @classmethod def validate_subsequence_start_mode(cls, v: str) -> str: if v not in ["distribute", "exact"]: raise ValueError( "subsequence_start_mode must be one of 'distribute', 'exact'" ) return v def __init__(self, **kwargs): default_stride_for_split = [kwargs["seq_length"]] * len(kwargs["split_ratios"]) kwargs["stride_by_split"] = kwargs.get( "stride_by_split", default_stride_for_split ) super().__init__(**kwargs)