Source code for sequifier.config.preprocess_config
import os
from typing import Optional
import numpy as np
import yaml
from beartype import beartype
from pydantic import BaseModel, validator
@beartype
def load_preprocessor_config(
config_path: str, args_config: dict
) -> "PreprocessorModel":
"""
Load preprocessor configuration from a YAML file and update it with args_config.
Args:
config_path: Path to the YAML configuration file.
args_config: Dictionary containing additional configuration arguments.
Returns:
PreprocessorModel instance with loaded configuration.
"""
with open(config_path, "r") as f:
config_values = yaml.safe_load(f)
config_values.update(args_config)
return PreprocessorModel(**config_values)
[docs]
class PreprocessorModel(BaseModel):
"""
Pydantic model for preprocessor configuration.
Attributes:
project_path: The path to the sequifier project directory.
data_path: The path to the input data file.
read_format: The file type of the input data. Can be 'csv' or 'parquet'.
write_format: The file type for the preprocessed output data.
combine_into_single_file: If True, combines all preprocessed data into a single file.
selected_columns: A list of columns to be included in the preprocessing. If None, all columns are used.
group_proportions: A list of floats that define the relative sizes of data splits (e.g., for train, validation, test).
The sum of proportions must be 1.0.
seq_length: The sequence length for the model inputs.
seq_step_sizes: A list of step sizes for creating subsequences within each data split.
max_rows: The maximum number of input rows to process. If None, all rows are processed.
seed: A random seed for reproducibility.
n_cores: The number of CPU cores to use for parallel processing. If None, it uses the available CPU cores.
batches_per_file: The number of batches to process per file.
process_by_file: A flag to indicate if processing should be done file by file.
"""
project_path: str
data_path: str
read_format: str = "csv"
write_format: str = "parquet"
combine_into_single_file: bool = True
selected_columns: Optional[list[str]]
group_proportions: list[float]
seq_length: int
seq_step_sizes: Optional[list[int]]
max_rows: Optional[int]
seed: int
n_cores: Optional[int]
batches_per_file: int = 1024
process_by_file: bool = True
@validator("data_path", always=True)
def validate_data_path(cls, v: str) -> str:
if not os.path.exists(v):
raise ValueError(f"{v} does not exist")
return v
@validator("read_format", "write_format", always=True)
def validate_format(cls, v: str) -> str:
supported_formats = ["csv", "parquet", "pt"]
if v not in supported_formats:
raise ValueError(
f"Currently only {', '.join(supported_formats)} are supported"
)
return v
@validator("combine_into_single_file", always=True)
def validate_format2(cls, v: bool, values: dict):
if values["write_format"] == "pt" and v is True:
raise ValueError(
"With write_format 'pt', combine_into_single_file must be set to False"
)
return v
@validator("group_proportions")
def validate_proportions_sum(cls, v: list[float]) -> list[float]:
if not np.isclose(np.sum(v), 1.0):
raise ValueError(
f"group_proportions must sum to 1.0, but sums to {np.sum(v)}"
)
if not all(p > 0 for p in v):
raise ValueError(f"All group_proportions must be positive: {v}")
return v
@validator("seq_step_sizes", always=True)
def validate_step_sizes(cls, v: Optional[list[int]], values: dict) -> list[int]:
group_proportions = values.get("group_proportions")
assert (
group_proportions is not None
), "group_proportions must be set to validate seq_step_sizes"
assert isinstance(v, list), "seq_step_sizes should be a list after __init__"
if len(v) != len(group_proportions):
raise ValueError(
f"Length of seq_step_sizes ({len(v)}) must match length of "
f"group_proportions ({len(group_proportions)})"
)
if not all(step > 0 for step in v):
raise ValueError(f"All seq_step_sizes must be positive integers: {v}")
return v
@validator("batches_per_file")
def validate_batches_per_file(cls, v: int) -> int:
if v < 1:
raise ValueError("batches_per_file must be a positive integer")
return v
def __init__(self, **kwargs):
default_seq_step_size = [kwargs["seq_length"]] * len(
kwargs["group_proportions"]
)
kwargs["seq_step_sizes"] = kwargs.get("seq_step_sizes", default_seq_step_size)
kwargs["seq_length"] += 1
super().__init__(**kwargs)