Skip to content

GenericDecay

minnt.schedulers.GenericDecay

Bases: LambdaLR

A generic cosine/linear decay learning rate scheduler with optional linear warmup.

If specified, this scheduler first linearly increases the learning rate from 0 to the initial learning rate during an optional warmup phase. Then it decreases the learning rate according to a specified decay strategy (cosine/linear/none) to a final fraction of the initial learning rate defined by final_decay (default 0.0, i.e., decays to 0) over the remaining training steps.

Source code in minnt/schedulers/generic_decay.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class GenericDecay(torch.optim.lr_scheduler.LambdaLR):
    """A generic cosine/linear decay learning rate scheduler with optional linear warmup.

    If specified, this scheduler first linearly increases the learning rate from 0 to the initial
    learning rate during an optional warmup phase. Then it decreases the learning rate according to a
    specified decay strategy (cosine/linear/none) to a final fraction of the initial
    learning rate defined by `final_decay` (default 0.0, i.e., decays to 0) over the remaining
    training steps.
    """
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        total_steps: int,
        decay: Literal["cosine", "linear", "none"],
        *,
        final_decay: float = 0.0,
        warmup: int | float = 0,
        last_epoch: int = -1,
        warn_about_exceeding_steps: bool = True,
    ) -> None:
        r"""Creates a new GenericDecay scheduler instance.

        Parameters:
          optimizer: The optimizer for which to schedule the learning rate.
          total_steps: The total number of training steps in all epochs, including
            the optional warmup phase.
          decay: The decay strategy to use after the warmup phase, one of:

            - `"cosine"`: cosine decay, computed as

            $$\operatorname{decay\_factor}(t) = \mathit{final\_decay} + (1 - \mathit{final\_decay}) \cdot
              \bigg(\frac{1 + \cos(\pi \cdot t / \mathit{decay\_steps})}{2}\bigg);$$

            - `"linear"`: linear decay, computed as

            $$\operatorname{decay\_factor}(t) = \mathit{final\_decay} + (1 - \mathit{final\_decay}) \cdot
              \bigg(1 - \frac{t}{\mathit{decay\_steps}}\bigg);$$

            - `"none"`: no decay, i.e., keeping the initial learning rate.
          final_decay: The final learning rate as a fraction of the initial
            learning rate after decay. Default is 0.0 (decays to 0).
          warmup: Specifies the warmup phase. If a number smaller than 1 is given,
            it is treated as a fraction of `total_steps`; otherwise, it is treated as
            an absolute number of steps. Default is 0 (no warmup).
          last_epoch: The index of the last epoch when resuming training. Default is -1.
          warn_about_exceeding_steps: Whether to raise a [RuntimeWarning][] if the number of steps
            exceeds the `total_steps`.
        """
        assert decay in ("cosine", "linear", "none"), f"Unknown decay strategy: {decay}"
        self._decay = decay

        self._warmup_steps = int(warmup * total_steps if warmup < 1 else warmup)
        self._decay_steps = total_steps - self._warmup_steps
        assert self._warmup_steps <= total_steps, "Warmup steps must be at most the total steps"

        self._final_decay = final_decay
        self._warn_about_exceeding_steps = warn_about_exceeding_steps
        super().__init__(optimizer, self.compute_decay_factor, last_epoch)

    def compute_decay_factor(self, step: int) -> float:
        if step < self._warmup_steps:
            return step / self._warmup_steps

        if step > self._warmup_steps + self._decay_steps:
            if self._warn_about_exceeding_steps:
                warnings.warn(
                    f"Step {step} exceeds total steps ({self._warmup_steps + self._decay_steps}). "
                    "The final learning rate will be kept.", RuntimeWarning)
            step = self._warmup_steps + self._decay_steps

        if self._decay == "none" or self._decay_steps == 0:
            decay = 1.0
        elif self._decay == "cosine":
            decay = 0.5 * (1 + math.cos(math.pi * ((step - self._warmup_steps) / self._decay_steps)))
        elif self._decay == "linear":
            decay = 1.0 - (step - self._warmup_steps) / self._decay_steps

        if self._final_decay:
            decay = self._final_decay + (1 - self._final_decay) * decay

        return decay

__init__

