Source code for sequifier.samplers.distributed_grouped_random_sampler
from logging import Logger
from typing import Iterator, Optional, Union
import numpy as np
import torch
from beartype import beartype
from torch.utils.data import Sampler
from sequifier.io.sequifier_dataset_from_folder import SequifierDatasetFromFolder
from sequifier.io.sequifier_dataset_from_folder_lazy import (
SequifierDatasetFromFolderLazy,
)
@beartype
def get_final_indices(
files_for_this_rank: list[int],
index_groups: list[list[int]],
shuffle: bool,
generator: Optional[torch.Generator],
):
final_indices = []
for file_idx in files_for_this_rank:
# IMPORTANT: Create a copy to avoid mutating self.index_groups in-place
group = list(index_groups[file_idx])
if shuffle:
assert generator is not None
perm = torch.randperm(len(group), generator=generator).tolist() # type: ignore
group = [group[i] for i in perm]
final_indices.extend(group)
return final_indices
[docs]class DistributedGroupedRandomSampler(Sampler[int]):
"""
A distributed sampler that groups samples by file to improve cache efficiency.
This sampler partitions the set of data FILES across the distributed processes,
not the individual samples. Each process then iterates through its assigned
files in a random order. Within each file, the samples are also shuffled.
This ensures that each process sees a unique subset of the data per epoch
while maximizing sequential reads from the same file, which is ideal for
lazy-loading datasets.
"""
[docs] def __init__(
self,
data_source: Union[SequifierDatasetFromFolder, SequifierDatasetFromFolderLazy],
num_replicas: int,
rank: int,
logger: Logger,
shuffle: bool = True,
seed: int = 0,
sampling_strategy: str = "exact",
):
"""
Args:
data_source: The dataset to sample from. Must have a `batch_files_info`
attribute.
num_replicas: Number of processes participating in distributed training.
rank: Rank of the current process.
shuffle: If True, shuffles the order of files and samples within files.
seed: Random seed used to create the permutation.
sampling_strategy: str = How to distribute data between GPUs
"""
super().__init__(None)
self.data_source = data_source
self.num_replicas = num_replicas
self.rank = rank
self.logger = logger
self.epoch = 0
self.shuffle = shuffle
self.seed = seed
self.sampling_strategy = sampling_strategy
# Pre-compute the global indices for each file, same as before
self.index_groups = []
start_index = 0
for file_info in self.data_source.batch_files_info:
num_samples_in_file = file_info["samples"]
indices = list(range(start_index, start_index + num_samples_in_file))
self.index_groups.append(indices)
start_index += num_samples_in_file
# Determine the number of files and samples this rank will process
self.num_files = len(self.index_groups)
self.files_for_this_rank = list(range(self.num_files))[
self.rank :: self.num_replicas
]
if len(self.files_for_this_rank) == 0:
if self.sampling_strategy == "oversampling":
random_viable_rank = np.random.randint(self.num_files)
self.files_for_this_rank = list(range(self.num_files))[
random_viable_rank :: self.num_replicas
]
else:
raise Exception(
f"No file found for GPU rank {self.rank}. Total number of files found: {self.num_files}. Please adapt your data or use 'oversampling'"
)
if self.sampling_strategy == "exact":
if self.num_files % self.num_replicas != 0:
raise ValueError(
f"Number of input files ({self.num_files}) must be divisible by "
f"world_size ({self.num_replicas}) when using 'exact' sampling strategy."
)
self.num_samples = sum(
len(self.index_groups[i]) for i in self.files_for_this_rank
)
elif self.sampling_strategy == "oversampling":
max_samples = 0
for r in range(self.num_replicas):
files_for_r = list(range(self.num_files))[r :: self.num_replicas]
samples_for_r = sum(len(self.index_groups[i]) for i in files_for_r)
max_samples = max(max_samples, samples_for_r)
self.num_samples = max_samples
elif self.sampling_strategy == "undersampling":
min_samples = float("inf")
for r in range(self.num_replicas):
files_for_r = list(range(self.num_files))[r :: self.num_replicas]
samples_for_r = sum(len(self.index_groups[i]) for i in files_for_r)
min_samples = min(min_samples, samples_for_r)
self.num_samples = int(min_samples)
[docs] def __iter__(self) -> Iterator[int]:
"""
Returns an iterator over indices for the current rank.
"""
# 1. Initialize generator with deterministic seed for this epoch
if self.shuffle:
generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch)
files_for_this_rank_order = torch.randperm(
len(self.files_for_this_rank), generator=generator
).tolist()
else:
generator = None
files_for_this_rank_order = list(range(len(self.files_for_this_rank)))
# 2. Assign a unique, non-overlapping subset of shuffled files to this rank
files_for_this_rank = [
self.files_for_this_rank[i] for i in files_for_this_rank_order
]
final_indices = get_final_indices(
files_for_this_rank, self.index_groups, self.shuffle, generator
)
if self.sampling_strategy == "oversampling":
while len(final_indices) < self.num_samples:
additional_file_for_this_rank = int(
torch.randint(0, self.num_files, (1,), generator=generator).item()
)
additional_indices = get_final_indices(
[additional_file_for_this_rank],
self.index_groups,
self.shuffle,
generator,
)
required_additional_indices = self.num_samples - len(final_indices)
final_indices.extend(additional_indices[:required_additional_indices])
files_for_this_rank.append(additional_file_for_this_rank)
elif self.sampling_strategy == "undersampling":
final_indices = final_indices[: self.num_samples]
self.logger.info(f"Files for rank {self.rank}: {files_for_this_rank}")
return iter(final_indices)
[docs] def __len__(self) -> int:
"""
Returns the number of samples for the current rank, not the total.
"""
return self.num_samples
[docs] def set_epoch(self, epoch: int) -> None:
"""
Sets the epoch for this sampler. This is used to create a different
shuffling order for each epoch.
"""
self.epoch = epoch