Skip to content

EarlyStopping

minnt.callbacks.EarlyStopping

Bases: Callback

A callback that stops the training after a metric stops improving.

Source code in minnt/callbacks/early_stopping.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class EarlyStopping(Callback):
    """A callback that stops the training after a metric stops improving."""

    def __init__(
        self,
        metric: str,
        patience: int,
        mode: Literal["max", "min"] = "max",
    ) -> None:
        """Create the EarlyStopping callback.

        Parameters:
          metric: The metric name from `logs` dictionary to monitor.
          patience: The callback stops the training if the monitored metric does not improve for
            `patience` consecutive epochs.
          mode: One of `"max"` or `"min"`, indicating whether the monitored metric should be maximized
            or minimized.
        """
        assert mode in ("max", "min"), "mode must be one of 'max' or 'min'"

        self._metric = metric
        self._mode = mode
        self._patience = patience
        self._epochs_without_improvement = 0

        self.best_value = None

    best_value: float | None
    """The best metric value seen so far."""

    def __call__(self, module: "TrainableModule", epoch: int, logs: Logs) -> StopTraining | None:
        if (self.best_value is None
                or (self._mode == "max" and logs[self._metric] > self.best_value)
                or (self._mode == "min" and logs[self._metric] < self.best_value)):
            self.best_value = logs[self._metric]
            self._epochs_without_improvement = 0
        else:
            self._epochs_without_improvement += 1

        if self._epochs_without_improvement >= self._patience:
            return STOP_TRAINING

__init__

__init__(
    metric: str, patience: int, mode: Literal["max", "min"] = "max"
) -> None

Create the EarlyStopping callback.

Parameters:

  • metric (str) –

    The metric name from logs dictionary to monitor.

  • patience (int) –

    The callback stops the training if the monitored metric does not improve for patience consecutive epochs.

  • mode (Literal['max', 'min'], default: 'max' ) –

    One of "max" or "min", indicating whether the monitored metric should be maximized or minimized.

Source code in minnt/callbacks/early_stopping.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    metric: str,
    patience: int,
    mode: Literal["max", "min"] = "max",
) -> None:
    """Create the EarlyStopping callback.

    Parameters:
      metric: The metric name from `logs` dictionary to monitor.
      patience: The callback stops the training if the monitored metric does not improve for
        `patience` consecutive epochs.
      mode: One of `"max"` or `"min"`, indicating whether the monitored metric should be maximized
        or minimized.
    """
    assert mode in ("max", "min"), "mode must be one of 'max' or 'min'"

    self._metric = metric
    self._mode = mode
    self._patience = patience
    self._epochs_without_improvement = 0

    self.best_value = None

best_value instance-attribute

best_value: float | None

The best metric value seen so far.