Source code for sequifier.io.sequifier_dataset_from_folder_lazy

import bisect
import collections
import json
import os
from typing import Dict, Tuple

import psutil  # Dependency: pip install psutil
import torch
from torch.utils.data import Dataset

from sequifier.config.train_config import TrainModel
from sequifier.helpers import normalize_path


[docs] class SequifierDatasetFromFolderLazy(Dataset): """ An efficient PyTorch Dataset for datasets that do not fit into RAM. This class loads data from individual .pt files on-demand (lazily) when an item is requested via `__getitem__`. It maintains an in-memory cache of recently used files to speed up access. To prevent memory exhaustion, the cache is managed by a Least Recently Used (LRU) policy, which evicts the oldest data chunks when the total system RAM usage exceeds a configurable threshold. This strategy balances I/O overhead and memory usage, making it suitable for training on datasets larger than the available system memory. """
[docs] def __init__(self, data_path: str, config: TrainModel, ram_threshold: float = 70.0): """ Initializes the dataset by reading metadata and setting up the cache. Each .pt file is expected to contain a tuple: (sequences_dict, targets_dict, sequence_ids_tensor, subsequence_ids_tensor, start_item_positions_tensor). Args: data_path (str): The path to the directory containing the pre-processed .pt files and a metadata.json file. config (TrainModel): The training configuration object. ram_threshold (float): The system RAM usage percentage (0-100) at which to trigger cache eviction. """ self.data_dir = normalize_path(data_path, config.project_path) self.config = config self.ram_threshold = ram_threshold metadata_path = os.path.join(self.data_dir, "metadata.json") if not os.path.exists(metadata_path): raise FileNotFoundError( f"metadata.json not found in '{self.data_dir}'. " "Ensure data is pre-processed with write_format: pt." ) with open(metadata_path, "r") as f: metadata = json.load(f) self.n_samples = metadata["total_samples"] self.batch_files_info = metadata["batch_files"] # --- Build an index for fast sample-to-file mapping --- # self.cumulative_samples will store the cumulative sample count at the end # of each file, e.g., [1024, 2048, 3072, ...], allowing for a fast binary search. self.cumulative_samples = [] current_sum = 0 for file_info in self.batch_files_info: current_sum += file_info["samples"] self.cumulative_samples.append(current_sum) # --- Initialize cache and thread-safety mechanisms --- # An OrderedDict is used to implement the LRU logic. self.cache: collections.OrderedDict[ str, Tuple[ Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, ], ] = collections.OrderedDict() print( f"[INFO] Initialized lazy dataset from {self.data_dir}. " f"Total samples: {self.n_samples}. RAM threshold: {self.ram_threshold}%" )
[docs] def __len__(self) -> int: """Returns the total number of samples in the dataset.""" return self.n_samples
def _find_file_for_index(self, idx: int) -> Tuple[int, str]: """ Finds which file contains the given sample index and the local index within it. Args: idx: The global sample index across all files. Returns: A tuple containing (local_index_in_file, file_path). """ # bisect_right finds the insertion point for idx, which corresponds to the # index of the file containing this sample. file_index = bisect.bisect_right(self.cumulative_samples, idx) # Calculate the local index within the identified file. # If it's the first file (index 0), the local index is just idx. # Otherwise, subtract the cumulative sample count of the previous file. previous_samples = ( self.cumulative_samples[file_index - 1] if file_index > 0 else 0 ) local_index = idx - previous_samples file_path = os.path.join( self.data_dir, self.batch_files_info[file_index]["path"] ) return local_index, file_path def _evict_lru_items(self): """ Checks system memory and evicts least recently used items from the cache until usage is below the threshold. This method must be called from within a locked context. """ while psutil.virtual_memory().percent > self.ram_threshold: if not self.cache: # Cache is empty, but memory is still high. Nothing to evict. break # popitem(last=False) removes and returns the (key, value) pair that # was first inserted, effectively implementing the LRU policy. evicted_path, _ = self.cache.popitem(last=False) print( f"[INFO] RAM usage {psutil.virtual_memory().percent:.1f}% > {self.ram_threshold}%. " f"Evicting {os.path.basename(evicted_path)} from cache." )
[docs] def __getitem__( self, idx: int ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], int, int, int]: """ Retrieves a single data sample, loading from disk if not in the cache. This method is the core of the lazy-loading strategy. It is thread-safe and manages the cache automatically. Args: idx: The index of the sample to retrieve. Returns: A tuple containing: - sequence (dict): Dictionary of feature tensors for the sample. - targets (dict): Dictionary of target tensors for the sample. - sequence_id (int): The sequence ID of the sample. - subsequence_id (int): The subsequence ID within the sequence. - start_position (int): The starting item position of the subsequence within the original full sequence. """ if not 0 <= idx < self.n_samples: raise IndexError( f"Index {idx} is out of range for dataset with {self.n_samples} samples." ) local_index, file_path = self._find_file_for_index(idx) # Acquire lock to ensure atomic cache operations # 1. Check for a cache hit if file_path in self.cache: # Mark as recently used by moving it to the end of the OrderedDict. self.cache.move_to_end(file_path) ( sequences_batch, targets_batch, sequence_id_tensor, subsequence_id_tensor, start_item_positions_tensor, ) = self.cache[file_path] # 2. Handle a cache miss else: # Load the data from the .pt file from disk. ( sequences_batch, targets_batch, sequence_id_tensor, subsequence_id_tensor, start_item_positions_tensor, ) = torch.load(file_path, map_location="cpu") # Add the newly loaded data to the cache. self.cache[file_path] = ( sequences_batch, targets_batch, sequence_id_tensor, subsequence_id_tensor, start_item_positions_tensor, ) # After adding, check memory and evict old items if necessary. self._evict_lru_items() # 3. Retrieve the specific sample from the (now cached) batch tensors. sequence = {key: tensor[local_index] for key, tensor in sequences_batch.items()} targets = {key: tensor[local_index] for key, tensor in targets_batch.items()} return ( sequence, targets, int(sequence_id_tensor[local_index]), int(subsequence_id_tensor[local_index]), int(start_item_positions_tensor[local_index]), )