Source code for sequifier.samplers.distributed_grouped_random_sampler

import random
from typing import Iterator, Union

import torch
from torch.utils.data import Sampler

from sequifier.io.sequifier_dataset_from_folder import SequifierDatasetFromFolder
from sequifier.io.sequifier_dataset_from_folder_lazy import (
    SequifierDatasetFromFolderLazy,
)


[docs] class DistributedGroupedRandomSampler(Sampler[int]): """ A distributed sampler that groups samples by file to improve cache efficiency. This sampler partitions the set of data FILES across the distributed processes, not the individual samples. Each process then iterates through its assigned files in a random order. Within each file, the samples are also shuffled. This ensures that each process sees a unique subset of the data per epoch while maximizing sequential reads from the same file, which is ideal for lazy-loading datasets. """
[docs] def __init__( self, data_source: Union[SequifierDatasetFromFolder, SequifierDatasetFromFolderLazy], num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0, ): """ Args: data_source: The dataset to sample from. Must have a `batch_files_info` attribute. num_replicas: Number of processes participating in distributed training. rank: Rank of the current process. shuffle: If True, shuffles the order of files and samples within files. seed: Random seed used to create the permutation. """ super().__init__(data_source) self.data_source = data_source self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.shuffle = shuffle self.seed = seed # Pre-compute the global indices for each file, same as before self.index_groups = [] start_index = 0 for file_info in self.data_source.batch_files_info: num_samples_in_file = file_info["samples"] indices = list(range(start_index, start_index + num_samples_in_file)) self.index_groups.append(indices) start_index += num_samples_in_file # Determine the number of files and samples this rank will process self.num_files = len(self.index_groups) self.files_for_this_rank = list(range(self.num_files))[ self.rank :: self.num_replicas ] self.num_samples = sum( len(self.index_groups[i]) for i in self.files_for_this_rank )
[docs] def __iter__(self) -> Iterator[int]: """ Returns an iterator over indices for the current rank. """ # 1. Deterministically shuffle the list of ALL files based on seed and epoch if self.shuffle: g = torch.Generator() g.manual_seed(self.seed + self.epoch) all_files_order = torch.randperm(self.num_files, generator=g).tolist() else: all_files_order = list(range(self.num_files)) # 2. Assign a unique, non-overlapping subset of shuffled files to this rank files_for_this_rank = all_files_order[self.rank :: self.num_replicas] # 3. Create the final list of indices for this rank final_indices = [] for file_idx in files_for_this_rank: group = self.index_groups[file_idx] if self.shuffle: # Shuffle samples within the file group random.shuffle(group) final_indices.extend(group) return iter(final_indices)
[docs] def __len__(self) -> int: """ Returns the number of samples for the current rank, not the total. """ return self.num_samples
[docs] def set_epoch(self, epoch: int) -> None: """ Sets the epoch for this sampler. This is used to create a different shuffling order for each epoch. """ self.epoch = epoch