import os
from typing import Optional, Union
import numpy as np
import polars as pl
import torch
from beartype import beartype
from torch import Tensor
PANDAS_TO_TORCH_TYPES = {
"Int64": torch.int64,
"Float64": torch.float32,
"int64": torch.int64,
"float64": torch.float32,
}
[docs]
@beartype
def construct_index_maps(
id_maps: Optional[dict[str, dict[Union[str, int], int]]],
target_columns_index_map: list[str],
map_to_id: Optional[bool],
) -> dict[str, dict[int, Union[str, int]]]:
"""Constructs reverse index maps (int index to original ID).
This function creates reverse mappings from the integer indices back to
the original string or integer identifiers. It only performs this
operation if `map_to_id` is True and `id_maps` is provided.
A special mapping for index 0 is added:
- If original IDs are strings, 0 maps to "unknown".
- If original IDs are integers, 0 maps to (minimum original ID) - 1.
Args:
id_maps: A nested dictionary mapping column names to their
respective ID-to-index maps (e.g.,
`{'col_name': {'original_id': 1, ...}}`). Expected to be
provided if `map_to_id` is True.
target_columns_index_map: A list of column names for which to
construct the reverse maps.
map_to_id: A boolean flag. If True, the reverse maps are
constructed. If False or None, an empty dictionary is returned.
Returns:
A dictionary where keys are column names from
`target_columns_index_map` and values are the reverse maps
(index-to-original-ID). Returns an empty dict if `map_to_id`
is not True.
Raises:
AssertionError: If `map_to_id` is True but `id_maps` is None.
AssertionError: If the values of a map are not consistently
string or integer (excluding the added '0' key).
"""
index_map = {}
if map_to_id is not None and map_to_id:
assert id_maps is not None
for target_column in target_columns_index_map:
map_ = {v: k for k, v in id_maps[target_column].items()}
val = next(iter(map_.values()))
if isinstance(val, str):
map_[0] = "unknown"
else:
assert isinstance(val, int)
map_[0] = min(map_.values()) - 1 # type: ignore
index_map[target_column] = map_
return index_map
[docs]
@beartype
def read_data(
path: str, read_format: str, columns: Optional[list[str]] = None
) -> pl.DataFrame:
"""Reads data from a CSV or Parquet file into a Polars DataFrame.
Args:
path: The file path to read from.
read_format: The format of the file. Supported formats are
"csv" and "parquet".
columns: An optional list of column names to read. This argument
is only used when `read_format` is "parquet".
Returns:
A Polars DataFrame containing the data from the file.
Raises:
ValueError: If `read_format` is not "csv" or "parquet".
"""
if read_format == "csv":
return pl.read_csv(path, separator=",")
if read_format == "parquet":
return pl.read_parquet(path, columns=columns)
raise ValueError(f"Unsupported read format: {read_format}")
[docs]
@beartype
def write_data(data: pl.DataFrame, path: str, write_format: str, **kwargs) -> None:
"""Writes a Polars (or Pandas) DataFrame to a CSV or Parquet file.
This function detects the type of the input DataFrame.
- For Polars DataFrames, it uses `.write_csv()` or `.write_parquet()`.
- For other DataFrame types (presumably Pandas), it uses `.to_csv()`
or `.to_parquet()`.
Note: The type hint specifies `pl.DataFrame`, but the implementation
includes a fallback path that suggests compatibility with Pandas
DataFrames.
Args:
data: The Polars (or Pandas) DataFrame to write.
path: The destination file path.
write_format: The format to write. Supported formats are
"csv" and "parquet".
**kwargs: Additional keyword arguments passed to the underlying
write function (e.g., `write_csv` for Polars, `to_csv` for
Pandas).
Returns:
None.
Raises:
ValueError: If `write_format` is not "csv" or "parquet".
"""
if isinstance(data, pl.DataFrame):
if write_format == "csv":
data.write_csv(path, **kwargs)
elif write_format == "parquet":
data.write_parquet(path)
else:
raise ValueError(
f"Unsupported write format for Polars DataFrame: {write_format}"
)
return
if write_format == "csv":
data.to_csv(path, separator=",", index=False, **kwargs)
elif write_format == "parquet":
data.to_parquet(path)
else:
raise ValueError(f"Unsupported write format: {write_format}")
[docs]
@beartype
def subset_to_selected_columns(
data: Union[pl.DataFrame, pl.LazyFrame], selected_columns: list[str]
) -> Union[pl.DataFrame, pl.LazyFrame]:
"""Filters a DataFrame to rows where 'inputCol' is in a selected list.
This function supports both Polars (DataFrame, LazyFrame) and Pandas
DataFrames, dispatching to the appropriate filtering method.
- For Polars objects, it uses `data.filter(pl.col("inputCol").is_in(...))`.
- For other objects (presumably Pandas), it builds a numpy boolean
mask and filters using `data.loc[...]`.
Note: The type hint only specifies Polars objects, but the
implementation includes a fallback path for Pandas-like objects.
Args:
data: The Polars (DataFrame, LazyFrame) or Pandas DataFrame to
filter. It must contain a column named "inputCol".
selected_columns: A list of values. Rows will be kept if their
value in "inputCol" is present in this list.
Returns:
A filtered DataFrame or LazyFrame of the same type as the input.
"""
if isinstance(data, (pl.DataFrame, pl.LazyFrame)):
return data.filter(pl.col("inputCol").is_in(selected_columns))
column_filters = [
(data["inputCol"].values == input_col) for input_col in selected_columns
]
filter_ = np.logical_or.reduce(column_filters)
return data.loc[filter_, :]
[docs]
@beartype
def numpy_to_pytorch(
data: pl.DataFrame,
column_types: dict[str, torch.dtype],
all_columns: list[str], # Changed from selected_columns, target_columns
seq_length: int,
) -> dict[str, Tensor]: # Now returns a single dictionary
"""Converts a long-format Polars DataFrame to a dict of sequence tensors.
This function assumes the input DataFrame `data` is in a long format
where each row represents a sequence for a specific feature. It expects
a column named "inputCol" that contains the feature name (e.g.,
'price', 'volume') and other columns representing time steps (e.g.,
"0", "1", ..., "L").
It generates two tensors for each column in `all_columns`:
1. An "input" tensor (from time steps L down to 1).
2. A "target" tensor (from time steps L-1 down to 0).
Example:
For `seq_length = 3` and `all_columns = ['price']`, it will create:
- 'price': Tensor from columns ["3", "2", "1"]
- 'price_target': Tensor from columns ["2", "1", "0"]
Args:
data: The long-format Polars DataFrame. Must contain "inputCol"
and columns named as strings of integers for time steps.
column_types: A dictionary mapping feature names (from "inputCol")
to their desired `torch.dtype`.
all_columns: A list of all feature names (from "inputCol") to
be processed and converted into tensors.
seq_length: The total sequence length (L). This determines the
column names for time steps (e.g., "0" to "L").
Returns:
A dictionary mapping feature names to their corresponding PyTorch
tensors. Target tensors are stored with a `_target` suffix
(e.g., `{'price': <tensor>, 'price_target': <tensor>}`).
"""
# Define both input and target sequence column names
input_seq_cols = [str(c) for c in range(seq_length, 0, -1)]
target_seq_cols = [str(c) for c in range(seq_length - 1, -1, -1)]
# We will create a unified dictionary
unified_tensors = {}
for col_name in all_columns:
# Create the input sequence tensor (e.g., from t=1 to t=L)
input_tensor = torch.tensor(
data.filter(pl.col("inputCol") == col_name)
.select(input_seq_cols)
.to_numpy(),
dtype=column_types[col_name],
)
unified_tensors[col_name] = input_tensor
# Create the target sequence tensor (e.g., from t=0 to t=L-1)
# We'll store it with a "_target" suffix to distinguish it
target_tensor = torch.tensor(
data.filter(pl.col("inputCol") == col_name)
.select(target_seq_cols)
.to_numpy(),
dtype=column_types[col_name],
)
unified_tensors[f"{col_name}_target"] = target_tensor
return unified_tensors
[docs]
class LogFile:
"""Manages logging to multiple files based on verbosity levels.
This class opens multiple log files based on a path template and a
hardcoded list of levels (2 and 3). Messages are written to files
based on their assigned level, and high-level messages are also
printed to the console on the main process (rank 0).
Attributes:
rank (Optional[int]): The rank of the current process, used to
control console output.
levels (list[int]): The hardcoded list of log levels [2, 3]
for which files are created.
_files (dict[int, io.TextIOWrapper]): A dictionary mapping log
levels to their open file handlers.
_path (str): The original path template provided.
"""
[docs]
@beartype
def __init__(self, path: str, rank: Optional[int] = None):
"""Initializes the LogFile and opens log files.
The `path` argument should be a template containing "[NUMBER]",
which will be replaced by the log levels (2 and 3) to create
separate log files.
Args:
path: The path template for the log files (e.g.,
"run_log_[NUMBER].txt").
rank: The rank of the current process (e.g., in distributed
training). If None or 0, high-level messages will be
printed to stdout.
"""
self.rank = rank
self.levels = [2, 3]
self._files = {
level: path.replace("[NUMBER]", str(level)) for level in self.levels
}
self._path = path
[docs]
@beartype
def write(self, string: str, level: int = 3) -> None:
"""Writes a string to log files and potentially the console.
The string is written to all log files whose level is less than
or equal to the specified `level`.
- A message with `level=2` goes to file 2.
- A message with `level=3` goes to file 2 and file 3.
If `level` is 3 or greater, the message is also printed to stdout
if `self.rank` is None or 0.
Args:
string: The message to log.
level: The verbosity level of the message. Defaults to 3.
"""
for level2 in self.levels:
if level2 <= level:
with open(self._files[level2], "a+", encoding="utf-8") as f:
f.write(f"{string}\n")
if level >= 3:
if self.rank is None or self.rank == 0:
print(string)
[docs]
@beartype
def normalize_path(path: str, project_path: str) -> str:
"""Normalizes a path to be relative to a project path, then joins them.
This function ensures that a given `path` is correctly expressed as
an absolute path rooted at `project_path`. It does this by first
removing the `project_path` prefix from `path` (if it exists)
and then joining the result back to `project_path`.
This is useful for handling paths that might be provided as either
relative (e.g., "data/file.txt") or absolute
(e.g., "/abs/path/to/project/data/file.txt").
Args:
path: The path to normalize.
project_path: The absolute path to the project's root directory.
Returns:
A normalized, absolute path.
"""
project_path_normalized = (project_path + os.sep).replace(os.sep + os.sep, os.sep)
path2 = os.path.join(project_path, path.replace(project_path_normalized, ""))
return path2