Source code for sequifier.optimizers.optimizers
import torch
import torch_optimizer # noqa: F401
from sequifier.optimizers.ademamix import AdEMAMix
CUSTOM_OPTIMIZERS = {"AdEMAMix": AdEMAMix}
[docs]
def get_optimizer_class(optimizer_name: str) -> torch.optim.Optimizer:
"""Gets the optimizer class from a string.
Args:
optimizer_name: The name of the optimizer.
Returns:
The optimizer class.
"""
if optimizer_name in CUSTOM_OPTIMIZERS:
return CUSTOM_OPTIMIZERS[optimizer_name]
elif hasattr(torch_optimizer, optimizer_name):
return getattr(torch_optimizer, optimizer_name)
elif hasattr(torch.optim, optimizer_name):
return getattr(torch.optim, optimizer_name)
else:
raise ValueError(f"Optimizer '{optimizer_name}' not found.")