Source code for sequifier.io.sequifier_dataset_from_file

from typing import Dict, Iterator, Tuple

import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset

from sequifier.config.train_config import TrainModel
from sequifier.helpers import PANDAS_TO_TORCH_TYPES, numpy_to_pytorch, read_data


[docs] class SequifierDatasetFromFile(IterableDataset): """ An iterable-style dataset that pre-loads all data into CPU RAM and yields pre-collated batches. This is the idiomatic PyTorch solution for implementing custom 'en block' batching. The __iter__ method handles shuffling and batch slicing, ensuring maximum performance. """ def __init__(self, data_path: str, config: TrainModel, shuffle: bool = True): super().__init__() self.config = config self.batch_size = config.training_spec.batch_size self.shuffle = shuffle self.epoch = 0 # Create a unified list of all columns the model might need all_columns = sorted(list(set(config.selected_columns + config.target_columns))) print(f"[INFO] Loading training dataset into memory from '{data_path}'...") data_df = read_data(data_path, config.read_format) column_types = { col: PANDAS_TO_TORCH_TYPES[config.column_types[col]] for col in config.column_types } # self.all_tensors now holds both inputs and targets all_tensors = numpy_to_pytorch( data=data_df, column_types=column_types, all_columns=all_columns, seq_length=config.seq_length, ) self.n_samples = all_tensors[all_columns[0]].shape[0] del data_df self.sequence_tensors = { key: all_tensors[key] for key in self.config.selected_columns } self.target_tensors = { key: all_tensors[f"{key}_target"] for key in self.config.target_columns } del all_tensors if config.training_spec.device.startswith("cuda"): for key in self.sequence_tensors: self.sequence_tensors[key] = self.sequence_tensors[key].pin_memory() for key in self.target_tensors: self.target_tensors[key] = self.target_tensors[key].pin_memory() print(f"[INFO] Dataset loaded with {self.n_samples} samples.")
[docs] def set_epoch(self, epoch: int): """Allows the training loop to set the epoch for deterministic shuffling.""" self.epoch = epoch
[docs] def __len__(self) -> int: """Returns the total number of samples in the dataset.""" return self.n_samples
[docs] def __iter__( self, ) -> Iterator[ Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], None, None, None] ]: """Yields batches of data. Handles shuffling (if enabled) and slicing data based on distributed rank and worker ID. Yields: Iterator[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], None, None, None]]: An iterator where each item is a tuple containing: - data_batch (dict): Dictionary of feature tensors for the batch. - targets_batch (dict): Dictionary of target tensors for the batch. - None: Placeholder for sequence_id (not used in this dataset type). - None: Placeholder for subsequence_id (not used in this dataset type). - None: Placeholder for start_position (not used in this dataset type). """ worker_info = torch.utils.data.get_worker_info() world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 if worker_info is None: # Single-process data loading worker_id = 0 num_workers = 1 else: # Multi-process data loading worker_id = worker_info.id num_workers = worker_info.num_workers indices = torch.arange(self.n_samples) if self.shuffle: g = torch.Generator() # Use epoch and seed for a different but deterministic shuffle each epoch g.manual_seed(self.config.seed + self.epoch) indices = indices[torch.randperm(self.n_samples, generator=g)] indices_for_rank = indices[rank::world_size] indices_for_worker = indices_for_rank[worker_id::num_workers] for i in range(0, len(indices_for_worker), self.batch_size): batch_end = i + self.batch_size if batch_end > len(indices_for_worker): continue batch_indices = indices_for_worker[i:batch_end] data_batch = { key: tensor[batch_indices] for key, tensor in self.sequence_tensors.items() } targets_batch = { key: tensor[batch_indices] for key, tensor in self.target_tensors.items() } yield data_batch, targets_batch, None, None, None