Source code for sequifier.train

import copy
import functools
import glob
import math
import os
import time
import uuid
import warnings
from datetime import timedelta
from typing import Any, Optional, Union

import numpy as np
import polars as pl
import torch
import torch._dynamo
import torch.distributed as dist
import torch.multiprocessing as mp
from beartype import beartype
from torch import Tensor, nn
from torch.amp import GradScaler
from torch.distributed.fsdp import (
    CPUOffload,
    FullOptimStateDictConfig,
    FullStateDictConfig,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.nn import ModuleDict
from torch.nn.functional import one_hot
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader

torch._dynamo.config.suppress_errors = True

from sequifier.config.train_config import TrainModel, load_train_config  # noqa: E402
from sequifier.helpers import (  # noqa: E402
    conditional_beartype,
    configure_determinism,
    configure_logger,
    construct_index_maps,
    get_torch_dtype,
)
from sequifier.io.sequifier_dataset_from_file import (  # noqa: E402
    SequifierDatasetFromFile,
)
from sequifier.io.sequifier_dataset_from_folder import (  # noqa: E402
    SequifierDatasetFromFolder,
)
from sequifier.io.sequifier_dataset_from_folder_lazy import (  # noqa: E402
    SequifierDatasetFromFolderLazy,
)
from sequifier.model.layers import RMSNorm, SequifierEncoderLayer  # noqa: E402
from sequifier.optimizers.optimizers import get_optimizer_class  # noqa: E402


[docs]@beartype def setup(rank: int, world_size: int, backend: str = "nccl"): """Sets up the distributed training environment. Args: rank: The rank of the current process. world_size: The total number of processes. backend: The distributed backend to use. """ os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") os.environ["NCCL_DEBUG"] = "INFO" os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" if not dist.is_initialized(): timeout_sec = int(os.environ.get("NCCL_TIMEOUT", 1800)) dist.init_process_group( backend, rank=rank, world_size=world_size, timeout=timedelta(seconds=timeout_sec), )
[docs]def cleanup(): """Cleans up the distributed training environment.""" dist.destroy_process_group()
@beartype def create_dummy_data(config: TrainModel, local_rank: int) -> dict[str, Tensor]: dummy_data = {} for col in config.input_columns: dtype = torch.int64 if col in config.categorical_columns else torch.float32 dummy_data[col] = torch.ones( (2, config.seq_length), dtype=dtype, device=local_rank ) return dummy_data
[docs]@beartype def train_worker( local_rank: int, world_size: int, config: TrainModel, from_folder: bool, global_rank: int, ): """The worker function for distributed training. Args: rank: The rank of the current process. world_size: The total number of processes. config: The training configuration. from_folder: Whether to load data from a folder (e.g., preprocessed .pt files) or a single file (e.g., .parquet). global_rank: The global rank """ logger = configure_logger(config.project_root, config.model_name, global_rank) if config.training_spec.distributed: setup(global_rank, world_size, config.training_spec.backend) if config.training_spec.device.startswith("cuda"): torch.cuda.set_device(local_rank) # 1. Create Datasets and DataLoaders with DistributedSampler if from_folder: if config.training_spec.load_full_data_to_ram: train_dataset = SequifierDatasetFromFolder( config.training_data_path, config ) valid_dataset = SequifierDatasetFromFolder( config.validation_data_path, config ) else: train_dataset = SequifierDatasetFromFolderLazy( config.training_data_path, config ) valid_dataset = SequifierDatasetFromFolderLazy( config.validation_data_path, config ) else: if config.training_spec.distributed: raise ValueError( "Distributed training is not supported with single-file datasets." ) train_dataset = SequifierDatasetFromFile(config.training_data_path, config) valid_dataset = SequifierDatasetFromFile(config.validation_data_path, config) train_loader = DataLoader( train_dataset, batch_size=None, # Batching is handled natively by the IterableDataset sampler=None, # Sharding is handled natively by the IterableDataset num_workers=config.training_spec.num_workers, pin_memory=config.training_spec.device not in ["mps", "cpu"], prefetch_factor=4 if config.training_spec.num_workers > 0 else None, persistent_workers=(config.training_spec.num_workers > 0), ) valid_loader = DataLoader( valid_dataset, batch_size=None, sampler=None, num_workers=config.training_spec.num_workers, pin_memory=config.training_spec.device not in ["mps", "cpu"], prefetch_factor=4 if config.training_spec.num_workers > 0 else None, persistent_workers=(config.training_spec.num_workers > 0), ) configure_determinism(config.seed, config.training_spec.enforce_determinism) model = TransformerModel(config, rank=global_rank, local_rank=local_rank) latest_model_path = model._get_latest_model_name() pytorch_total_params = sum(p.numel() for p in model.parameters()) checkpoint = None is_fsdp = config.training_spec.fsdp if config.training_spec.continue_training and latest_model_path: if not is_fsdp or global_rank == 0: checkpoint = torch.load( latest_model_path, map_location="cpu", weights_only=False ) epoch = checkpoint["epoch"] batch = checkpoint["batch"] time_string = f", epoch {epoch}, batch {batch}" else: time_string = "" if is_fsdp and global_rank == 0: model.load_state_dict(checkpoint["model_state_dict"]) # type: ignore del checkpoint["model_state_dict"] # type: ignore logger.info( f"[INFO] Resuming training from checkpoint '{latest_model_path}'{time_string}. Total params: {format_number(pytorch_total_params)}" ) else: model.start_epoch = 1 model.start_batch = 0 logger.info( f"[INFO] Initializing new model with {format_number(pytorch_total_params)} parameters." ) if config.training_spec.distributed: if is_fsdp: auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={SequifierEncoderLayer}, ) strategy_map = { "FULL_SHARD": ShardingStrategy.FULL_SHARD, "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, "NO_SHARD": ShardingStrategy.NO_SHARD, } sharding_strategy = strategy_map.get( config.training_spec.fsdp_sharding_strategy, # default: "FULL_SHARD"), ShardingStrategy.FULL_SHARD, ) mixed_precision = None if config.training_spec.layer_autocast: amp_dtype = get_torch_dtype( config.training_spec.layer_type_dtypes.get("linear", "bfloat16") if config.training_spec.layer_type_dtypes else "float32" ) mixed_precision = MixedPrecision( param_dtype=amp_dtype, reduce_dtype=amp_dtype, buffer_dtype=amp_dtype, ) cpu_offload = CPUOffload( offload_params=config.training_spec.fsdp_cpu_offload # deault: False ) model = FSDP( model, auto_wrap_policy=auto_wrap_policy, mixed_precision=mixed_precision, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, device_id=local_rank, use_orig_params=True, sync_module_states=True, ) dist.barrier() else: device_ids = ( [local_rank] if config.training_spec.device.startswith("cuda") else None ) model = DDP(model, device_ids=device_ids, find_unused_parameters=False) if config.training_spec.device.startswith("cuda"): model = torch.compile(model) dummy_data = create_dummy_data(config, local_rank) if config.training_spec.layer_autocast and not is_fsdp: with torch.no_grad(), torch.autocast( device_type="cuda", dtype=torch.bfloat16 ): _ = model(dummy_data, False) else: with torch.no_grad(): _ = model(dummy_data, False) if config.training_spec.distributed: dist.barrier() # Initialize Optimizer unwrapped_model = model.module if config.training_spec.distributed else model params_to_optimize = model.parameters() if is_fsdp else unwrapped_model.parameters() unwrapped_model.initialize_optimizer(params=params_to_optimize) # Load Optimizer and Scheduler States if config.training_spec.continue_training and latest_model_path: if is_fsdp: # # Scatter optimizer states over the network full_osd = checkpoint["optimizer_state_dict"] if global_rank == 0 else None # type: ignore sharded_osd = FSDP.scatter_full_optim_state_dict( full_optim_state_dict=full_osd, model=model, optim=unwrapped_model.optimizer, ) unwrapped_model.optimizer.load_state_dict(sharded_osd) # Broadcast the tiny scheduler and epoch metadata to all ranks if global_rank == 0: if checkpoint["batch"] + 1 >= len(train_loader): # type: ignore start_epoch = checkpoint["epoch"] + 1 # type: ignore start_batch = 0 else: start_epoch = checkpoint["epoch"] # type: ignore start_batch = checkpoint["batch"] + 1 # type: ignore else: start_epoch = None start_batch = None meta = ( [ start_epoch, # type: ignore start_batch, # type: ignore checkpoint["scheduler_state_dict"], # type: ignore ] if global_rank == 0 else [None, None, None] ) if config.training_spec.distributed: dist.broadcast_object_list(meta, src=0) unwrapped_model.start_epoch, unwrapped_model.start_batch, sched_state = meta # type: ignore if sched_state is not None: unwrapped_model.scheduler.load_state_dict(sched_state) if global_rank == 0: del checkpoint, full_osd # Clear remaining CPU memory logger.info( f"[INFO] Resuming FSDP training from checkpoint '{latest_model_path}'. Total params: {format_number(pytorch_total_params)}" ) else: model.load_state_dict(checkpoint["model_state_dict"]) # type: ignore unwrapped_model.optimizer.load_state_dict( checkpoint["optimizer_state_dict"] # type: ignore ) # Safely cast states to device if loading non-FSDP CPU state dict for state in unwrapped_model.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = ( v.cpu() if k == "step" else v.to(config.training_spec.device) ) if "scheduler_state_dict" in checkpoint: # type: ignore unwrapped_model.scheduler.load_state_dict( checkpoint["scheduler_state_dict"] # type: ignore ) if checkpoint["batch"] + 1 >= len(train_loader): # type: ignore unwrapped_model.start_epoch = checkpoint["epoch"] + 1 # type: ignore unwrapped_model.start_batch = 0 else: unwrapped_model.start_epoch = checkpoint["epoch"] # type: ignore unwrapped_model.start_batch = checkpoint["batch"] + 1 # type: ignore # Start training unwrapped_model.train_model( train_loader, valid_loader, ddp_model=model if config.training_spec.distributed else None, ) if config.training_spec.distributed: cleanup()
@beartype def _mp_train_worker_wrapper( local_rank: int, world_size: int, config: TrainModel, from_folder: bool ): train_worker(local_rank, world_size, config, from_folder, global_rank=local_rank)
[docs]@beartype def train(args: Any, args_config: dict[str, Any]) -> None: """The main training function. Args: args: The command-line arguments. args_config: The configuration dictionary. """ config_path = args.config_path or "configs/train.yaml" config = load_train_config(config_path, args_config, args.skip_metadata) world_size = config.training_spec.world_size from_folder = config.read_format == "pt" if config.training_spec.distributed: if "RANK" in os.environ and "WORLD_SIZE" in os.environ: # Launched via torchrun / srun for multi-node distributed training global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) train_worker(local_rank, world_size, config, from_folder, global_rank) else: # Single-node multi-GPU fallback using mp.spawn mp.spawn( _mp_train_worker_wrapper, args=(world_size, config, from_folder), nprocs=world_size, join=True, ) else: train_worker(0, 1, config, from_folder, 0)
[docs]@beartype def format_number(number: Union[int, float, np.float32]) -> str: """Format a number for display. Args: number: The number to format. Returns: A formatted string representation of the number. """ if np.isnan(number): return "NaN" elif number == 0: order_of_magnitude = 0 else: order_of_magnitude = math.floor(math.log(np.abs(number), 10)) number_adjusted = number * (10 ** (-order_of_magnitude)) return f"{number_adjusted:5.2f}e{order_of_magnitude}"
[docs]class TransformerEmbeddingModel(nn.Module): """A wrapper around the TransformerModel to expose the embedding functionality."""
[docs] def __init__(self, transformer_model: "TransformerModel"): """Initializes the TransformerEmbeddingModel. Args: transformer_model: The TransformerModel to wrap. """ super().__init__() self.transformer_model = transformer_model self.logger = self.transformer_model.logger
@beartype def _copy_model(self): """Copies the model. This creates a deep copy of the model, typically for saving the "best model". It temporarily removes the `log_file` attribute before copying to avoid errors, then re-initializes it. Returns: A deep copy of the current TransformerModel instance. """ logger_ref = self.transformer_model.logger del self.transformer_model.logger del self.logger model_copy = copy.deepcopy(self) model_copy.transformer_model._initialize_log_file() self.transformer_model.logger = logger_ref self.logger = self.transformer_model.logger return model_copy
[docs] @conditional_beartype def forward(self, src: dict[str, Tensor]): """Forward pass for the embedding model. Args: src: The input data. Returns: The embedded output. """ return self.transformer_model.forward_embed(src)
[docs]class TransformerModel(nn.Module): """The main Transformer model for the sequifier. This class implements the Transformer model, including the training and evaluation loops, as well as the export functionality. """
[docs] @beartype def __init__( self, hparams: Any, rank: Optional[int] = None, local_rank: Optional[int] = None ): """Initializes the TransformerModel. Based on the hyperparameters, this initializes: - Embeddings for categorical and real features (self.encoder) - Positional encoders (self.pos_encoder) - The main TransformerEncoder (self.transformer_encoder) - Output decoders for each target column (self.decoder) - Loss functions (self.criterion) - Optimizer (self.optimizer) and scheduler (self.scheduler) Args: hparams: The hyperparameters for the model (e.g., from TrainModel config). rank: The rank of the current process (for distributed training). """ super().__init__() self.project_root = hparams.project_root self.model_type = "Transformer" self.rank = rank self.model_name = hparams.model_name or uuid.uuid4().hex[:8] self._initialize_log_file() self.logger.info(f"--- Starting Training for model: {self.model_name} ---") self.input_columns = hparams.input_columns self.categorical_columns = [ col for col in hparams.categorical_columns if self.input_columns is None or col in self.input_columns ] self.real_columns = [ col for col in hparams.real_columns if self.input_columns is None or col in self.input_columns ] self.logger.info(f"{self.categorical_columns = }") self.logger.info(f"{self.real_columns = }") self.target_columns = hparams.target_columns self.target_column_types = hparams.target_column_types self.loss_weights = hparams.training_spec.loss_weights self.seq_length = hparams.seq_length self.n_classes = hparams.n_classes self.inference_batch_size = hparams.inference_batch_size self.log_interval = hparams.training_spec.log_interval self.class_share_log_columns = hparams.training_spec.class_share_log_columns self.index_maps = construct_index_maps( hparams.id_maps, self.class_share_log_columns, True ) self.export_embedding_model = hparams.export_embedding_model self.export_generative_model = hparams.export_generative_model self.export_onnx = hparams.export_onnx self.export_pt = hparams.export_pt self.export_with_dropout = hparams.export_with_dropout self.early_stopping_epochs = hparams.training_spec.early_stopping_epochs self.hparams = hparams self.drop = nn.Dropout(hparams.training_spec.dropout) self.encoder = ModuleDict() self.dim_model = self.hparams.model_spec.dim_model self.initial_embedding_dim = self.hparams.model_spec.initial_embedding_dim self.joint_embedding_dim = hparams.model_spec.joint_embedding_dim if self.joint_embedding_dim is not None: self.joint_embedding_layer = nn.Linear( self.initial_embedding_dim, self.joint_embedding_dim ) else: self.joint_embedding_layer = None self.use_rope = hparams.model_spec.positional_encoding == "rope" if hparams.model_spec.feature_embedding_dims is not None: self.feature_embedding_dims = hparams.model_spec.feature_embedding_dims else: self.feature_embedding_dims = self._get_feature_embedding_dims( self.initial_embedding_dim, self.categorical_columns, self.real_columns ) self.real_columns_with_embedding = [] self.real_columns_direct = [] for col in self.real_columns: if self.feature_embedding_dims[col] > 1: self.encoder[col] = nn.Linear(1, self.feature_embedding_dims[col]) self.real_columns_with_embedding.append(col) else: if self.feature_embedding_dims[col] != 1: raise ValueError( f"Real column {col} without embedding must have feature_embedding_dims=1" ) self.real_columns_direct.append(col) for col, n_classes in self.n_classes.items(): if col in self.categorical_columns: self.encoder[col] = nn.Embedding( n_classes, self.feature_embedding_dims[col] ) if not self.use_rope: self.pos_encoder = ModuleDict() for col in self.real_columns: self.pos_encoder[col] = nn.Embedding( self.seq_length, self.feature_embedding_dims[col] ) for col, n_classes in self.n_classes.items(): if col in self.categorical_columns: self.pos_encoder[col] = nn.Embedding( self.seq_length, self.feature_embedding_dims[col] ) else: self.pos_encoder = None self.layers = nn.ModuleList( [ SequifierEncoderLayer( hparams.model_spec, self.dim_model, hparams.model_spec.n_head, hparams.model_spec.dim_feedforward, hparams.training_spec.dropout, hparams.seq_length, ) for _ in range(hparams.model_spec.num_layers) ] ) if hparams.model_spec.norm_first: NormClass = ( RMSNorm if hparams.model_spec.normalization == "rmsnorm" else nn.LayerNorm ) self.final_norm = NormClass(self.dim_model) else: self.final_norm = nn.Identity() self.prediction_length = hparams.model_spec.prediction_length self.decoder = ModuleDict() self.softmax = ModuleDict() for target_column, target_column_type in self.target_column_types.items(): if target_column_type == "categorical": self.decoder[target_column] = nn.Linear( self.dim_model, self.n_classes[target_column], ) self.softmax[target_column] = nn.LogSoftmax(dim=-1) elif target_column_type == "real": self.decoder[target_column] = nn.Linear(self.dim_model, 1) else: raise ValueError( f"Target column type {target_column_type} not in ['categorical', 'real']" ) self.device = hparams.training_spec.device self.device_max_concat_length = hparams.training_spec.device_max_concat_length if hparams.training_spec.device.startswith("cuda"): if local_rank is not None: self.device = f"cuda:{local_rank}" elif self.rank is not None: # Backwards compatibility self.device = f"cuda:{self.rank}" else: self.device = hparams.training_spec.device else: self.device = hparams.training_spec.device if not self.hparams.training_spec.fsdp: self.to(self.device) self.criterion = self._init_criterion(hparams=hparams) self.batch_size = hparams.training_spec.batch_size self.accumulation_steps = hparams.training_spec.accumulation_steps self.register_buffer( "src_mask", self._generate_square_subsequent_mask(self.seq_length), persistent=False, # Optional: prevents the mask from being saved in your checkpoints ) self._init_weights() self.scheduler_step_on = hparams.training_spec.scheduler_step_on self.save_interval_epochs = hparams.training_spec.save_interval_epochs self.save_latest_interval_minutes = ( hparams.training_spec.save_latest_interval_minutes ) self.save_batch_interval_minutes = ( hparams.training_spec.save_batch_interval_minutes ) self.save_batch_interval_minutes_val_loss = ( hparams.training_spec.save_batch_interval_minutes_val_loss ) self.continue_training = hparams.training_spec.continue_training use_scaler = False if hparams.training_spec.layer_type_dtypes: if "float16" in hparams.training_spec.layer_type_dtypes.values(): use_scaler = True if self.hparams.training_spec.fsdp: self.scaler = ShardedGradScaler(enabled=use_scaler) else: self.scaler = GradScaler( device=self.device.split(":")[0], enabled=use_scaler ) self._apply_layer_dtypes()
[docs] @beartype def initialize_optimizer(self, params: Any = None) -> None: """Initializes the optimizer and scheduler.""" if params is None: params = self.parameters() opt_kwargs = dict(self.hparams.training_spec.optimizer) self.optimizer = self._get_optimizer( params=params, **self._filter_key(opt_kwargs, "name") ) sched_kwargs = dict(self.hparams.training_spec.scheduler) self.scheduler = self._get_scheduler(**self._filter_key(sched_kwargs, "name")) self.scheduler_step_on = self.hparams.training_spec.scheduler_step_on
@beartype def _apply_layer_dtypes(self) -> None: """Casts specific layer types to configured dtypes (e.g., bfloat16, float8).""" layer_config = self.hparams.training_spec.layer_type_dtypes if not layer_config: return self.logger.info(f"[INFO] Applying custom layer dtypes: {layer_config}") # Iterate over all sub-modules and cast based on type for name, module in self.named_modules(): # Linear Layers if isinstance(module, nn.Linear): is_decoder = any(module is m for m in self.decoder.values()) if is_decoder and "decoder" in layer_config: module.to(dtype=get_torch_dtype(layer_config["decoder"])) elif "linear" in layer_config: module.to(dtype=get_torch_dtype(layer_config["linear"])) # Embeddings elif isinstance(module, nn.Embedding) and "embedding" in layer_config: target_dtype = get_torch_dtype(layer_config["embedding"]) module.to(dtype=target_dtype) # Normalization (RMSNorm, LayerNorm) elif isinstance(module, (nn.LayerNorm, RMSNorm)) and "norm" in layer_config: target_dtype = get_torch_dtype(layer_config["norm"]) module.to(dtype=target_dtype) if "linear" in layer_config: target_dtype = get_torch_dtype(layer_config["linear"]) for criterion in self.criterion.values(): if hasattr(criterion, "weight") and criterion.weight is not None: criterion.weight.data = criterion.weight.data.to(dtype=target_dtype) @beartype def _init_criterion(self, hparams: Any) -> ModuleDict: """Initializes the criterion (loss function) for each target column. Args: hparams: The hyperparameters for the model, used to find criterion names and class weights. Returns: A dictionary mapping target column names to their loss function instances. """ criterion = ModuleDict() for target_column in self.target_columns: criterion_name = hparams.training_spec.criterion[target_column] if hasattr(torch.nn, criterion_name): criterion_class = getattr(torch.nn, criterion_name) else: raise ValueError(f"Criterion {criterion_name} not found in torch.nn") criterion_kwargs = {} if ( hparams.training_spec.class_weights is not None and target_column in hparams.training_spec.class_weights ): criterion_kwargs["weight"] = Tensor( hparams.training_spec.class_weights[target_column] ) criterion_kwargs["reduction"] = "none" criterion[target_column] = criterion_class(**criterion_kwargs) return criterion @beartype def _get_feature_embedding_dims( self, embedding_size: int, categorical_columns: list[str], real_columns: list[str], ) -> dict[str, int]: """Calculates the embedding dimension for each column. This attempts to distribute the total `embedding_size` across all input columns. Args: embedding_size: The total embedding dimension (initial_embedding_dim). categorical_columns: List of categorical column names. real_columns: List of real-valued column names. Returns: A dictionary mapping column names to their calculated embedding dimension. """ if not (len(categorical_columns) + len(real_columns)) > 0: raise ValueError("No columns found") if len(categorical_columns) == 0 and len(real_columns) > 0: if embedding_size < len(real_columns): raise ValueError( f"initial_embedding_dim ({embedding_size}) is smaller than the number of real input columns ({len(real_columns)}). " "Cannot allocate at least 1 dimension per column." ) feature_embedding_dims = {col: 1 for col in real_columns} column_index = dict(enumerate(real_columns)) remaining_dims = embedding_size - len(real_columns) for i in range(remaining_dims): j = i % len(real_columns) feature_embedding_dims[column_index[j]] += 1 if sum(feature_embedding_dims.values()) != embedding_size: raise ValueError( f"Auto-calculated embedding dimensions ({sum(feature_embedding_dims.values())}) do not sum to initial_embedding_dim ({embedding_size})." ) elif len(real_columns) == 0 and len(categorical_columns) > 0: if embedding_size < len(categorical_columns): raise ValueError( f"initial_embedding_dim ({embedding_size}) is smaller than the number of categorical columns ({len(categorical_columns)}). " "Resulting embedding dimension would be 0." ) if (embedding_size % len(categorical_columns)) != 0: raise ValueError( f"initial_embedding_dim ({embedding_size}) must be divisible by n_categorical ({len(categorical_columns)})" ) dim_model_comp = embedding_size // len(categorical_columns) feature_embedding_dims = { col: dim_model_comp for col in categorical_columns } else: raise ValueError( "If both real and categorical variables are present, feature_embedding_dims config value must be set" ) return feature_embedding_dims @staticmethod def _generate_square_subsequent_mask(sz: int) -> Tensor: """Generates an upper-triangular matrix of -inf, with zeros on diag. This is used as a mask to prevent attention to future tokens in the transformer. Args: sz: The size of the square mask (sequence length). Returns: A square tensor of shape (sz, sz) with -inf in the upper triangle. """ return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) @staticmethod def _filter_key(dict_: dict[str, Any], key: str) -> dict[str, Any]: """Filters a key from a dictionary. Args: dict_: The dictionary to filter. key: The key to remove. Returns: A new dictionary without the specified key. """ return {k: v for k, v in dict_.items() if k != key} @beartype def _init_weights(self) -> None: """Initializes the weights of the model.""" init_std = 0.02 for col in self.categorical_columns: self.encoder[col].weight.data.normal_(mean=0.0, std=init_std) for target_column in self.target_columns: self.decoder[target_column].bias.data.zero_() self.decoder[target_column].weight.data.normal_(mean=0.0, std=init_std) if self.pos_encoder is not None: for col_name in self.pos_encoder: self.pos_encoder[col_name].weight.data.normal_(mean=0.0, std=init_std) if self.joint_embedding_layer is not None: self.joint_embedding_layer.weight.data.normal_(mean=0.0, std=init_std) if self.joint_embedding_layer.bias is not None: self.joint_embedding_layer.bias.data.zero_() @conditional_beartype def _recursive_concat(self, srcs: list[Tensor]): """Recursively concatenates a list of tensors. This is used to avoid device-specific limits on the number of tensors that can be concatenated at once by breaking the operation into smaller, recursive chunks. Args: srcs: A list of tensors to concatenate along dimension 2. Returns: A single tensor resulting from the recursive concatenation. """ if len(srcs) <= self.device_max_concat_length: return torch.cat(srcs, 2) else: srcs_inner = [] for start in range(0, len(srcs), self.device_max_concat_length): src = self._recursive_concat( srcs[start : start + self.device_max_concat_length] ) srcs_inner.append(src) return self._recursive_concat(srcs_inner)
[docs] @conditional_beartype def forward_inner(self, src: dict[str, Tensor]) -> Tensor: """The inner forward pass of the model. This handles embedding lookup, positional encoding, and passing the combined tensor through the transformer encoder. Args: src: A dictionary mapping column names to input tensors (batch_size, seq_length). Returns: The raw output tensor from the TransformerEncoder (seq_length, batch_size, dim_model). """ srcs = [] for col in self.categorical_columns: src_t = self.encoder[col](src[col].T) * math.sqrt( self.initial_embedding_dim ) if not self.use_rope: pos = ( torch.arange( 0, self.seq_length, dtype=torch.long, device=src_t.device ) .repeat(src_t.shape[1], 1) .T ) src_p = self.pos_encoder[col](pos) # type: ignore src_c = self.drop(src_t + src_p) else: src_c = self.drop(src_t) srcs.append(src_c) for col in self.real_columns: if col in self.real_columns_direct: target_dtype = self.layers[0].ff.get_first_layer_dtype() src_t = src[col].T.unsqueeze(2).to(dtype=target_dtype) * math.sqrt( self.initial_embedding_dim ) else: assert col in self.real_columns_with_embedding layer = self.encoder[col] inp = src[col].T[:, :, None].to(dtype=layer.weight.dtype) src_t = layer(inp) * math.sqrt(self.initial_embedding_dim) if not self.use_rope: pos = ( torch.arange( 0, self.seq_length, dtype=torch.long, device=src_t.device ) .repeat(src_t.shape[1], 1) .T ) src_p = self.pos_encoder[col](pos) # type: ignore src_c = self.drop(src_t + src_p) else: src_c = self.drop(src_t) srcs.append(src_c) src2 = self._recursive_concat(srcs) src2 = src2.transpose(0, 1) if self.joint_embedding_layer is not None: src2 = self.joint_embedding_layer(src2) mask = self.src_mask.to(dtype=src2.dtype) for layer in self.layers: src2 = layer(src2, src_mask=mask) src2 = self.final_norm(src2) return src2.transpose(0, 1)
[docs] @conditional_beartype def forward_embed(self, src: dict[str, Tensor]) -> Tensor: """Forward pass for the embedding model. This returns only the embedding from the *last* token in the sequence. Args: src: A dictionary mapping column names to input tensors (batch_size, seq_length). Returns: The embedding tensor for the last token (batch_size, dim_model). """ return self.forward_inner(src)[-self.prediction_length :, :, :]
[docs] @conditional_beartype def forward_train(self, src: dict[str, Tensor]) -> dict[str, Tensor]: """Forward pass for training. This runs the inner forward pass and then applies the appropriate decoder for each target column. Args: src: A dictionary mapping column names to input tensors (batch_size, seq_length). Returns: A dictionary mapping target column names to their raw output (logit) tensors (seq_length, batch_size, n_classes/1). """ output = self.forward_inner(src) output = { target_column: self.decode(target_column, output) for target_column in self.target_columns } return output
[docs] @conditional_beartype def decode(self, target_column: str, output: Tensor) -> Tensor: """Decodes the output of the transformer encoder. Applies the appropriate final linear layer for a given target column. Args: target_column: The name of the target column to decode. output: The raw output tensor from the TransformerEncoder (seq_length, batch_size, dim_model). Returns: The decoded output (logits or real value) for the target column (seq_length, batch_size, n_classes/1). """ target_dtype = self.decoder[target_column].weight.dtype decoded = self.decoder[target_column](output.to(target_dtype)).to(torch.float32) return decoded
[docs] @conditional_beartype def apply_softmax(self, target_column: str, output: Tensor) -> Tensor: """Applies softmax to the output of the decoder. If the target is real, it returns the output unchanged. If the target is categorical, it applies LogSoftmax. Args: target_column: The name of the target column. output: The decoded output tensor (logits or real value). Returns: The output tensor, with LogSoftmax applied if categorical. """ if target_column in self.real_columns: return output else: return self.softmax[target_column](output.float())
[docs] @conditional_beartype def forward( self, src: dict[str, Tensor], return_logits: Union[bool, Tensor] = False ) -> dict[str, Tensor]: """The main forward pass of the model. This is typically used for inference/evaluation, returning the probabilities/values for the *last* token in the sequence. Args: src: A dictionary mapping column names to input tensors (batch_size, seq_length). return_logits: Return logits Returns: A dictionary mapping target column names to their final output (LogSoftmax probabilities or real values) for the last token (batch_size, n_classes/1). """ output = self.forward_train(src) if return_logits: return output return { target_column: self.apply_softmax( target_column, out[-self.prediction_length :, :, :] ) for target_column, out in output.items() }
@beartype def _get_full_state_dict( self, ddp_model: Optional[nn.Module] = None ) -> dict[str, Tensor]: """Safely extracts the full state dict to CPU memory, supporting FSDP.""" model_to_extract = ddp_model if ddp_model is not None else self if self.hparams.training_spec.fsdp: save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type( model_to_extract, StateDictType.FULL_STATE_DICT, save_policy ): state_dict = model_to_extract.state_dict() return { k.replace("_orig_mod.", ""): v.cpu().clone() for k, v in state_dict.items() } else: return { k.replace("_orig_mod.", ""): v.cpu().clone() for k, v in self.state_dict().items() }
[docs] @beartype def train_model( self, train_loader: DataLoader, valid_loader: DataLoader, ddp_model: Optional[nn.Module] = None, ) -> None: """Trains the model. This method contains the main training loop, including epoch iteration, validation, early stopping logic, and model saving/exporting. Args: train_loader: DataLoader for the training dataset. valid_loader: DataLoader for the validation dataset. ddp_model: ddp model """ best_val_loss = float("inf") n_epochs_no_improvement = 0 last_epoch = self.start_epoch - 1 best_model_state = None try: self.last_latest_save_time = time.time() self.last_batch_save_time = time.time() if self.start_epoch == 1: total_loss, total_losses, output = self._evaluate( valid_loader, ddp_model ) elapsed = 0.0 self._log_epoch_results(0, 0, elapsed, total_loss, total_losses, output) for epoch in range(self.start_epoch, self.hparams.training_spec.epochs + 1): if ( self.early_stopping_epochs is None or n_epochs_no_improvement < self.early_stopping_epochs ) and ( epoch == self.start_epoch or epoch > self.start_epoch and not np.isnan(total_loss) # type: ignore # noqa: F821 ): epoch_start_time = time.time() train_loader.dataset.set_epoch(epoch) valid_loader.dataset.set_epoch(epoch) self._train_epoch(train_loader, valid_loader, epoch, ddp_model) total_loss, total_losses, output = self._evaluate( valid_loader, ddp_model ) elapsed = time.time() - epoch_start_time self._log_epoch_results( epoch, len(train_loader), elapsed, total_loss, total_losses, output, ) if total_loss < best_val_loss: best_val_loss = total_loss best_model_state = self._get_full_state_dict(ddp_model) n_epochs_no_improvement = 0 else: n_epochs_no_improvement += 1 if self.scheduler_step_on == "epoch": self.scheduler.step() if epoch % self.save_interval_epochs == 0: self._save( epoch, len(train_loader) - 1, total_loss, ddp_model=ddp_model, suffix=f"epoch-{epoch}", ) last_epoch = epoch except KeyboardInterrupt: self.logger.info("\n" + "=" * 89) self.logger.info("[WARNING] Training interrupted by user (Ctrl+C).") if self.hparams.training_spec.distributed: dist.barrier() # 1. Use a list to hold the answer so it can be broadcasted across ranks answer_list = ["n"] # 2. Only Rank 0 prompts the user if self.rank == 0: try: answer = ( input( "Do you want to export the 'best' and 'last' models? (y/n): " ) .lower() .strip() ) if answer == "y": answer_list[0] = "y" except EOFError: # Handle non-interactive environments answer_list[0] = "n" # 3. Broadcast the decision to all GPUs so they stay in sync if self.hparams.training_spec.distributed: dist.broadcast_object_list(answer_list, src=0) # 4. If the decision is 'y', ALL ranks must participate in state dict extraction if answer_list[0] == "y": if self.rank == 0: self.logger.info("[INFO] User opted to export models.") if last_epoch is not None and best_model_state is not None: if self.rank == 0: self.logger.info( f"[INFO] Exporting 'last' model from epoch {last_epoch}..." ) # ALL RANKS MUST EXECUTE THIS to prevent FSDP all_gather deadlocks last_model_state = self._get_full_state_dict(ddp_model) # ONLY Rank 0 executes the file I/O if self.rank == 0: self._export(last_model_state, "last", last_epoch) self.logger.info( "[INFO] Exporting 'best' model (based on best val loss)..." ) self._export(best_model_state, "best", last_epoch) self.logger.info("[INFO] Models exported.") else: if self.rank == 0: self.logger.info( "[INFO] Could not export model as no epoch ran." ) else: if self.rank == 0: self.logger.info("[INFO] User opted *not* to export. Exiting.") if self.hparams.training_spec.distributed: dist.barrier() last_model_state = self._get_full_state_dict(ddp_model) if best_model_state is None: if self.rank == 0: self.logger.info( "[INFO] No validation improvement... Saving last model as 'best'." ) best_model_state = last_model_state # 2. Restrict the export saving to Rank 0 inside the _export method (which you already do) # or guard the I/O specifically: if self.rank == 0: self._export(last_model_state, "last", last_epoch) # type: ignore self._export(best_model_state, "best", last_epoch) # type: ignore self.logger.info("--- Training Complete ---") if self.hparams.training_spec.distributed: dist.barrier()
@beartype def _train_epoch( self, train_loader: DataLoader, valid_loader: DataLoader, epoch: int, ddp_model: Optional[nn.Module] = None, ) -> None: """Trains the model for one epoch. Iterates through the training DataLoader, computes loss, performs backpropagation, and updates model parameters. The DataLoader is expected to yield tuples of (sequences_dict, targets_dict, sequence_ids, subsequence_ids, start_positions). The IDs and positions are currently unused in this training loop. Args: train_loader: DataLoader for the training dataset. epoch: The current epoch number (used for logging). """ total_loss = 0.0 batches_aggregated = 0 start_time = time.time() num_batches = len(train_loader) start_batch = self.start_batch self.start_batch = 0 model_to_call = ddp_model if ddp_model is not None else self model_to_call.train() is_fsdp = self.hparams.training_spec.fsdp for batch_count, (data, targets, _, _, _) in enumerate(train_loader): if batch_count >= start_batch: data = { k: v.to(self.device, non_blocking=True) for k, v in data.items() if k in self.input_columns } targets = { k: v.to(self.device, non_blocking=True) for k, v in targets.items() if k in self.target_column_types } # Only use standard torch.autocast if FSDP MixedPrecision is NOT handling it natively if self.hparams.training_spec.layer_autocast and not is_fsdp: amp_dtype = get_torch_dtype( self.hparams.training_spec.layer_type_dtypes.get( "linear", "bfloat16" ) if self.hparams.training_spec.layer_type_dtypes else "float32" ) with torch.autocast( device_type=self.device.split(":")[0], dtype=amp_dtype ): output = model_to_call(data, True) loss, losses = self._calculate_loss(output, targets) else: output = model_to_call(data, True) loss, losses = self._calculate_loss(output, targets) self.scaler.scale(loss).backward() if ( self.accumulation_steps is None or (batch_count + 1) % self.accumulation_steps == 0 or (batch_count + 1) == num_batches ): self.scaler.unscale_(self.optimizer) if is_fsdp: model_to_call.clip_grad_norm_(0.5) else: torch.nn.utils.clip_grad_norm_(self.parameters(), 0.5) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() total_loss += loss.item() batches_aggregated += 1 if (batch_count + 1) % self.log_interval == 0 and self.rank == 0: learning_rate = self.scheduler.get_last_lr()[0] s_per_batch = (time.time() - start_time) / max( 1, batches_aggregated ) avg_train_loss = total_loss / max(1, batches_aggregated) self.logger.info( f"[INFO] Epoch {epoch:3d} | Batch {(batch_count+1):5d}/{num_batches:5d} | Loss: {format_number(avg_train_loss)} | LR: {format_number(learning_rate)} | S/Batch {format_number(s_per_batch)}" ) total_loss = 0.0 batches_aggregated = 0 self.start_batch = 0 start_time = time.time() del data, targets, output, loss, losses if self.scheduler_step_on == "batch": self.scheduler.step() should_save_latest = torch.tensor( [0], dtype=torch.int32, device=self.device ) should_save_batch = torch.tensor( [0], dtype=torch.int32, device=self.device ) if not self.hparams.training_spec.distributed or self.rank == 0: current_time = time.time() if self.save_latest_interval_minutes is not None and ( current_time - self.last_latest_save_time ) >= (self.save_latest_interval_minutes * 60): current_time = time.time() should_save_latest[0] = 1 self.last_latest_save_time = current_time if self.save_batch_interval_minutes is not None and ( current_time - self.last_batch_save_time ) >= (self.save_batch_interval_minutes * 60): if self.save_batch_interval_minutes_val_loss: val_loss, val_losses, output = self._evaluate( valid_loader, ddp_model ) self._log_epoch_results( 0, batch_count + 1, (current_time - self.last_batch_save_time), val_loss, val_losses, output, ) else: val_loss = np.float32(np.nan) current_time = time.time() should_save_batch[0] = 1 self.last_batch_save_time = current_time if self.hparams.training_spec.distributed: dist.broadcast(should_save_latest, src=0) dist.broadcast(should_save_batch, src=0) if should_save_latest.item() == 1: self._save( epoch, batch_count, np.float32(np.nan), ddp_model, suffix="latest", ) if self.rank != 0: self.last_latest_save_time = ( time.time() ) # Keep ranks roughly aligned if should_save_batch.item() == 1: self._save( epoch, batch_count, val_loss, # type: ignore ddp_model, suffix=f"epoch-{epoch}-batch-{batch_count + 1}", ) # type: ignore if self.rank != 0: self.last_batch_save_time = time.time() @beartype def _calculate_loss( self, output: dict[str, Tensor], targets: dict[str, Tensor] ) -> tuple[Tensor, dict[str, Tensor]]: """Calculates the loss for the given output and targets. Compares the model's output (from `forward_train`) with the target values, applying the appropriate criterion for each target column and combining them using loss weights. Args: output: A dictionary of output tensors from the model (seq_length, batch_size, n_classes/1). targets: A dictionary of target tensors (batch_size, seq_length). Returns: A tuple containing: - The total combined (weighted) loss as a single Tensor. - A dictionary of individual (unweighted) loss Tensors for each target column. """ mask_col = next( ( col for col in targets.keys() if self.target_column_types[col] == "categorical" ), list(targets.keys())[0], ) if self.target_column_types[mask_col] == "real": seq_mask_2d = (targets[mask_col] != 0.0).long().cumsum(dim=1) > 0 else: seq_mask_2d = targets[mask_col] != 0 mask = seq_mask_2d.T.contiguous().reshape(-1) losses = {} for target_column in targets.keys(): target_column_type = self.target_column_types[target_column] if target_column_type == "categorical": output[target_column] = ( output[target_column] .float() .reshape(-1, self.n_classes[target_column]) ) elif target_column_type == "real": output[target_column] = ( output[target_column].to(dtype=torch.float32).reshape(-1) ) target_tensor = targets[target_column].T.contiguous().reshape(-1) if self.target_column_types[target_column] == "real": target_tensor = target_tensor.to(dtype=output[target_column].dtype) raw_loss = self.criterion[target_column]( output[target_column], target_tensor ) current_mask = mask.to(dtype=raw_loss.dtype) losses[target_column] = (raw_loss * current_mask).sum() / ( current_mask.sum() + 1e-9 ) loss = None for target_column in targets.keys(): losses[target_column] = losses[target_column] * ( self.loss_weights[target_column] if self.loss_weights is not None else 1.0 ) if loss is None: loss = losses[target_column].clone() else: loss += losses[target_column] if loss is None: raise RuntimeError( "Loss calculation failed; no loss tensors were generated." ) return loss, losses @beartype def _copy_model(self): """Copies the model. This creates a deep copy of the model, typically for saving the "best model". It temporarily removes the `log_file` attribute before copying to avoid errors, then re-initializes it. Returns: A deep copy of the current TransformerModel instance. """ logger_ref = self.logger del self.logger model_copy = copy.deepcopy(self) model_copy._initialize_log_file() self.logger = logger_ref return model_copy @beartype def _transform_val(self, col: str, val: Tensor) -> Tensor: """ "Transforms input data to match the format of model output. This is used *only* for calculating the baseline loss, where the input (e.g., categorical indices) needs to be one-hot encoded to be comparable to the model's (logit) output. Args: col: The name of the column being transformed. val: The input tensor (categorical indices). Returns: A tensor transformed to be compatible with the loss function (e.g., one-hot encoded). """ if self.target_column_types[col] == "categorical": target_dtype = self.decoder[col].weight.dtype return ( one_hot(val, self.n_classes[col]) .reshape(-1, self.n_classes[col]) .to(dtype=target_dtype) ) else: if self.target_column_types[col] != "real": raise ValueError(f"Column {col} must be 'real' if not 'categorical'.") return val @beartype def _evaluate( self, valid_loader: DataLoader, ddp_model: Optional[nn.Module] = None ) -> tuple[np.float32, dict[str, np.float32], dict[str, Tensor]]: """Evaluates the model on the validation set. Iterates through the validation data, calculates the total loss, and aggregates results across all processes if in distributed mode. Also calculates a one-time baseline loss on the first call. The DataLoader is expected to yield tuples of (sequences_dict, targets_dict, sequence_ids, subsequence_ids, start_positions). The IDs and positions are currently unused during evaluation. Args: valid_loader: DataLoader for the validation dataset. ddp_model: DDP model Returns: A tuple containing: - The total aggregated validation loss (float). - A dictionary of aggregated losses for each target column (dict[str, float]). - The output tensor dictionary from the last batch (used for class share logging). """ total_loss_collect = [] # Initialize a dict to hold lists of losses for each target total_losses_collect = {col: [] for col in self.target_columns} output = {} # for type checking model_to_call = ddp_model if ddp_model is not None else self model_to_call.eval() is_fsdp = self.hparams.training_spec.fsdp with torch.no_grad(): for data, targets, _, _, _ in valid_loader: # Move data to the current process's assigned GPU data = { k: v.to(self.device, non_blocking=True) for k, v in data.items() if k in self.input_columns } targets = { k: v.to(self.device, non_blocking=True) for k, v in targets.items() if k in self.target_column_types } if self.hparams.training_spec.layer_autocast and not is_fsdp: amp_dtype = get_torch_dtype( self.hparams.training_spec.layer_type_dtypes.get( "linear", "bfloat16" ) if self.hparams.training_spec.layer_type_dtypes else "float32" ) with torch.autocast( device_type=self.device.split(":")[0], dtype=amp_dtype ): output = model_to_call(data, True) loss, losses = self._calculate_loss(output, targets) else: output = model_to_call(data, True) loss, losses = self._calculate_loss(output, targets) total_loss_collect.append(loss.item()) for col, loss in losses.items(): total_losses_collect[col].append(loss.item()) # Free up GPU memory del data, targets, loss, losses if self.device == "cuda": torch.cuda.empty_cache() if len(total_loss_collect) > 0: total_loss_local = np.mean(total_loss_collect) total_losses_local = { col: np.mean(loss_list) for col, loss_list in total_losses_collect.items() } else: # Handle empty validation set case total_loss_local = 0.0 total_losses_local = {col: 0.0 for col in self.target_columns} # 2. Aggregate losses across all GPUs if in distributed mode if self.hparams.training_spec.distributed: # Put local losses into tensors for reduction total_loss_tensor = torch.tensor( total_loss_local, device=self.device, dtype=torch.float32 ) # Ensure consistent order for the losses tensor loss_keys = sorted(total_losses_local.keys()) losses_values = [total_losses_local[k] for k in loss_keys] losses_tensor = torch.tensor( losses_values, device=self.device, dtype=torch.float32 ) # Sum losses from all processes. The result is broadcast back to all processes. dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM) dist.all_reduce(losses_tensor, op=dist.ReduceOp.SUM) world_size = dist.get_world_size() total_loss_tensor /= world_size losses_tensor /= world_size # Update local variables with the aggregated global results total_loss_global = total_loss_tensor.cpu().numpy() losses_global_values = losses_tensor.cpu().numpy() total_losses_global = dict(zip(loss_keys, losses_global_values)) else: # If not distributed, local losses are the global losses total_loss_global = total_loss_local total_losses_global = total_losses_local # 3. Handle one-time baseline loss calculation (must also be synchronized) if not hasattr(self, "baseline_loss"): baseline_loss_local_collect = [] baseline_losses_local_collect = {col: [] for col in self.target_columns} # Iterate over the sharded validation loader for data, targets, _, _, _ in valid_loader: data = { k: v.to(self.device, non_blocking=True) for k, v in data.items() if k in self.input_columns } targets = { k: v.to(self.device, non_blocking=True) for k, v in targets.items() if k in self.target_column_types } pseudo_output = {} targets_for_baseline = {} for col in self.target_columns: if col in data: pseudo_output[col] = self._transform_val( col, data[col].transpose(0, 1) ) targets_for_baseline[col] = targets[col] if len(pseudo_output) > 0: loss, losses = self._calculate_loss( pseudo_output, targets_for_baseline ) baseline_loss_local_collect.append(loss.item()) for col, loss_ in losses.items(): baseline_losses_local_collect[col].append(loss_.item()) # Sum the losses for the local shard if len(baseline_loss_local_collect): baseline_loss_local = np.mean(baseline_loss_local_collect) baseline_losses_local = { col: np.mean(loss_list) for col, loss_list in baseline_losses_local_collect.items() } else: baseline_loss_local = -1.0 baseline_losses_local = {col: -1.0 for col in self.target_columns} # Broadcast the baseline values from the main process to all others if self.hparams.training_spec.distributed: total_loss_tensor = torch.tensor( baseline_loss_local, device=self.device, dtype=torch.float32 ) dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM) loss_keys = sorted(baseline_losses_local.keys()) losses_values = [baseline_losses_local[k] for k in loss_keys] losses_tensor = torch.tensor( losses_values, device=self.device, dtype=torch.float32 ) dist.all_reduce(losses_tensor, op=dist.ReduceOp.SUM) world_size = dist.get_world_size() total_loss_tensor /= world_size losses_tensor /= world_size self.baseline_loss = total_loss_tensor.item() self.baseline_losses = dict(zip(loss_keys, losses_tensor.cpu().numpy())) else: # If not distributed, local is global self.baseline_loss = baseline_loss_local self.baseline_losses = baseline_losses_local model_to_call.train() torch.clear_autocast_cache() return ( np.float32(total_loss_global), {k: np.float32(v) for k, v in total_losses_global.items()}, output, ) @beartype def _export(self, state_dict: dict[str, Tensor], suffix: str, epoch: int) -> None: """Exports the model. This is a wrapper function that handles exporting the model (and optionally the embedding-only model) on rank 0 only. Args: state_dict: The state dict of the model instance to export (e.g., best model or last model). suffix: A string suffix to append to the model filename (e.g., "best", "last"). epoch: The current epoch number, included in the filename. """ if self.rank != 0: return # Instantiate a clean, decoupled CPU model for the export phase export_model = TransformerModel(self.hparams) export_model.load_state_dict(state_dict) export_model.eval() os.makedirs(os.path.join(self.project_root, "models"), exist_ok=True) if self.export_generative_model: self._export_model(export_model, suffix, epoch) if self.export_embedding_model: model2 = TransformerEmbeddingModel(export_model) self._export_model(model2, f"{suffix}-embedding", epoch) @beartype def _export_model( self, model: Union["TransformerModel", "TransformerEmbeddingModel"], suffix: str, epoch: int, ) -> None: """Exports the model to ONNX and/or PyTorch format. Saves the model weights as a .pt file and/or exports the model graph and weights as an .onnx file based on the config. Args: model: The model instance (TransformerModel or TransformerEmbeddingModel). suffix: A string suffix for the filename (e.g., "best", "last-embedding"). epoch: The current epoch number, included in the filename. """ if self.export_onnx: is_different_type = any( p.dtype in [torch.float16, torch.bfloat16, torch.float64] for p in model.parameters() ) model_to_export = model if is_different_type: self.logger.info( "[INFO] Casting model to float32 for ONNX export compatibility..." ) # Safe to deepcopy since `model` is already a pure CPU, unwrapped PyTorch module here. model_to_export = model._copy_model().float() export_device = next(model_to_export.parameters()).device x_cat = { col: torch.randint( 0, self.n_classes[col], (self.inference_batch_size, self.seq_length) ).to(export_device, non_blocking=True) for col in self.categorical_columns } dtype_real = torch.float32 if is_different_type else None x_real = { col: torch.rand(self.inference_batch_size, self.seq_length).to( export_device, non_blocking=True, dtype=dtype_real ) for col in self.real_columns } x = {"src": {**x_cat, **x_real}} # Export the model export_path = os.path.join( self.project_root, "models", f"sequifier-{self.model_name}-{suffix}-{epoch}.onnx", ) training_mode = ( torch._C._onnx.TrainingMode.TRAINING if self.export_with_dropout else torch._C._onnx.TrainingMode.EVAL ) constant_folding = self.export_with_dropout == False # noqa: E712 torch.onnx.export( model_to_export, x, # model input (or a tuple for multiple inputs) export_path, # where to save the model export_params=True, # store the trained parameter weights opset_version=14, # the ONNX version do_constant_folding=constant_folding, input_names=["input"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size"}, "output": {0: "batch_size"}, }, training=training_mode, ) if self.export_pt: export_path = os.path.join( self.project_root, "models", f"sequifier-{self.model_name}-{suffix}-{epoch}.pt", ) torch.save( { "model_state_dict": model.state_dict(), "export_with_dropout": self.export_with_dropout, }, export_path, ) @beartype def _save( self, epoch: int, batch: int, val_loss: np.float32, ddp_model: Optional[nn.Module] = None, suffix: Optional[str] = None, ) -> None: """Saves the model checkpoint. Saves the model state, optimizer state, and epoch number to a .pt file in the checkpoints directory. Only runs on rank 0. Args: val_loss: The validation loss at the current epoch. ddp_model: DDP model suffix: Checkpoint file suffix. """ model_to_extract = ddp_model if ddp_model is not None else self is_fsdp = self.hparams.training_spec.fsdp if is_fsdp: save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) optim_policy = FullOptimStateDictConfig( offload_to_cpu=True, rank0_only=True ) with FSDP.state_dict_type( model_to_extract, StateDictType.FULL_STATE_DICT, save_policy, optim_policy, ): model_state_dict = { k.replace("_orig_mod.", ""): v for k, v in model_to_extract.state_dict().items() } optim_state_dict = FSDP.full_optim_state_dict( model_to_extract, self.optimizer ) else: model_state_dict = self.state_dict() model_state_dict = { k.replace("_orig_mod.", ""): v for k, v in self.state_dict().items() } optim_state_dict = self.optimizer.state_dict() if self.rank != 0: return os.makedirs(os.path.join(self.project_root, "checkpoints"), exist_ok=True) file_name = f"{self.model_name}-{suffix}.pt" output_path = os.path.join( self.project_root, "checkpoints", file_name, ) torch.save( { "epoch": epoch, "batch": batch, "model_state_dict": model_state_dict, "optimizer_state_dict": optim_state_dict, "scheduler_state_dict": self.scheduler.state_dict(), "loss": val_loss, }, output_path, ) self.logger.info(f"[INFO] Saved checkpoint to {output_path}") @beartype def _get_optimizer(self, params: Any, **kwargs): """Gets the optimizer. Initializes the optimizer specified in the hyperparameters. Args: params: params **kwargs: Additional arguments to pass to the optimizer constructor (e.g., weight_decay). Returns: An initialized torch.optim.Optimizer instance. """ optimizer_class = get_optimizer_class(self.hparams.training_spec.optimizer.name) return optimizer_class( params, lr=self.hparams.training_spec.learning_rate, **kwargs ) @beartype def _get_scheduler(self, **kwargs): """Gets the scheduler. Initializes the learning rate scheduler specified in the hyperparameters. Args: **kwargs: Additional arguments to pass to the scheduler constructor (e.g., step_size). Returns: An initialized torch.optim.lr_scheduler._LRScheduler instance. """ scheduler_name = self.hparams.training_spec.scheduler.name if hasattr(torch.optim.lr_scheduler, scheduler_name): scheduler_class = getattr(torch.optim.lr_scheduler, scheduler_name) else: raise ValueError( f"Scheduler {scheduler_name} not found in torch.optim.lr_scheduler" ) return scheduler_class(self.optimizer, **kwargs) @beartype def _initialize_log_file(self): """Initializes the log file.""" # Replaces old LogFile class instantiation self.logger = configure_logger(self.project_root, self.model_name, self.rank) @beartype def _get_latest_model_name(self) -> Optional[str]: """Gets the name of the latest model checkpoint. Scans the checkpoints directory for files matching the current `model_name` and returns the path to the most recently modified one. Returns: The file path (str) to the latest checkpoint, or None if no checkpoint is found. """ checkpoint_path = os.path.join(self.project_root, "checkpoints", "*") files = glob.glob(checkpoint_path) files = [ file for file in files if os.path.split(file)[1].startswith(self.model_name) ] if files: return max(files, key=os.path.getctime) else: return None @beartype def _log_epoch_results( self, epoch: int, batch: int, elapsed: float, total_loss: np.float32, total_losses: dict[str, np.float32], output: dict[str, Tensor], ) -> None: """Logs the results of an epoch. Writes validation loss, individual losses, learning rate, and class share statistics (if configured) to the log file. Only runs on rank 0. Args: epoch: Current epoch number. elapsed: Time taken for the epoch (in seconds). total_loss: The total aggregated validation loss. total_losses: A dictionary of aggregated losses for each target. output: The output tensor dictionary from the last validation batch, used for class share logging. batch: Current batch number. """ if self.rank == 0: learning_rate = self.optimizer.state_dict()["param_groups"][0]["lr"] log_string = f"[INFO] Validation | Epoch: {epoch:3d} | Batch: {batch} | Loss: {format_number(total_loss)} | Baseline Loss: {format_number(self.baseline_loss)} | Time: {elapsed:5.2f}s | LR {format_number(learning_rate)}" self.logger.info("-" * 89) self.logger.info(log_string) if len(total_losses) > 1: loss_strs = [ f"{key}_loss: {format_number(value)}" for key, value in total_losses.items() ] self.logger.info("[INFO] - " + ", ".join(loss_strs)) for categorical_column in self.class_share_log_columns: output_values = ( output[categorical_column].argmax(1).cpu().detach().numpy() ) output_counts_df = ( pl.Series("values", output_values).value_counts().sort("values") ) output_counts = output_counts_df.get_column("count") output_counts = output_counts / output_counts.sum() value_shares = " | ".join( [ f"{self.index_maps[categorical_column][row['values']]}: {row['count']:5.5f}" for row in output_counts_df.iter_rows(named=True) ] ) self.logger.info(f"[INFO] {categorical_column}: {value_shares}") self.logger.info("-" * 89)
[docs]@beartype def load_inference_model( model_type: str, model_path: str, training_config_path: str, args_config: dict[str, Any], device: str, infer_with_dropout: bool, ) -> torch.nn.Module: """Loads a trained model for inference. Args: model_type: "generative" or "embedding". model_path: Path to the saved .pt model file. training_config_path: Path to the .yaml config file used for training. args_config: A dictionary of override configurations. device: The device to load the model onto (e.g., "cuda", "cpu"). infer_with_dropout: Whether to force dropout layers to be active during inference. Returns: The loaded and compiled torch.nn.Module (TransformerModel or TransformerEmbeddingModel) in evaluation mode. """ skip_metadata = args_config.get("skip_metadata", False) args_config_subset = { k: v for k, v in args_config.items() if k not in ["model_path", "data_path"] } training_config = load_train_config( training_config_path, args_config_subset, skip_metadata ) with torch.no_grad(): model = TransformerModel(training_config) if model_type == "generative": model = TransformerModel(training_config) elif model_type == "embedding": model_inner = TransformerModel(training_config) model = TransformerEmbeddingModel(model_inner) else: assert False, "impossible" model.logger.info(f"[INFO] Loading model weights from {model_path}") model_state = torch.load( model_path, map_location=torch.device(device), weights_only=False ) model.load_state_dict(model_state["model_state_dict"]) model.eval() if infer_with_dropout: if not model_state["export_with_dropout"]: warnings.warn( "Model was exported with 'export_with_dropout'==False. By setting 'infer_with_dropout' to True, you are overriding this configuration" ) for module in model.modules(): if isinstance(module, torch.nn.Dropout): module.train() if device.startswith("cuda"): model = torch.compile(model).to(device) else: model.to(device) return model
[docs]@beartype def infer_with_embedding_model( model: nn.Module, x: list[dict[str, np.ndarray]], device: str, size: int, target_columns: list[str], ) -> np.ndarray: """Performs inference with an embedding model. Args: model: The loaded TransformerEmbeddingModel. x: A list of input data dictionaries (batched). device: The device to run inference on. size: The total number of samples (unused in this function). target_columns: List of target column names (unused in this function). Returns: A NumPy array containing the concatenated embeddings from all batches. """ outs0 = [] categorical_cols = set(model.transformer_model.categorical_columns) with torch.no_grad(): for x_sub in x: layer_types = ( model.transformer_model.hparams.training_spec.layer_type_dtypes or {} ) dtype_str = layer_types.get("linear", "float32") ref_dtype = get_torch_dtype(dtype_str) data_gpu = {} for col, x_ in x_sub.items(): if col in categorical_cols: data_gpu[col] = torch.from_numpy(x_).to(device, dtype=torch.int64) else: data_gpu[col] = torch.from_numpy(x_).to(device, dtype=ref_dtype) output_gpu = model.forward(data_gpu) output_cpu = output_gpu.cpu().detach().float().numpy() output_cpu = output_cpu.transpose(1, 0, 2).reshape( output_cpu.shape[0] * output_cpu.shape[1], output_cpu.shape[2] ) outs0.append(output_cpu) if device == "cuda": torch.cuda.empty_cache() outs = np.concatenate(outs0, axis=0) return outs
[docs]@beartype def infer_with_generative_model( model: nn.Module, x: list[dict[str, np.ndarray]], device: str, size: int, target_columns: list[str], ) -> dict[str, np.ndarray]: """Performs inference with a generative model. Args: model: The loaded TransformerModel. x: A list of input data dictionaries (batched). device: The device to run inference on. size: The total number of samples to trim the final output to. target_columns: List of target column names to extract from the output. Returns: A dictionary mapping target column names to their concatenated output NumPy arrays, trimmed to `size`. """ outs0 = [] categorical_cols = set(model.categorical_columns) with torch.no_grad(): for x_sub in x: layer_types = model.hparams.training_spec.layer_type_dtypes or {} dtype_str = layer_types.get("linear", "float32") ref_dtype = get_torch_dtype(dtype_str) data_gpu = {} for col, x_ in x_sub.items(): if col in categorical_cols: data_gpu[col] = torch.from_numpy(x_).to(device, dtype=torch.int64) else: data_gpu[col] = torch.from_numpy(x_).to(device, dtype=ref_dtype) output_gpu = model.forward(data_gpu) output_cpu = {k: v.cpu().detach() for k, v in output_gpu.items()} outs0.append(output_cpu) if device == "cuda": torch.cuda.empty_cache() outs = { target_column: np.concatenate( [ o[target_column] .float() .numpy() .transpose(1, 0, 2) .reshape( o[target_column].shape[0] * o[target_column].shape[1], o[target_column].shape[2], ) for o in outs0 ], axis=0, )[:size, :] for target_column in target_columns } return outs