Source code for sequifier.optimizers.optimizers

import torch
import torch_optimizer  # noqa: F401

from sequifier.optimizers.ademamix import AdEMAMix

CUSTOM_OPTIMIZERS = ["AdEMAMix"]

TORCH_OPTIMIZERS = [
    "A2GradUni",
    "A2GradInc",
    "A2GradExp",
    "AccSGD",
    "AdaBelief",
    "AdaBound",
    "Adafactor",
    "Adahessian",
    "AdaMod",
    "AdamP",
    "AggMo",
    "Apollo",
    "DiffGrad",
    "Lamb",
    "LARS",
    "Lion",
    "Lookahead",
    "MADGRAD",
    "NovoGrad",
    "PID",
    "QHAdam",
    "QHM",
    "RAdam",
    "SGDP",
    "SGDW",
    "Shampoo",
    "SWATS",
    "Yogi",
]


[docs] def get_optimizer_class(optimizer_name: str) -> torch.optim.Optimizer: """Gets the optimizer class from a string. Args: optimizer_name: The name of the optimizer. Returns: The optimizer class. """ if optimizer_name in CUSTOM_OPTIMIZERS: if optimizer_name == "AdEMAMix": return AdEMAMix else: raise Exception(f"Optimizer '{optimizer_name}' is not available") elif optimizer_name in TORCH_OPTIMIZERS: return eval(f"torch_optimizer.{optimizer_name}") else: return eval(f"torch.optim.{optimizer_name}")