Source code for sbss.common.callbacks.cyclic_annealer

# Copyright (C) 2025 National Institute of Advanced Industrial Science and Technology (AIST)
# SPDX-License-Identifier: MIT

from typing import Any

import torch

import lightning as lt


[docs] class CyclicAnnealerCallback(lt.Callback): """Lightning callback that cyclically varies a module attribute during training. This callback modifies a specified attribute (e.g., a learning rate or coefficient) in the Lightning module according to a cyclic schedule. The value increases linearly within each cycle and resets afterward. An optional initial period can use a different maximum value before switching to the main cycle. Args: name (str): Name of the attribute in the Lightning module to modify. cycle (int): Number of training epochs or fraction of total steps that define one cycle. max_value (float): Maximum value reached during the cycle. ini_period (int, optional): Initial period before cyclic behavior starts. Defaults to 0. ini_max_value (float, optional): Maximum value used during the initial period. Defaults to 1.0. Returns: None """
[docs] def __init__(self, name: str, cycle: int, max_value: float, ini_period: int = 0, ini_max_value: float = 1.0): self.name = name self.cycle = cycle self.max_value = max_value self.ini_period = ini_period self.ini_max_value = ini_max_value
def on_train_batch_start(self, trainer: lt.Trainer, pl_module: Any, batch: torch.Tensor, batch_idx: int): step = trainer.global_step / trainer.num_training_batches max_value = self.ini_max_value if step < self.ini_period else self.max_value setattr(pl_module, self.name, max_value * min(2 * (step % self.cycle) / self.cycle, 1))