import bisect
import collections
import json
import os
from typing import Dict, Tuple
import psutil # Dependency: pip install psutil
import torch
from loguru import logger
from torch.utils.data import Dataset
from sequifier.config.train_config import TrainModel
from sequifier.helpers import normalize_path
[docs]class SequifierDatasetFromFolderLazy(Dataset):
"""
An efficient PyTorch Dataset for datasets that do not fit into RAM.
This class loads data from individual .pt files on-demand (lazily) when an
item is requested via `__getitem__`. It maintains an in-memory cache of
recently used files to speed up access. To prevent memory exhaustion,
the cache is managed by a Least Recently Used (LRU) policy, which
evicts the oldest data chunks when the total system RAM usage exceeds a
configurable threshold.
This strategy balances I/O overhead and memory usage, making it suitable
for training on datasets larger than the available system memory.
"""
[docs] def __init__(self, data_path: str, config: TrainModel):
"""
Initializes the dataset by reading metadata and setting up the cache.
Each .pt file is expected to contain a tuple:
(sequences_dict, targets_dict, sequence_ids_tensor, subsequence_ids_tensor, start_item_positions_tensor).
Args:
data_path (str): The path to the directory containing the pre-processed
.pt files and a metadata.json file.
config (TrainModel): The training configuration object.
"""
self.data_dir = normalize_path(data_path, config.project_root)
self.config = config
self.max_ram_gb = config.training_spec.max_ram_gb
self.max_ram_bytes = config.training_spec.max_ram_gb * (1024**3)
metadata_path = os.path.join(self.data_dir, "metadata.json")
if not os.path.exists(metadata_path):
raise FileNotFoundError(
f"metadata.json not found in '{self.data_dir}'. "
"Ensure data is pre-processed with write_format: pt."
)
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.n_samples = metadata["total_samples"]
self.batch_files_info = metadata["batch_files"]
# --- Build an index for fast sample-to-file mapping ---
# self.cumulative_samples will store the cumulative sample count at the end
# of each file, e.g., [1024, 2048, 3072, ...], allowing for a fast binary search.
self.cumulative_samples = []
current_sum = 0
for file_info in self.batch_files_info:
current_sum += file_info["samples"]
self.cumulative_samples.append(current_sum)
# --- Initialize cache and thread-safety mechanisms ---
# An OrderedDict is used to implement the LRU logic.
self.cache: collections.OrderedDict[
str,
Tuple[
Dict[str, torch.Tensor],
Dict[str, torch.Tensor],
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
] = collections.OrderedDict()
self.max_cache_files = 2
logger.info(
f"[INFO] Initialized lazy dataset from {self.data_dir}. "
f"Total samples: {self.n_samples}. RAM threshold in GB: {self.max_ram_gb}, max cache files: {self.max_cache_files}"
)
[docs] def __len__(self) -> int:
"""Returns the total number of samples in the dataset."""
return self.n_samples
def _find_file_for_index(self, idx: int) -> Tuple[int, str]:
"""
Finds which file contains the given sample index and the local index within it.
Args:
idx: The global sample index across all files.
Returns:
A tuple containing (local_index_in_file, file_path).
"""
# bisect_right finds the insertion point for idx, which corresponds to the
# index of the file containing this sample.
file_index = bisect.bisect_right(self.cumulative_samples, idx)
# Calculate the local index within the identified file.
# If it's the first file (index 0), the local index is just idx.
# Otherwise, subtract the cumulative sample count of the previous file.
previous_samples = (
self.cumulative_samples[file_index - 1] if file_index > 0 else 0
)
local_index = idx - previous_samples
file_path = os.path.join(
self.data_dir, self.batch_files_info[file_index]["path"]
)
return local_index, file_path
def _get_memory_usage_percent(self):
# Try cgroup v2
if os.path.exists("/sys/fs/cgroup/memory.current"):
with open("/sys/fs/cgroup/memory.current", "r") as f:
used = int(f.read())
with open("/sys/fs/cgroup/memory.max", "r") as f:
limit_str = f.read().strip()
# Handle 'max' which means no limit
limit = int(limit_str) if limit_str != "max" else self.max_ram_bytes
return used / limit if limit > 0 else 0
# Try cgroup v1
elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
used = int(f.read())
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
limit = int(f.read())
return used / limit if limit > 0 else 0
# Fallback to psutil (host memory)
else:
return psutil.virtual_memory().percent / 100.0
def _evict_lru_items(self):
# Evict if usage > 90% of limit (safety buffer)
while self._get_memory_usage_percent() > 0.90:
if not self.cache:
break
self.cache.popitem(last=False)
[docs] def __getitem__(
self, idx: int
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], int, int, int]:
"""
Retrieves a single data sample, loading from disk if not in the cache.
This method is the core of the lazy-loading strategy. It is thread-safe
and manages the cache automatically.
Args:
idx: The index of the sample to retrieve.
Returns:
A tuple containing:
- sequence (dict): Dictionary of feature tensors for the sample.
- targets (dict): Dictionary of target tensors for the sample.
- sequence_id (int): The sequence ID of the sample.
- subsequence_id (int): The subsequence ID within the sequence.
- start_position (int): The starting item position of the subsequence
within the original full sequence.
"""
if not 0 <= idx < self.n_samples:
raise IndexError(
f"Index {idx} is out of range for dataset with {self.n_samples} samples."
)
local_index, file_path = self._find_file_for_index(idx)
# 1. Check if file is already in cache
if file_path in self.cache:
self.cache.move_to_end(file_path)
data_tuple = self.cache[file_path]
else:
# 2. PRE-EVICTION: Make space BEFORE loading
# If adding this file would exceed the limit, evict the oldest one first.
while len(self.cache) >= self.max_cache_files:
# remove oldest item (FIFO)
evicted_path, _ = self.cache.popitem(last=False)
# Optional: Force garbage collection if memory is tight,
# though .clone() usually makes this automatic.
# import gc; gc.collect()
# 3. Load from disk (Safe now, as we made space)
data_tuple = torch.load(file_path, map_location="cpu")
self.cache[file_path] = data_tuple
# Unpack
(
sequences_batch,
targets_batch,
sequence_id_tensor,
subsequence_id_tensor,
start_item_positions_tensor,
) = data_tuple
# 4. CRITICAL: Use .clone() to sever the link to the cached file
train_seq_len = self.config.seq_length
sequence = {
key: tensor[local_index, -train_seq_len:].clone()
for key, tensor in sequences_batch.items()
}
targets = {
key: tensor[local_index, -train_seq_len:].clone()
for key, tensor in targets_batch.items()
}
return (
sequence,
targets,
int(sequence_id_tensor[local_index]),
int(subsequence_id_tensor[local_index]),
int(start_item_positions_tensor[local_index]),
)