__init__(
    optimizer: Optimizer,
    total_steps: int,
    decay: Literal["cosine", "linear", "none"],
    *,
    final_decay: float = 0.0,
    warmup: int | float = 0,
    last_epoch: int = -1,
    warn_about_exceeding_steps: bool = True
) -> None

Creates a new GenericDecay scheduler instance.

Parameters:

  • optimizer (Optimizer) –

    The optimizer for which to schedule the learning rate.

  • total_steps (int) –

    The total number of training steps in all epochs, including the optional warmup phase.

  • decay (Literal['cosine', 'linear', 'none']) –

    The decay strategy to use after the warmup phase, one of:

    • "cosine": cosine decay, computed as
    \[\operatorname{decay\_factor}(t) = \mathit{final\_decay} + (1 - \mathit{final\_decay}) \cdot \bigg(\frac{1 + \cos(\pi \cdot t / \mathit{decay\_steps})}{2}\bigg);\]
    • "linear": linear decay, computed as
    \[\operatorname{decay\_factor}(t) = \mathit{final\_decay} + (1 - \mathit{final\_decay}) \cdot \bigg(1 - \frac{t}{\mathit{decay\_steps}}\bigg);\]
    • "none": no decay, i.e., keeping the initial learning rate.
  • final_decay (float, default: 0.0 ) –

    The final learning rate as a fraction of the initial learning rate after decay. Default is 0.0 (decays to 0).

  • warmup (int | float, default: 0 ) –

    Specifies the warmup phase. If a number smaller than 1 is given, it is treated as a fraction of total_steps; otherwise, it is treated as an absolute number of steps. Default is 0 (no warmup).

  • last_epoch (int, default: -1 ) –

    The index of the last epoch when resuming training. Default is -1.

  • warn_about_exceeding_steps (bool, default: True ) –

    Whether to raise a RuntimeWarning if the number of steps exceeds the total_steps.

Source code in minnt/schedulers/generic_decay.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    optimizer: torch.optim.Optimizer,
    total_steps: int,
    decay: Literal["cosine", "linear", "none"],
    *,
    final_decay: float = 0.0,
    warmup: int | float = 0,
    last_epoch: int = -1,
    warn_about_exceeding_steps: bool = True,
) -> None:
    r"""Creates a new GenericDecay scheduler instance.

    Parameters:
      optimizer: The optimizer for which to schedule the learning rate.
      total_steps: The total number of training steps in all epochs, including
        the optional warmup phase.
      decay: The decay strategy to use after the warmup phase, one of:

        - `"cosine"`: cosine decay, computed as

        $$\operatorname{decay\_factor}(t) = \mathit{final\_decay} + (1 - \mathit{final\_decay}) \cdot
          \bigg(\frac{1 + \cos(\pi \cdot t / \mathit{decay\_steps})}{2}\bigg);$$

        - `"linear"`: linear decay, computed as

        $$\operatorname{decay\_factor}(t) = \mathit{final\_decay} + (1 - \mathit{final\_decay}) \cdot
          \bigg(1 - \frac{t}{\mathit{decay\_steps}}\bigg);$$

        - `"none"`: no decay, i.e., keeping the initial learning rate.
      final_decay: The final learning rate as a fraction of the initial
        learning rate after decay. Default is 0.0 (decays to 0).
      warmup: Specifies the warmup phase. If a number smaller than 1 is given,
        it is treated as a fraction of `total_steps`; otherwise, it is treated as
        an absolute number of steps. Default is 0 (no warmup).
      last_epoch: The index of the last epoch when resuming training. Default is -1.
      warn_about_exceeding_steps: Whether to raise a [RuntimeWarning][] if the number of steps
        exceeds the `total_steps`.
    """
    assert decay in ("cosine", "linear", "none"), f"Unknown decay strategy: {decay}"
    self._decay = decay

    self._warmup_steps = int(warmup * total_steps if warmup < 1 else warmup)
    self._decay_steps = total_steps - self._warmup_steps
    assert self._warmup_steps <= total_steps, "Warmup steps must be at most the total steps"

    self._final_decay = final_decay
    self._warn_about_exceeding_steps = warn_about_exceeding_steps
    super().__init__(optimizer, self.compute_decay_factor, last_epoch)