Source code for sequifier.io.sequifier_dataset_from_folder
import json
import os
from typing import Dict, List, Tuple
import torch
from torch.utils.data import Dataset
from sequifier.config.train_config import TrainModel
from sequifier.helpers import normalize_path
[docs]
class SequifierDatasetFromFolder(Dataset):
"""
An efficient PyTorch Dataset that pre-loads all data into RAM.
This is the ideal strategy when the entire dataset split can fit into the
system's memory. It pays a one-time I/O cost at initialization, after which
all data access during training is extremely fast (RAM access).
"""
[docs]
def __init__(self, data_path: str, config: TrainModel):
"""
Initializes the dataset by loading all .pt files from the data directory
into memory. Each .pt file is expected to contain a tuple:
(sequences_dict, targets_dict, sequence_ids_tensor, subsequence_ids_tensor, start_item_positions_tensor).
"""
self.data_dir = normalize_path(data_path, config.project_path)
self.config = config
metadata_path = os.path.join(self.data_dir, "metadata.json")
if not os.path.exists(metadata_path):
raise FileNotFoundError(
f"metadata.json not found in '{self.data_dir}'. "
"Ensure data is pre-processed with write_format: pt."
)
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.batch_files_info = metadata["batch_files"]
self.n_samples = metadata["total_samples"]
print(f"[INFO] Loading training dataset into memory from '{self.data_dir}'...")
all_sequences: Dict[str, List[torch.Tensor]] = {
col: [] for col in config.selected_columns
}
all_targets: Dict[str, List[torch.Tensor]] = {
col: [] for col in config.target_columns
}
all_sequence_ids: List[torch.Tensor] = []
all_subsequence_ids: List[torch.Tensor] = []
all_starting_positions: List[torch.Tensor] = []
# Load all data files and collect tensors
for file_info in metadata["batch_files"]:
file_path = os.path.join(self.data_dir, file_info["path"])
(
sequences_batch,
targets_batch,
sequence_ids,
subsequence_ids,
start_item_positions_tensor,
) = torch.load(file_path, map_location="cpu")
for col in all_sequences.keys():
if col in sequences_batch:
all_sequences[col].append(sequences_batch[col])
for col in all_targets.keys():
if col in targets_batch:
all_targets[col].append(targets_batch[col])
all_sequence_ids.append(sequence_ids)
all_subsequence_ids.append(subsequence_ids)
all_starting_positions.append(start_item_positions_tensor)
# Concatenate all tensors into a single large tensor for each column
self.sequences: Dict[str, torch.Tensor] = {
col: torch.cat(tensors) for col, tensors in all_sequences.items() if tensors
}
self.targets: Dict[str, torch.Tensor] = {
col: torch.cat(tensors) for col, tensors in all_targets.items() if tensors
}
self.sequence_ids = torch.cat(all_sequence_ids)
self.subsequence_ids = torch.cat(all_subsequence_ids)
self.start_item_positions = torch.cat(all_starting_positions)
for tensor in self.sequences.values():
tensor.share_memory_()
for tensor in self.targets.values():
tensor.share_memory_()
print(f"[INFO] Dataset loaded with {self.n_samples} samples.")
# Verify that the number of loaded samples matches the metadata
first_key = next(iter(self.sequences.keys()))
if self.sequences[first_key].shape[0] != self.n_samples:
raise ValueError(
f"Mismatch in sample count! Metadata: {self.n_samples}, Loaded: {self.sequences[first_key].shape[0]}"
)
def __len__(self) -> int:
return self.n_samples
[docs]
def __getitem__(
self, idx: int
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], int, int, int]:
"""Retrieves a single sample from the pre-loaded data.
Args:
idx: The index of the sample to retrieve.
Returns:
A tuple containing:
- sequence (dict): Dictionary of feature tensors for the sample.
- targets (dict): Dictionary of target tensors for the sample.
- sequence_id (int): The sequence ID of the sample.
- subsequence_id (int): The subsequence ID within the sequence.
- start_position (int): The starting item position of the subsequence
within the original full sequence.
"""
if not 0 <= idx < self.n_samples:
raise IndexError(
f"Index {idx} is out of range for a dataset with {self.n_samples} samples."
)
# Accessing data is now just a fast slice from the pre-loaded tensors in RAM
sequence = {key: tensor[idx] for key, tensor in self.sequences.items()}
targets = {key: tensor[idx] for key, tensor in self.targets.items()}
sequence_id = int(self.sequence_ids[idx].item())
subsequence_id = int(self.subsequence_ids[idx].item())
start_position = int(self.start_item_positions[idx].item())
return sequence, targets, sequence_id, subsequence_id, start_position