Source code for sequifier.train
import copy
import glob
import math
import os
import time
import uuid
import warnings
from typing import Any, Optional, Union
import numpy as np
import polars as pl
import torch
import torch._dynamo
import torch.distributed as dist
import torch.multiprocessing as mp
from beartype import beartype
from torch import Tensor, nn
from torch.nn import ModuleDict, TransformerEncoder, TransformerEncoderLayer
from torch.nn.functional import one_hot
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
torch._dynamo.config.suppress_errors = True
from sequifier.config.train_config import TrainModel, load_train_config # noqa: E402
from sequifier.helpers import LogFile # noqa: E402
from sequifier.helpers import construct_index_maps # noqa: E402
from sequifier.io.sequifier_dataset_from_file import ( # noqa: E402
SequifierDatasetFromFile,
)
from sequifier.io.sequifier_dataset_from_folder import ( # noqa: E402
SequifierDatasetFromFolder,
)
from sequifier.io.sequifier_dataset_from_folder_lazy import ( # noqa: E402
SequifierDatasetFromFolderLazy,
)
from sequifier.optimizers.optimizers import get_optimizer_class # noqa: E402
from sequifier.samplers.distributed_grouped_random_sampler import ( # noqa: E402
DistributedGroupedRandomSampler,
)
[docs]
@beartype
def setup(rank: int, world_size: int, backend: str = "nccl"):
"""Sets up the distributed training environment.
Args:
rank: The rank of the current process.
world_size: The total number of processes.
backend: The distributed backend to use.
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
dist.init_process_group(backend, rank=rank, world_size=world_size)
[docs]
def cleanup():
"""Cleans up the distributed training environment."""
dist.destroy_process_group()
[docs]
@beartype
def train_worker(rank: int, world_size: int, config: TrainModel, from_folder: bool):
"""The worker function for distributed training.
Args:
rank: The rank of the current process.
world_size: The total number of processes.
config: The training configuration.
from_folder: Whether to load data from a folder (e.g., preprocessed .pt files)
or a single file (e.g., .parquet).
"""
if config.training_spec.distributed:
setup(rank, world_size, config.training_spec.backend)
# 1. Create Datasets and DataLoaders with DistributedSampler
if from_folder:
if config.training_spec.load_full_data_to_ram:
train_dataset = SequifierDatasetFromFolder(
config.training_data_path, config
)
valid_dataset = SequifierDatasetFromFolder(
config.validation_data_path, config
)
else:
train_dataset = SequifierDatasetFromFolderLazy(
config.training_data_path, config
)
valid_dataset = SequifierDatasetFromFolderLazy(
config.validation_data_path, config
)
else:
assert config.training_spec.distributed == False # noqa: E712
train_dataset = SequifierDatasetFromFile(config.training_data_path, config)
valid_dataset = SequifierDatasetFromFile(config.validation_data_path, config)
if from_folder:
if config.training_spec.distributed:
# 2. Use the new distributed sampler for the multi-GPU case
train_sampler = DistributedGroupedRandomSampler(
train_dataset, num_replicas=world_size, rank=rank
)
valid_sampler = DistributedGroupedRandomSampler(
valid_dataset, num_replicas=world_size, rank=rank, shuffle=False
)
else:
# Use the simple grouped sampler for the single-GPU case
train_sampler = RandomSampler(train_dataset)
valid_sampler = None
else:
train_sampler = (
DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
if config.training_spec.distributed
else None
)
valid_sampler = (
DistributedSampler(
valid_dataset, num_replicas=world_size, rank=rank, shuffle=False
)
if config.training_spec.distributed
else None
)
if from_folder:
train_loader = DataLoader(
train_dataset,
batch_size=config.training_spec.batch_size,
sampler=train_sampler,
shuffle=False, # Shuffle only if not using sampler
num_workers=config.training_spec.num_workers, # Use multiple workers for data loading
pin_memory=config.training_spec.device not in ["mps", "cpu"],
)
# For validation, it's often fine to just run it on the main process
valid_loader = DataLoader(
valid_dataset,
batch_size=config.training_spec.batch_size,
sampler=valid_sampler,
shuffle=False,
)
elif not from_folder:
train_loader = DataLoader(
train_dataset,
batch_size=None,
sampler=None,
num_workers=config.training_spec.num_workers,
pin_memory=False,
persistent_workers=(config.training_spec.num_workers > 0),
)
valid_loader = DataLoader(
valid_dataset, batch_size=None, sampler=None, shuffle=False
)
else:
assert False, "not possible"
# 2. Instantiate and wrap the model
torch.manual_seed(config.seed)
np.random.seed(config.seed)
model = TransformerModel(config, rank)
if config.training_spec.distributed:
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model = torch.compile(model)
# 3. Start training
# When using DDP, the original model is accessed via the .module attribute
original_model = model.module if config.training_spec.distributed else model
original_model.train_model(train_loader, valid_loader, train_sampler, valid_sampler)
if config.training_spec.distributed:
cleanup()
[docs]
@beartype
def train(args: Any, args_config: dict[str, Any]) -> None:
"""The main training function.
Args:
args: The command-line arguments.
args_config: The configuration dictionary.
"""
config_path = args.config_path or "configs/train.yaml"
config = load_train_config(config_path, args_config, args.on_unprocessed)
print(f"--- Starting Training for model: {config.model_name} ---")
world_size = config.training_spec.world_size
from_folder = config.read_format == "pt"
if config.training_spec.distributed:
mp.spawn(
train_worker,
args=(world_size, config, from_folder),
nprocs=world_size,
join=True,
)
else:
# Fallback to single-GPU/CPU training
train_worker(0, world_size, config, from_folder)
[docs]
@beartype
def format_number(number: Union[int, float, np.float32]) -> str:
"""Format a number for display.
Args:
number: The number to format.
Returns:
A formatted string representation of the number.
"""
if np.isnan(number):
return "NaN"
elif number == 0:
order_of_magnitude = 0
else:
order_of_magnitude = math.floor(math.log(number, 10))
number_adjusted = number * (10 ** (-order_of_magnitude))
return f"{number_adjusted:5.2f}e{order_of_magnitude}"
[docs]
class TransformerEmbeddingModel(nn.Module):
"""A wrapper around the TransformerModel to expose the embedding functionality."""
[docs]
def __init__(self, transformer_model: "TransformerModel"):
"""Initializes the TransformerEmbeddingModel.
Args:
transformer_model: The TransformerModel to wrap.
"""
super().__init__()
self.transformer_model = transformer_model
self.log_file = self.transformer_model.log_file
[docs]
def forward(self, src: dict[str, Tensor]):
"""Forward pass for the embedding model.
Args:
src: The input data.
Returns:
The embedded output.
"""
return self.transformer_model.forward_embed(src)
[docs]
class TransformerModel(nn.Module):
"""The main Transformer model for the sequifier.
This class implements the Transformer model, including the training and
evaluation loops, as well as the export functionality.
"""
[docs]
@beartype
def __init__(self, hparams: Any, rank: Optional[int] = None):
"""Initializes the TransformerModel.
Based on the hyperparameters, this initializes:
- Embeddings for categorical and real features (self.encoder)
- Positional encoders (self.pos_encoder)
- The main TransformerEncoder (self.transformer_encoder)
- Output decoders for each target column (self.decoder)
- Loss functions (self.criterion)
- Optimizer (self.optimizer) and scheduler (self.scheduler)
Args:
hparams: The hyperparameters for the model (e.g., from TrainModel config).
rank: The rank of the current process (for distributed training).
"""
super().__init__()
self.project_path = hparams.project_path
self.model_type = "Transformer"
self.model_name = hparams.model_name or uuid.uuid4().hex[:8]
self.rank = rank
self.selected_columns = hparams.selected_columns
self.categorical_columns = [
col
for col in hparams.categorical_columns
if self.selected_columns is None or col in self.selected_columns
]
self.real_columns = [
col
for col in hparams.real_columns
if self.selected_columns is None or col in self.selected_columns
]
self.target_columns = hparams.target_columns
self.target_column_types = hparams.target_column_types
self.loss_weights = hparams.training_spec.loss_weights
self.seq_length = hparams.seq_length
self.n_classes = hparams.n_classes
self.inference_batch_size = hparams.inference_batch_size
self.log_interval = hparams.training_spec.log_interval
self.class_share_log_columns = hparams.training_spec.class_share_log_columns
self.index_maps = construct_index_maps(
hparams.id_maps, self.class_share_log_columns, True
)
self.export_embedding_model = hparams.export_embedding_model
self.export_generative_model = hparams.export_generative_model
self.export_onnx = hparams.export_onnx
self.export_pt = hparams.export_pt
self.export_with_dropout = hparams.export_with_dropout
self.early_stopping_epochs = hparams.training_spec.early_stopping_epochs
self.hparams = hparams
self.drop = nn.Dropout(hparams.training_spec.dropout)
self.encoder = ModuleDict()
self.pos_encoder = ModuleDict()
self.embedding_size = max(
self.hparams.model_spec.d_model, self.hparams.model_spec.nhead
)
if hparams.model_spec.d_model_by_column is not None:
self.d_model_by_column = hparams.model_spec.d_model_by_column
else:
self.d_model_by_column = self._get_d_model_by_column(
self.embedding_size, self.categorical_columns, self.real_columns
)
self.real_columns_with_embedding = []
self.real_columns_direct = []
for col in self.real_columns:
if self.d_model_by_column[col] > 1:
self.encoder[col] = nn.Linear(1, self.d_model_by_column[col])
self.real_columns_with_embedding.append(col)
else:
assert self.d_model_by_column[col] == 1
self.real_columns_direct.append(col)
self.pos_encoder[col] = nn.Embedding(
self.seq_length, self.d_model_by_column[col]
)
for col, n_classes in self.n_classes.items():
if col in self.categorical_columns:
self.encoder[col] = nn.Embedding(n_classes, self.d_model_by_column[col])
self.pos_encoder[col] = nn.Embedding(
self.seq_length, self.d_model_by_column[col]
)
encoder_layers = TransformerEncoderLayer(
self.embedding_size,
hparams.model_spec.nhead,
hparams.model_spec.d_hid,
hparams.training_spec.dropout,
)
self.transformer_encoder = TransformerEncoder(
encoder_layers, hparams.model_spec.nlayers, enable_nested_tensor=False
)
self.decoder = ModuleDict()
self.softmax = ModuleDict()
for target_column, target_column_type in self.target_column_types.items():
if target_column_type == "categorical":
self.decoder[target_column] = nn.Linear(
self.embedding_size,
self.n_classes[target_column],
)
self.softmax[target_column] = nn.LogSoftmax(dim=-1)
elif target_column_type == "real":
self.decoder[target_column] = nn.Linear(self.embedding_size, 1)
else:
raise ValueError(
f"Target column type {target_column_type} not in ['categorical', 'real']"
)
self.device = hparams.training_spec.device
self.device_max_concat_length = hparams.training_spec.device_max_concat_length
if hparams.training_spec.device == "cuda" and self.rank is not None:
self.device = f"cuda:{self.rank}"
else:
self.device = hparams.training_spec.device
self.to(self.device)
self.criterion = self._init_criterion(hparams=hparams)
self.batch_size = hparams.training_spec.batch_size
self.accumulation_steps = hparams.training_spec.accumulation_steps
self.src_mask = self._generate_square_subsequent_mask(self.seq_length).to(
self.device
)
self._init_weights()
self.optimizer = self._get_optimizer(
**self._filter_key(hparams.training_spec.optimizer, "name")
)
self.scheduler = self._get_scheduler(
**self._filter_key(hparams.training_spec.scheduler, "name")
)
self.iter_save = hparams.training_spec.iter_save
self.continue_training = hparams.training_spec.continue_training
load_string = self._load_weights_conditional()
self._initialize_log_file()
self.log_file.write(load_string)
@beartype
def _init_criterion(self, hparams: Any) -> dict[str, Any]:
"""Initializes the criterion (loss function) for each target column.
Args:
hparams: The hyperparameters for the model, used to find criterion names
and class weights.
Returns:
A dictionary mapping target column names to their loss function instances.
"""
criterion = {}
for target_column in self.target_columns:
criterion_class = eval(
f"torch.nn.{hparams.training_spec.criterion[target_column]}"
)
criterion_kwargs = {}
if (
hparams.training_spec.class_weights is not None
and target_column in hparams.training_spec.class_weights
):
criterion_kwargs["weight"] = Tensor(
hparams.training_spec.class_weights[target_column]
).to(self.device)
criterion[target_column] = criterion_class(**criterion_kwargs)
return criterion
@beartype
def _get_d_model_by_column(
self,
embedding_size: int,
categorical_columns: list[str],
real_columns: list[str],
) -> dict[str, int]:
"""Calculates the embedding dimension for each column.
This attempts to distribute the total `embedding_size` across all
input columns.
Args:
embedding_size: The total embedding dimension (d_model).
categorical_columns: List of categorical column names.
real_columns: List of real-valued column names.
Returns:
A dictionary mapping column names to their calculated embedding dimension.
"""
print(f"{len(categorical_columns) = } {len(real_columns) = }")
assert (len(categorical_columns) + len(real_columns)) > 0, "No columns found"
if len(categorical_columns) == 0 and len(real_columns) > 0:
d_model_by_column = {col: 1 for col in real_columns}
column_index = dict(enumerate(real_columns))
for i in range(embedding_size):
if sum(d_model_by_column.values()) % embedding_size != 0:
j = i % len(real_columns)
d_model_by_column[column_index[j]] += 1
assert sum(d_model_by_column.values()) % embedding_size == 0
elif len(real_columns) == 0 and len(categorical_columns) > 0:
assert (
(embedding_size % len(categorical_columns)) == 0
), f"If only categorical variables are included, d_model must be a multiple of the number of categorical variables ({embedding_size = } % {len(categorical_columns) = }) != 0"
d_model_comp = embedding_size // len(categorical_columns)
d_model_by_column = {col: d_model_comp for col in categorical_columns}
else:
raise UserWarning(
"If both real and categorical variables are present, d_model_by_column config value must be set"
)
return d_model_by_column
@staticmethod
def _generate_square_subsequent_mask(sz: int) -> Tensor:
"""Generates an upper-triangular matrix of -inf, with zeros on diag.
This is used as a mask to prevent attention to future tokens in the
transformer.
Args:
sz: The size of the square mask (sequence length).
Returns:
A square tensor of shape (sz, sz) with -inf in the upper triangle.
"""
return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)
@staticmethod
def _filter_key(dict_: dict[str, Any], key: str) -> dict[str, Any]:
"""Filters a key from a dictionary.
Args:
dict_: The dictionary to filter.
key: The key to remove.
Returns:
A new dictionary without the specified key.
"""
return {k: v for k, v in dict_.items() if k != key}
@beartype
def _init_weights(self) -> None:
"""Initializes the weights of the model."""
init_std = 0.02
for col in self.categorical_columns:
self.encoder[col].weight.data.normal_(mean=0.0, std=init_std)
for target_column in self.target_columns:
self.decoder[target_column].bias.data.zero_()
self.decoder[target_column].weight.data.normal_(mean=0.0, std=init_std)
for col_name in self.pos_encoder:
self.pos_encoder[col_name].weight.data.normal_(mean=0.0, std=init_std)
@beartype
def _recursive_concat(self, srcs: list[Tensor]):
"""Recursively concatenates a list of tensors.
This is used to avoid device-specific limits on the number of tensors
that can be concatenated at once by breaking the operation into
smaller, recursive chunks.
Args:
srcs: A list of tensors to concatenate along dimension 2.
Returns:
A single tensor resulting from the recursive concatenation.
"""
if len(srcs) <= self.device_max_concat_length:
return torch.cat(srcs, 2)
else:
srcs_inner = []
for start in range(0, len(srcs), self.device_max_concat_length):
src = self._recursive_concat(
srcs[start : start + self.device_max_concat_length]
)
srcs_inner.append(src)
return self._recursive_concat(srcs_inner)
[docs]
@beartype
def forward_inner(self, src: dict[str, Tensor]) -> Tensor:
"""The inner forward pass of the model.
This handles embedding lookup, positional encoding, and passing the
combined tensor through the transformer encoder.
Args:
src: A dictionary mapping column names to input tensors
(batch_size, seq_length).
Returns:
The raw output tensor from the TransformerEncoder
(seq_length, batch_size, d_model).
"""
srcs = []
for col in self.categorical_columns:
src_t = self.encoder[col](src[col].T) * math.sqrt(self.embedding_size)
pos = (
torch.arange(0, self.seq_length, dtype=torch.long, device=self.device)
.repeat(src_t.shape[1], 1)
.T
)
src_p = self.pos_encoder[col](pos)
src_c = self.drop(src_t + src_p)
srcs.append(src_c)
for col in self.real_columns:
if col in self.real_columns_direct:
src_t = src[col].T.unsqueeze(2).repeat(1, 1, 1) * math.sqrt(
self.embedding_size
)
else:
assert col in self.real_columns_with_embedding
src_t = self.encoder[col](src[col].T[:, :, None]) * math.sqrt(
self.embedding_size
)
pos = (
torch.arange(0, self.seq_length, dtype=torch.long, device=self.device)
.repeat(src_t.shape[1], 1)
.T
)
src_p = self.pos_encoder[col](pos)
src_c = self.drop(src_t + src_p)
srcs.append(src_c)
src2 = self._recursive_concat(srcs)
output = self.transformer_encoder(src2, self.src_mask)
return output
[docs]
@beartype
def forward_embed(self, src: dict[str, Tensor]) -> Tensor:
"""Forward pass for the embedding model.
This returns only the embedding from the *last* token in the sequence.
Args:
src: A dictionary mapping column names to input tensors
(batch_size, seq_length).
Returns:
The embedding tensor for the last token
(batch_size, d_model).
"""
return self.forward_inner(src)[-1, :, :]
[docs]
@beartype
def forward_train(self, src: dict[str, Tensor]) -> dict[str, Tensor]:
"""Forward pass for training.
This runs the inner forward pass and then applies the appropriate
decoder for each target column.
Args:
src: A dictionary mapping column names to input tensors
(batch_size, seq_length).
Returns:
A dictionary mapping target column names to their raw output
(logit) tensors (seq_length, batch_size, n_classes/1).
"""
output = self.forward_inner(src)
output = {
target_column: self.decode(target_column, output)
for target_column in self.target_columns
}
return output
[docs]
@beartype
def decode(self, target_column: str, output: Tensor) -> Tensor:
"""Decodes the output of the transformer encoder.
Applies the appropriate final linear layer for a given target column.
Args:
target_column: The name of the target column to decode.
output: The raw output tensor from the TransformerEncoder
(seq_length, batch_size, d_model).
Returns:
The decoded output (logits or real value) for the target column
(seq_length, batch_size, n_classes/1).
"""
decoded = self.decoder[target_column](output)
return decoded
[docs]
@beartype
def apply_softmax(self, target_column: str, output: Tensor) -> Tensor:
"""Applies softmax to the output of the decoder.
If the target is real, it returns the output unchanged.
If the target is categorical, it applies LogSoftmax.
Args:
target_column: The name of the target column.
output: The decoded output tensor (logits or real value).
Returns:
The output tensor, with LogSoftmax applied if categorical.
"""
if target_column in self.real_columns:
return output
else:
return self.softmax[target_column](output)
[docs]
@beartype
def forward(self, src: dict[str, Tensor]) -> dict[str, Tensor]:
"""The main forward pass of the model.
This is typically used for inference/evaluation, returning the
probabilities/values for the *last* token in the sequence.
Args:
src: A dictionary mapping column names to input tensors
(batch_size, seq_length).
Returns:
A dictionary mapping target column names to their final
output (LogSoftmax probabilities or real values) for the
last token (batch_size, n_classes/1).
"""
output = self.forward_train(src)
return {
target_column: self.apply_softmax(target_column, out[-1, :, :])
for target_column, out in output.items()
}
[docs]
@beartype
def train_model(
self,
train_loader: DataLoader,
valid_loader: DataLoader,
train_sampler: Optional[
Union[RandomSampler, DistributedSampler, DistributedGroupedRandomSampler]
],
valid_sampler: Optional[
Union[RandomSampler, DistributedSampler, DistributedGroupedRandomSampler]
],
) -> None:
"""Trains the model.
This method contains the main training loop, including epoch iteration,
validation, early stopping logic, and model saving/exporting.
Args:
train_loader: DataLoader for the training dataset.
valid_loader: DataLoader for the validation dataset.
train_sampler: Sampler for the training DataLoader, used to set
the epoch in distributed training.
valid_sampler: Sampler for the validation DataLoader, used to set
the epoch in distributed training.
"""
best_val_loss = float("inf")
n_epochs_no_improvement = 0
for epoch in range(self.start_epoch, self.hparams.training_spec.epochs + 1):
if (
self.early_stopping_epochs is None
or n_epochs_no_improvement < self.early_stopping_epochs
) and (
epoch == self.start_epoch
or epoch > self.start_epoch
and not np.isnan(total_loss) # type: ignore # noqa: F821
):
epoch_start_time = time.time()
if train_sampler and not isinstance(train_sampler, RandomSampler):
train_sampler.set_epoch(epoch)
self._train_epoch(train_loader, epoch)
if valid_sampler and not isinstance(valid_sampler, RandomSampler):
valid_sampler.set_epoch(epoch)
total_loss, total_losses, output = self._evaluate(valid_loader)
elapsed = time.time() - epoch_start_time
self._log_epoch_results(
epoch, elapsed, total_loss, total_losses, output
)
if total_loss < best_val_loss:
best_val_loss = total_loss
best_model = self._copy_model()
n_epochs_no_improvement = 0
else:
n_epochs_no_improvement += 1
self.scheduler.step()
if epoch % self.iter_save == 0:
self._save(epoch, total_loss)
last_epoch = epoch
self._export(self, "last", last_epoch) # type: ignore
self._export(best_model, "best", last_epoch) # type: ignore
self.log_file.write("--- Training Complete ---")
@beartype
def _train_epoch(
self,
train_loader: DataLoader,
epoch: int,
) -> None:
"""Trains the model for one epoch.
Iterates through the training DataLoader, computes loss, performs
backpropagation, and updates model parameters. The DataLoader is expected
to yield tuples of (data_dict, targets_dict, sequence_ids, subsequence_ids, start_positions).
The IDs and positions are currently unused in this training loop.
Args:
train_loader: DataLoader for the training dataset.
epoch: The current epoch number (used for logging).
"""
self.train()
total_loss = 0.0
start_time = time.time()
num_batches = len(train_loader)
for batch_count, (data, targets, _, _, _) in enumerate(train_loader):
data = {
k: v.to(self.device, non_blocking=True)
for k, v in data.items()
if k in self.selected_columns
}
targets = {
k: v.to(self.device, non_blocking=True)
for k, v in targets.items()
if k in self.target_column_types
}
output = self.forward_train(data)
loss, losses = self._calculate_loss(output, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), 0.5)
if (
self.accumulation_steps is None
or (batch_count + 1) % self.accumulation_steps == 0
or (batch_count + 1) == num_batches
):
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += loss.item()
if (batch_count + 1) % self.log_interval == 0 and self.rank == 0:
lr = self.scheduler.get_last_lr()[0]
s_per_batch = (time.time() - start_time) / self.log_interval
self.log_file.write(
f"[INFO] Epoch {epoch:3d} | Batch {(batch_count+1):5d}/{num_batches:5d} | Loss: {format_number(total_loss)} | LR: {format_number(lr)} | S/Batch {format_number(s_per_batch)}"
)
total_loss = 0.0
start_time = time.time()
del data, targets, output, loss, losses
@beartype
def _calculate_loss(
self, output: dict[str, Tensor], targets: dict[str, Tensor]
) -> tuple[Tensor, dict[str, Tensor]]:
"""Calculates the loss for the given output and targets.
Compares the model's output (from `forward_train`) with the target
values, applying the appropriate criterion for each target column
and combining them using loss weights.
Args:
output: A dictionary of output tensors from the model
(seq_length, batch_size, n_classes/1).
targets: A dictionary of target tensors
(batch_size, seq_length).
Returns:
A tuple containing:
- The total combined (weighted) loss as a single Tensor.
- A dictionary of individual (unweighted) loss Tensors for each
target column.
"""
losses = {}
for target_column, target_column_type in self.target_column_types.items():
if target_column_type == "categorical":
output[target_column] = output[target_column].reshape(
-1, self.n_classes[target_column]
)
elif target_column_type == "real":
output[target_column] = output[target_column].reshape(-1)
losses[target_column] = self.criterion[target_column](
output[target_column], targets[target_column].T.contiguous().reshape(-1)
)
loss = None
for target_column in self.target_columns:
losses[target_column] = losses[target_column] * (
self.loss_weights[target_column]
if self.loss_weights is not None
else 1.0
)
if loss is None:
loss = losses[target_column].clone()
else:
loss += losses[target_column]
assert loss is not None
return loss, losses
@beartype
def _copy_model(self):
"""Copies the model.
This creates a deep copy of the model, typically for saving the
"best model". It temporarily removes the `log_file` attribute
before copying to avoid errors, then re-initializes it.
Returns:
A deep copy of the current TransformerModel instance.
"""
log_file = self.log_file
del self.log_file
model_copy = copy.deepcopy(self)
model_copy._initialize_log_file()
self.log_file = log_file
return model_copy
@beartype
def _transform_val(self, col: str, val: Tensor) -> Tensor:
"""Transforms input data to match the format of model output.
This is used *only* for calculating the baseline loss, where
the input (e.g., categorical indices) needs to be one-hot encoded
to be comparable to the model's (logit) output.
Args:
col: The name of the column being transformed.
val: The input tensor (categorical indices).
Returns:
A tensor transformed to be compatible with the loss function
(e.g., one-hot encoded).
"""
if self.target_column_types[col] == "categorical":
return (
one_hot(val, self.n_classes[col])
.reshape(-1, self.n_classes[col])
.float()
)
else:
assert self.target_column_types[col] == "real"
return val
@beartype
def _evaluate(
self, valid_loader: DataLoader
) -> tuple[np.float32, dict[str, np.float32], dict[str, Tensor]]:
"""Evaluates the model on the validation set.
Iterates through the validation data, calculates the total loss,
and aggregates results across all processes if in distributed mode.
Also calculates a one-time baseline loss on the first call.
The DataLoader is expected to yield tuples of
(data_dict, targets_dict, sequence_ids, subsequence_ids, start_positions).
The IDs and positions are currently unused during evaluation.
Args:
valid_loader: DataLoader for the validation dataset.
Returns:
A tuple containing:
- The total aggregated validation loss (float).
- A dictionary of aggregated losses for each target column (dict[str, float]).
- The output tensor dictionary from the last batch (used for class share logging).
"""
self.eval() # Turn on evaluation mode
total_loss_collect = []
# Initialize a dict to hold lists of losses for each target
total_losses_collect = {col: [] for col in self.target_columns}
output = {} # for type checking
with torch.no_grad():
for data, targets, _, _, _ in valid_loader:
# Move data to the current process's assigned GPU
data = {
k: v.to(self.device, non_blocking=True)
for k, v in data.items()
if k in self.selected_columns
}
targets = {
k: v.to(self.device, non_blocking=True)
for k, v in targets.items()
if k in self.target_column_types
}
output = self.forward_train(data)
loss, losses = self._calculate_loss(output, targets)
total_loss_collect.append(loss.item())
for col, loss in losses.items():
total_losses_collect[col].append(loss.item())
# Free up GPU memory
del data, targets, loss, losses
if self.device == "cuda":
torch.cuda.empty_cache()
# 1. Sum the losses calculated on this GPU process
total_loss_local = np.sum(total_loss_collect)
total_losses_local = {
col: np.sum(loss_list) for col, loss_list in total_losses_collect.items()
}
# 2. Aggregate losses across all GPUs if in distributed mode
if self.hparams.training_spec.distributed:
# Put local losses into tensors for reduction
total_loss_tensor = torch.tensor(total_loss_local, device=self.device)
# Ensure consistent order for the losses tensor
loss_keys = sorted(total_losses_local.keys())
losses_values = [total_losses_local[k] for k in loss_keys]
losses_tensor = torch.tensor(losses_values, device=self.device)
# Sum losses from all processes. The result is broadcast back to all processes.
dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(losses_tensor, op=dist.ReduceOp.SUM)
# Update local variables with the aggregated global results
total_loss_global = total_loss_tensor.cpu().numpy()
losses_global_values = losses_tensor.cpu().numpy()
total_losses_global = dict(zip(loss_keys, losses_global_values))
else:
# If not distributed, local losses are the global losses
total_loss_global = total_loss_local
total_losses_global = total_losses_local
# 3. Handle one-time baseline loss calculation (must also be synchronized)
if not hasattr(self, "baseline_loss"):
baseline_loss_local_collect = []
baseline_losses_local_collect = {col: [] for col in self.target_columns}
# Iterate over the sharded validation loader
for data, targets, _, _, _ in valid_loader:
data = {
k: v.to(self.device, non_blocking=True)
for k, v in data.items()
if k in self.selected_columns
}
targets = {
k: v.to(self.device, non_blocking=True)
for k, v in targets.items()
if k in self.target_column_types
}
# Replicate original logic of using input as pseudo-output
pseudo_output = {
col: self._transform_val(col, data[col]) for col in targets.keys()
}
loss, losses = self._calculate_loss(pseudo_output, targets)
baseline_loss_local_collect.append(loss.item())
for col, loss_ in losses.items():
baseline_losses_local_collect[col].append(loss_.item())
# Sum the losses for the local shard
baseline_loss_local = np.sum(baseline_loss_local_collect)
baseline_losses_local = {
col: np.sum(loss_list)
for col, loss_list in baseline_losses_local_collect.items()
}
# Broadcast the baseline values from the main process to all others
if self.hparams.training_spec.distributed:
total_loss_tensor = torch.tensor(
baseline_loss_local, device=self.device
)
dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
self.baseline_loss = total_loss_tensor.item()
loss_keys = sorted(baseline_losses_local.keys())
losses_values = [baseline_losses_local[k] for k in loss_keys]
losses_tensor = torch.tensor(losses_values, device=self.device)
dist.all_reduce(losses_tensor, op=dist.ReduceOp.SUM)
self.baseline_losses = dict(zip(loss_keys, losses_tensor.cpu().numpy()))
else:
# If not distributed, local is global
self.baseline_loss = baseline_loss_local
self.baseline_losses = baseline_losses_local
return (
np.float32(total_loss_global),
{k: np.float32(v) for k, v in total_losses_global.items()},
output,
)
@beartype
def _get_batch(
self,
X: dict[str, Tensor],
y: dict[str, Tensor],
batch_start: int,
batch_size: int,
to_device: bool,
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
"""Gets a batch of data.
(Note: This method seems unused in favor of DataLoader iteration).
Args:
X: A dictionary of feature tensors.
y: A dictionary of target tensors.
batch_start: The starting index for the batch.
batch_size: The size of the batch.
to_device: Whether to move the batch tensors to the model's device.
Returns:
A tuple (X_batch, y_batch) containing the sliced batch data.
"""
if to_device:
return (
{
col: X[col][batch_start : batch_start + batch_size, :].to(
self.device, non_blocking=True
)
for col in X.keys()
},
{
target_column: y[target_column][
batch_start : batch_start + batch_size, :
].to(self.device, non_blocking=True)
for target_column in y.keys()
},
)
else:
return (
{
col: X[col][batch_start : batch_start + batch_size, :]
for col in X.keys()
},
{
target_column: y[target_column][
batch_start : batch_start + batch_size, :
]
for target_column in y.keys()
},
)
@beartype
def _export(self, model: "TransformerModel", suffix: str, epoch: int) -> None:
"""Exports the model.
This is a wrapper function that handles exporting the model (and
optionally the embedding-only model) on rank 0 only.
Args:
model: The model instance to export (e.g., best model or last model).
suffix: A string suffix to append to the model filename (e.g., "best", "last").
epoch: The current epoch number, included in the filename.
"""
if self.rank != 0:
return
self.eval()
os.makedirs(os.path.join(self.project_path, "models"), exist_ok=True)
if self.export_generative_model:
self._export_model(model, suffix, epoch)
if self.export_embedding_model:
model2 = TransformerEmbeddingModel(model)
suffix = f"{suffix}-embedding"
self._export_model(model2, suffix, epoch)
def _export_model(
self,
model: Union["TransformerModel", "TransformerEmbeddingModel"],
suffix: str,
epoch: int,
) -> None:
"""Exports the model to ONNX and/or PyTorch format.
Saves the model weights as a .pt file and/or exports the model
graph and weights as an .onnx file based on the config.
Args:
model: The model instance (TransformerModel or TransformerEmbeddingModel).
suffix: A string suffix for the filename (e.g., "best", "last-embedding").
epoch: The current epoch number, included in the filename.
"""
if self.export_onnx:
x_cat = {
col: torch.randint(
0,
self.n_classes[col],
(self.inference_batch_size, self.seq_length),
).to(self.device, non_blocking=True)
for col in self.categorical_columns
}
x_real = {
col: torch.rand(self.inference_batch_size, self.seq_length).to(
self.device, non_blocking=True
)
for col in self.real_columns
}
x = {"src": {**x_cat, **x_real}}
# Export the model
export_path = os.path.join(
self.project_path,
"models",
f"sequifier-{self.model_name}-{suffix}-{epoch}.onnx",
)
training_mode = (
torch._C._onnx.TrainingMode.TRAINING
if self.export_with_dropout
else torch._C._onnx.TrainingMode.EVAL
)
constant_folding = self.export_with_dropout == False # noqa: E712
torch.onnx.export(
model, # model being run
x, # model input (or a tuple for multiple inputs)
export_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=14, # the ONNX version to export the model to
do_constant_folding=constant_folding, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {0: "batch_size"}, # variable length axes
"output": {0: "batch_size"},
},
training=training_mode,
)
if self.export_pt:
export_path = os.path.join(
self.project_path,
"models",
f"sequifier-{self.model_name}-{suffix}-{epoch}.pt",
)
torch.save(
{
"model_state_dict": model.state_dict(),
"export_with_dropout": self.export_with_dropout,
},
export_path,
)
@beartype
def _save(self, epoch: int, val_loss: np.float32) -> None:
"""Saves the model checkpoint.
Saves the model state, optimizer state, and epoch number to a .pt
file in the checkpoints directory. Only runs on rank 0.
Args:
epoch: The current epoch number.
val_loss: The validation loss at the current epoch.
"""
if self.rank != 0:
return
os.makedirs(os.path.join(self.project_path, "checkpoints"), exist_ok=True)
output_path = os.path.join(
self.project_path,
"checkpoints",
f"{self.model_name}-epoch-{epoch}.pt",
)
torch.save(
{
"epoch": epoch,
"model_state_dict": self.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"loss": val_loss,
},
output_path,
)
if self.rank == 0:
self.log_file.write(f"[INFO] Saved model to {output_path}")
@beartype
def _get_optimizer(self, **kwargs):
"""Gets the optimizer.
Initializes the optimizer specified in the hyperparameters.
Args:
**kwargs: Additional arguments to pass to the optimizer constructor
(e.g., weight_decay).
Returns:
An initialized torch.optim.Optimizer instance.
"""
optimizer_class = get_optimizer_class(self.hparams.training_spec.optimizer.name)
return optimizer_class(
self.parameters(), lr=self.hparams.training_spec.lr, **kwargs
)
@beartype
def _get_scheduler(self, **kwargs):
"""Gets the scheduler.
Initializes the learning rate scheduler specified in the hyperparameters.
Args:
**kwargs: Additional arguments to pass to the scheduler constructor
(e.g., step_size).
Returns:
An initialized torch.optim.lr_scheduler._LRScheduler instance.
"""
scheduler_class = eval(
f"torch.optim.lr_scheduler.{self.hparams.training_spec.scheduler.name}"
)
return scheduler_class(self.optimizer, **kwargs)
@beartype
def _initialize_log_file(self):
"""Initializes the log file."""
os.makedirs(os.path.join(self.project_path, "logs"), exist_ok=True)
path = os.path.join(
self.project_path, "logs", f"sequifier-{self.model_name}-[NUMBER].txt"
)
if self.rank is not None:
path = path.replace("[NUMBER]", f"rank{self.rank}-[NUMBER]")
self.log_file = LogFile(path, self.rank)
@beartype
def _load_weights_conditional(self) -> str:
"""Loads the weights of the model if a checkpoint is found.
If `continue_training` is True and a checkpoint for the current
`model_name` exists, it loads the model and optimizer states.
Otherwise, it initializes a new model.
Returns:
A string message indicating whether training is resuming or starting new.
"""
latest_model_path = self._get_latest_model_name()
pytorch_total_params = sum(p.numel() for p in self.parameters())
if latest_model_path is not None and self.continue_training:
checkpoint = torch.load(
latest_model_path,
map_location=torch.device(self.device),
weights_only=False,
)
self.load_state_dict(checkpoint["model_state_dict"])
self.start_epoch = checkpoint["epoch"] + 1
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
if k == "step":
# Keep the 'step' tensor on the CPU.
state[k] = v.cpu()
else:
# Move all other state tensors to the model's device.
state[k] = v.to(self.device)
return f"[INFO] Resuming training from checkpoint '{latest_model_path}'. Total params: {format_number(pytorch_total_params)}"
else:
self.start_epoch = 1
return f"[INFO] Initializing new model with {format_number(pytorch_total_params)} parameters."
@beartype
def _get_latest_model_name(self) -> Optional[str]:
"""Gets the name of the latest model checkpoint.
Scans the checkpoints directory for files matching the current
`model_name` and returns the path to the most recently modified one.
Returns:
The file path (str) to the latest checkpoint, or None if no
checkpoint is found.
"""
checkpoint_path = os.path.join(self.project_path, "checkpoints", "*")
files = glob.glob(checkpoint_path)
files = [
file for file in files if os.path.split(file)[1].startswith(self.model_name)
]
if files:
return max(files, key=os.path.getctime)
else:
return None
@beartype
def _log_epoch_results(
self,
epoch: int,
elapsed: float,
total_loss: np.float32,
total_losses: dict[str, np.float32],
output: dict[str, Tensor],
) -> None:
"""Logs the results of an epoch.
Writes validation loss, individual losses, learning rate, and
class share statistics (if configured) to the log file.
Only runs on rank 0.
Args:
epoch: The current epoch number.
elapsed: Time taken for the epoch (in seconds).
total_loss: The total aggregated validation loss.
total_losses: A dictionary of aggregated losses for each target.
output: The output tensor dictionary from the last validation batch,
used for class share logging.
"""
if self.rank == 0:
lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
self.log_file.write("-" * 89)
self.log_file.write(
f"[INFO] Validation | Epoch: {epoch:3d} | Loss: {format_number(total_loss)} | Baseline Loss: {format_number(self.baseline_loss)} | Time: {elapsed:5.2f}s | LR {format_number(lr)}"
)
if len(total_losses) > 1:
loss_strs = [
f"{key}_loss: {format_number(value)}"
for key, value in total_losses.items()
]
self.log_file.write("[INFO] - " + ", ".join(loss_strs))
for categorical_column in self.class_share_log_columns:
output_values = (
output[categorical_column].argmax(1).cpu().detach().numpy()
)
output_counts_df = (
pl.Series("values", output_values).value_counts().sort("values")
)
output_counts = output_counts_df.get_column("count")
output_counts = output_counts / output_counts.sum()
value_shares = " | ".join(
[
f"{self.index_maps[categorical_column][row['values']]}: {row['count']:5.5f}"
for row in output_counts_df.iter_rows(named=True)
]
)
self.log_file.write(f"[INFO] {categorical_column}: {value_shares}")
self.log_file.write("-" * 89)
[docs]
@beartype
def load_inference_model(
model_type: str,
model_path: str,
training_config_path: str,
args_config: dict[str, Any],
device: str,
infer_with_dropout: bool,
) -> torch.nn.Module:
"""Loads a trained model for inference.
Args:
model_type: "generative" or "embedding".
model_path: Path to the saved .pt model file.
training_config_path: Path to the .yaml config file used for training.
args_config: A dictionary of override configurations.
device: The device to load the model onto (e.g., "cuda", "cpu").
infer_with_dropout: Whether to force dropout layers to be active
during inference.
Returns:
The loaded and compiled torch.nn.Module (TransformerModel or
TransformerEmbeddingModel) in evaluation mode.
"""
training_config = load_train_config(
training_config_path, args_config, args_config["on_unprocessed"]
)
with torch.no_grad():
model = TransformerModel(training_config)
if model_type == "generative":
model = TransformerModel(training_config)
elif model_type == "embedding":
model_inner = TransformerModel(training_config)
model = TransformerEmbeddingModel(model_inner)
else:
assert False, "impossible"
model.log_file.write(f"[INFO] Loading model weights from {model_path}")
model_state = torch.load(
model_path, map_location=torch.device(device), weights_only=False
)
model.load_state_dict(model_state["model_state_dict"])
model.eval()
if infer_with_dropout:
if not model_state["export_with_dropout"]:
warnings.warn(
"Model was exported with 'export_with_dropout'==False. By setting 'infer_with_dropout' to True, you are overriding this configuration"
)
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.train()
model = torch.compile(model).to(device)
return model
[docs]
@beartype
def infer_with_embedding_model(
model: nn.Module,
x: list[dict[str, np.ndarray]],
device: str,
size: int,
target_columns: list[str],
) -> np.ndarray:
"""Performs inference with an embedding model.
Args:
model: The loaded TransformerEmbeddingModel.
x: A list of input data dictionaries (batched).
device: The device to run inference on.
size: The total number of samples (unused in this function).
target_columns: List of target column names (unused in this function).
Returns:
A NumPy array containing the concatenated embeddings from all batches.
"""
outs0 = []
with torch.no_grad():
for x_sub in x:
data_gpu = {
col: torch.from_numpy(x_).to(device) for col, x_ in x_sub.items()
}
output_gpu = model.forward(data_gpu)
outs0.append(output_gpu.cpu().detach())
if device == "cuda":
torch.cuda.empty_cache()
outs = np.concatenate(outs0, axis=0)
return outs
[docs]
@beartype
def infer_with_generative_model(
model: nn.Module,
x: list[dict[str, np.ndarray]],
device: str,
size: int,
target_columns: list[str],
) -> dict[str, np.ndarray]:
"""Performs inference with a generative model.
Args:
model: The loaded TransformerModel.
x: A list of input data dictionaries (batched).
device: The device to run inference on.
size: The total number of samples to trim the final output to.
target_columns: List of target column names to extract from the output.
Returns:
A dictionary mapping target column names to their concatenated
output NumPy arrays, trimmed to `size`.
"""
outs0 = []
with torch.no_grad():
for x_sub in x:
data_gpu = {
col: torch.from_numpy(x_).to(device) for col, x_ in x_sub.items()
}
output_gpu = model.forward(data_gpu)
output_cpu = {k: v.cpu().detach() for k, v in output_gpu.items()}
outs0.append(output_cpu)
if device == "cuda":
torch.cuda.empty_cache()
outs = {
target_column: np.concatenate(
[o[target_column].numpy() for o in outs0],
axis=0,
)[:size, :]
for target_column in target_columns
}
return outs