Skip to content

TrainableModule

minnt.Callback

Bases: Protocol

__call__

__call__(
    module: TrainableModule, epoch: int, logs: Logs
) -> Literal["stop_training"] | None

Represents a callback called after every training epoch.

If the callback returns TrainableModule.STOP_TRAINING, the training stops.

Parameters:

  • module (TrainableModule) –

    the module being trained

  • epoch (int) –

    the current epoch number (one-based)

  • logs (Logs) –

    a dictionary of logs, newly computed metric or losses should be added here

Returns:

minnt.TrainableModule

Bases: Module

Source code in minnt/trainable_module.py
38
39
40
41
42
43
44
45
46
47
48
49
class TrainableModule(torch.nn.Module):
    STOP_TRAINING: Literal["stop_training"] = "stop_training"
    """A constant returned by callbacks to stop the training."""

    def __init__(self, module: torch.nn.Module | None = None):
        """Initialize the module, optionally with an existing PyTorch module.

        Parameters:
          module: An optional existing PyTorch module to wrap, e.g., a [torch.nn.Sequential][]
            or a pretrained Transformer. If given, the module still must be configured.
        """
        raise NotImplementedError()

STOP_TRAINING class-attribute instance-attribute

STOP_TRAINING: Literal['stop_training'] = 'stop_training'

A constant returned by callbacks to stop the training.

__init__

__init__(module: Module | None = None)

Initialize the module, optionally with an existing PyTorch module.

Parameters:

  • module (Module | None, default: None ) –

    An optional existing PyTorch module to wrap, e.g., a torch.nn.Sequential or a pretrained Transformer. If given, the module still must be configured.

Source code in minnt/trainable_module.py
42
43
44
45
46
47
48
49
def __init__(self, module: torch.nn.Module | None = None):
    """Initialize the module, optionally with an existing PyTorch module.

    Parameters:
      module: An optional existing PyTorch module to wrap, e.g., a [torch.nn.Sequential][]
        or a pretrained Transformer. If given, the module still must be configured.
    """
    raise NotImplementedError()