Skip to content

SaveBestWeights

minnt.callbacks.SaveBestWeights

Bases: Callback

A callback that saves best model weights to a file.

Source code in minnt/callbacks/save_best_weights.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class SaveBestWeights(Callback):
    """A callback that saves best model weights to a file."""

    def __init__(
        self,
        path: str,
        metric: str,
        mode: Literal["max", "min"] = "max",
        optimizer_path: str | None = None,
    ) -> None:
        """Create the SaveBestWeights callback.

        Parameters:
          path: A path where weights will be saved using the [minnt.TrainableModule.save_weights][]
            method after each epoch. Note that you can use templates like `{logdir}` and `{epoch[:formatting]}`.
          metric: The metric name from `logs` dictionary to monitor.
          mode: One of `"max"` or `"min"`, indicating whether the monitored metric should be maximized
            or minimized.
          optimizer_path: An optional path passed to [minnt.TrainableModule.save_weights][] to
            save also the optimizer state; it is relative to `path`.
        """
        assert mode in ("max", "min"), "mode must be one of 'max' or 'min'"

        self._path = path
        self._metric = metric
        self._mode = mode
        self._optimizer_path = optimizer_path

        self.best_metric_value = None

    best_value: float | None = None
    """The best metric value seen so far."""

    def __call__(self, module: "TrainableModule", epoch: int, logs: Logs) -> None:
        if (self.best_value is None
                or (self._mode == "max" and logs[self._metric] > self.best_value)
                or (self._mode == "min" and logs[self._metric] < self.best_value)):
            self.best_value = logs[self._metric]

            module.save_weights(self._path, optimizer_path=self._optimizer_path)

__init__

__init__(
    path: str,
    metric: str,
    mode: Literal["max", "min"] = "max",
    optimizer_path: str | None = None,
) -> None

Create the SaveBestWeights callback.

Parameters:

  • path (str) –

    A path where weights will be saved using the minnt.TrainableModule.save_weights method after each epoch. Note that you can use templates like {logdir} and {epoch[:formatting]}.

  • metric (str) –

    The metric name from logs dictionary to monitor.

  • mode (Literal['max', 'min'], default: 'max' ) –

    One of "max" or "min", indicating whether the monitored metric should be maximized or minimized.

  • optimizer_path (str | None, default: None ) –

    An optional path passed to minnt.TrainableModule.save_weights to save also the optimizer state; it is relative to path.

Source code in minnt/callbacks/save_best_weights.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    path: str,
    metric: str,
    mode: Literal["max", "min"] = "max",
    optimizer_path: str | None = None,
) -> None:
    """Create the SaveBestWeights callback.

    Parameters:
      path: A path where weights will be saved using the [minnt.TrainableModule.save_weights][]
        method after each epoch. Note that you can use templates like `{logdir}` and `{epoch[:formatting]}`.
      metric: The metric name from `logs` dictionary to monitor.
      mode: One of `"max"` or `"min"`, indicating whether the monitored metric should be maximized
        or minimized.
      optimizer_path: An optional path passed to [minnt.TrainableModule.save_weights][] to
        save also the optimizer state; it is relative to `path`.
    """
    assert mode in ("max", "min"), "mode must be one of 'max' or 'min'"

    self._path = path
    self._metric = metric
    self._mode = mode
    self._optimizer_path = optimizer_path

    self.best_metric_value = None

best_value class-attribute instance-attribute

best_value: float | None = None

The best metric value seen so far.