import ctypes
import json
import os
import signal
import subprocess
import sys
import time
import warnings
from typing import Union
import optuna
import torch._dynamo
import yaml
from beartype import beartype
torch._dynamo.config.suppress_errors = True
from sequifier.config.hyperparameter_search_config import ( # noqa: E402
load_hyperparameter_search_config,
)
from sequifier.helpers import ( # noqa: E402
get_best_model_path,
get_last_training_batch_timedelta,
)
from sequifier.io.yaml import TrainModelDumper # noqa: E402
[docs]def set_pdeathsig():
"""Binds child process lifecycle to the parent orchestrator via Linux prctl."""
if sys.platform.startswith("linux"):
libc = ctypes.CDLL("libc.so.6")
libc.prctl(1, signal.SIGTERM) # PR_SET_PDEATHSIG = 1
[docs]def objective(trial: optuna.Trial, config) -> Union[float, tuple[float, ...]]:
"""The central objective engine bridging Optuna to pure CLI execution.
This function handles generating the YAML configuration for the specific
trial, dynamically allocating a port for distributed training, launching the
training subprocess, asynchronously polling the validation metrics, and reporting
them back to Optuna for potential pruning.
Args:
trial (optuna.Trial): The Optuna trial object managing the current hyperparameter combination.
config (HyperparameterSearchConfig): The parsed hyperparameter search configuration.
Returns:
float: The best validation loss achieved during the trial.
Raises:
optuna.TrialPruned: If the trial is pruned by the Optuna orchestrator.
RuntimeError: If the training subprocess fails or is externally preempted.
"""
run_config = config.sample_trial(trial, trial.number)
run_name = run_config.model_name
# 1. YAML Generation
config_path = os.path.join(
config.project_root, config.model_config_write_path, f"{run_name}.yaml"
)
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, "w") as f:
yaml.dump(run_config, f, Dumper=TrainModelDumper, sort_keys=False)
os.environ["SEQUIFIER_HYPERPARAMETER_SEARCH_RUN"] = "1"
env = os.environ.copy()
cmd = ["sequifier", "train", f"--config-path={config_path}"]
process = subprocess.Popen(
cmd,
env=env,
preexec_fn=set_pdeathsig if sys.platform.startswith("linux") else None,
)
metrics_path = os.path.join(
config.project_root, "logs", f"sequifier-{run_name}-metrics.jsonl"
)
prune_path = os.path.join(
config.project_root, "logs", f"sequifier-{run_name}.prune"
)
last_read_pos = 0
best_val_loss = float("inf")
def consume_metrics(last_read_pos: int, best_val_loss: float) -> tuple[int, float]:
"""Helper closure to read written metrics and evaluate pruning."""
if os.path.exists(metrics_path):
with open(metrics_path, "r") as f:
f.seek(last_read_pos)
while True:
line = f.readline()
if not line or (not line.endswith("\n")):
break # Reached end of currently written data
try:
data = json.loads(line)
val_loss = data.get("val_loss")
global_step = data.get("global_step")
if global_step is not None and val_loss is not None:
# 5. Cooperative Pruning Evaluation
is_multi_objective = (
config.evaluation_metrics is not None
and len(config.evaluation_metrics) > 1
)
if not is_multi_objective:
trial.report(val_loss, global_step)
best_val_loss = min(best_val_loss, val_loss)
if config.prune_trials and trial.should_prune():
open(prune_path, "w").close()
try:
try:
timedelta = (
get_last_training_batch_timedelta(
run_name, 0, config.project_root
)
)
timeout_val = (timedelta * 2) + 30
except (ValueError, FileNotFoundError):
timeout_val = 60.0 # Safe default fallback
process.wait(timeout=timeout_val)
except subprocess.TimeoutExpired:
process.kill() # Escalation
raise optuna.TrialPruned()
last_read_pos = f.tell()
except json.JSONDecodeError:
break
return last_read_pos, best_val_loss
# 4. Asynchronous Polling & Caching Mitigation
while process.poll() is None:
last_read_pos, best_val_loss = consume_metrics(last_read_pos, best_val_loss)
time.sleep(2)
_, best_val_loss = consume_metrics(last_read_pos, best_val_loss)
exit_code = process.returncode
if exit_code == 143:
if os.path.exists(prune_path):
raise optuna.TrialPruned()
else:
raise RuntimeError(
f"Trial pre-empted externally by cluster (SIGTERM). Exit code: {exit_code}"
)
elif exit_code != 0:
raise RuntimeError(f"Training failed with exit code {exit_code}")
model_type = "onnx" if run_config.export_onnx else "pt"
model_path, last_epoch = get_best_model_path(
config.project_root, run_name, model_type
)
if config.evaluation_inference_config:
subprocess.run(
[
"sequifier",
"infer",
f"--config-path={config.evaluation_inference_config}",
f"--model-path={model_path}",
],
check=True,
)
if config.evaluation_script and config.evaluation_metrics:
eval_script_path = config.evaluation_script
cmd = [sys.executable, eval_script_path, f"{run_name}-best-{last_epoch}"]
eval_process = subprocess.run(
cmd, capture_output=True, text=True, cwd=config.project_root
)
if eval_process.returncode != 0:
raise RuntimeError(
f"Evaluation script failed (exit code {eval_process.returncode}):\n{eval_process.stderr}"
)
eval_json_path = os.path.join(
config.project_root,
"outputs",
"evaluations",
f"{run_name}-best-{last_epoch}.json",
)
if not os.path.exists(eval_json_path):
raise FileNotFoundError(
f"Evaluation JSON not found at expected path: {eval_json_path}"
)
with open(eval_json_path, "r") as f:
eval_results = json.load(f)
eval_results_keys = set(list(eval_results.keys()))
evaluation_metrics = set(config.evaluation_metrics)
missing_metrics = evaluation_metrics.difference(eval_results_keys)
excess_metrics = eval_results_keys.difference(evaluation_metrics)
if len(missing_metrics):
raise ValueError(
f"Some of the configured evaluation metrics are not in the script output: {missing_metrics}"
)
if len(excess_metrics):
warnings.warn(
f"Some metrics output by the script are not used in hyperparameter optimization: {excess_metrics}"
)
metrics = []
for metric in config.evaluation_metrics:
if metric not in eval_results:
raise KeyError(
f"Metric '{metric}' missing in {eval_json_path}. Found keys: {list(eval_results.keys())}"
)
metrics.append(float(eval_results[metric]))
if len(metrics) == 1:
return metrics[0]
else:
return tuple(metrics)
return best_val_loss
[docs]@beartype
def hyperparameter_search(config_path: str, skip_metadata: bool) -> None:
"""Main function for initiating an Optuna-based hyperparameter search process.
This function loads the configuration, initializes the Optuna study with a
minimization direction, and kicks off the optimization loop. Once the configured
number of trials is complete, it prints out the best trial's value and hyperparameters.
Args:
config_path (str): Path to the hyperparameter search YAML configuration file.
skip_metadata (bool): Flag indicating whether to skip loading/processing data metadata.
Raises:
ValueError: If `n_trials` is not defined in the configuration.
"""
config = load_hyperparameter_search_config(config_path, skip_metadata)
os.makedirs(os.path.join(config.project_root, "state", "optuna"), exist_ok=True)
strategy = getattr(config, "search_strategy", "bayesian")
if strategy in ["sample"]:
sampler = optuna.samplers.RandomSampler()
elif strategy == "grid":
if hasattr(optuna.samplers, "BruteForceSampler"):
sampler = optuna.samplers.BruteForceSampler()
else:
raise RuntimeError(
"Grid search requires Optuna >= 3.1 for BruteForceSampler."
)
else: # "bayesian"
sampler = optuna.samplers.TPESampler()
storage_path = os.path.join(
config.project_root, "state", "optuna", f"{config.hp_search_name}.db"
)
is_multivariate = (
config.evaluation_metrics is not None and len(config.evaluation_metrics) > 1
)
if is_multivariate:
study = optuna.create_study(
study_name=config.hp_search_name,
directions=config.evaluation_metric_directions,
sampler=sampler,
storage=f"sqlite:///{storage_path}",
load_if_exists=True,
)
else:
direction = (
config.evaluation_metric_directions[0]
if (
config.evaluation_metric_directions
and len(config.evaluation_metric_directions) == 1
)
else "minimize"
)
study = optuna.create_study(
study_name=config.hp_search_name,
direction=direction,
sampler=sampler,
storage=f"sqlite:///{storage_path}",
load_if_exists=True,
)
n_trials = config.n_trials
if n_trials is None and config.search_strategy != "grid":
raise ValueError(
"n_trials/n_samples must be specified for hyperparameter search."
)
study.optimize(lambda trial: objective(trial, config), n_trials=n_trials)
if is_multivariate:
print("\nBest trials (Pareto front):")
for trial in study.best_trials:
print(f" Values: {trial.values}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
else:
print("\nBest trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")