Skip to content

KeepBestWeights

minnt.callbacks.KeepBestWeights

Bases: Callback

A callback that keeps the best model weights in memory.

Source code in minnt/callbacks/keep_best_weights.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class KeepBestWeights(Callback):
    """A callback that keeps the best model weights in memory."""

    def __init__(
        self,
        metric: str,
        mode: Literal["max", "min"] = "max",
        device: str | None = None,
    ) -> None:
        """Create the KeepBestWeights callback.

        Parameters:
          metric: The metric name from `logs` dictionary to monitor.
          mode: One of `"max"` or `"min"`, indicating whether the monitored metric should be maximized
            or minimized.
          device: The device where the weights will be stored. If `None`, the weights will be stored
            on the same device as the model.
        """
        assert mode in ("max", "min"), "mode must be one of 'max' or 'min'"

        self._metric = metric
        self._mode = mode
        self._device = device

        self.best_state_dict = None
        self.best_value = None

    best_state_dict: dict | None
    """The state dictionary containing the copies of best weights encountered so far."""

    best_value: float | None = None
    """The best metric value seen so far."""

    def __call__(self, module: "TrainableModule", epoch: int, logs: Logs) -> None:
        if (self.best_value is None
                or (self._mode == "max" and logs[self._metric] > self.best_value)
                or (self._mode == "min" and logs[self._metric] < self.best_value)):
            self.best_value = logs[self._metric]
            self.best_state_dict = {k: v.to(device=self._device, copy=True)
                                    for k, v in module.state_dict().items()}

__init__

__init__(
    metric: str, mode: Literal["max", "min"] = "max", device: str | None = None
) -> None

Create the KeepBestWeights callback.

Parameters:

  • metric (str) –

    The metric name from logs dictionary to monitor.

  • mode (Literal['max', 'min'], default: 'max' ) –

    One of "max" or "min", indicating whether the monitored metric should be maximized or minimized.

  • device (str | None, default: None ) –

    The device where the weights will be stored. If None, the weights will be stored on the same device as the model.

Source code in minnt/callbacks/keep_best_weights.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    metric: str,
    mode: Literal["max", "min"] = "max",
    device: str | None = None,
) -> None:
    """Create the KeepBestWeights callback.

    Parameters:
      metric: The metric name from `logs` dictionary to monitor.
      mode: One of `"max"` or `"min"`, indicating whether the monitored metric should be maximized
        or minimized.
      device: The device where the weights will be stored. If `None`, the weights will be stored
        on the same device as the model.
    """
    assert mode in ("max", "min"), "mode must be one of 'max' or 'min'"

    self._metric = metric
    self._mode = mode
    self._device = device

    self.best_state_dict = None
    self.best_value = None

best_state_dict instance-attribute

best_state_dict: dict | None

The state dictionary containing the copies of best weights encountered so far.

best_value class-attribute instance-attribute

best_value: float | None = None

The best metric value seen so far.