Skip to content

Examples

Using TrainableModule

We start with an example that illustrates just the minnt.TrainableModule class; the data preparation, data loaders, and losses are pure PyTorch. Note that we do use minnt.metrics.CategoricalAccuracy as a metric; but

    metrics={"accuracy": torchmetrics.Accuracy("multiclass", num_classes=10)},
could also have been used.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    to_tensor = torchvision.transforms.functional.to_tensor  # convert the MNIST PIL images to tensors
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True, transform=to_tensor)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True, transform=to_tensor)

    # Create data loaders from the datasets.
    train = torch.utils.data.DataLoader(mnist_train, batch_size=args.batch_size, shuffle=True)
    dev = torch.utils.data.DataLoader(mnist_dev, batch_size=args.batch_size)
    test = torch.utils.data.DataLoader(mnist_test, batch_size=args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=torch.optim.Adam(model.parameters(), args.learning_rate),
        loss=torch.nn.CrossEntropyLoss(),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

When executed, the script trains a convolutional network on MNIST, using a GPU or other accelerator if available or CPU otherwise, evaluating loss and accuracy on the development set every epoch, and finally evaluating loss and accuracy on the test data. The logs are shown both on a console, as illustrated here:

Config epoch=0 batch_size=50 cnn_dim=16 dropout=0.2 epochs=10 hidden_layer_size=256 learning_rate=0.001 seed=42 threads=1
Epoch 1/10 7.7s loss=0.2236 accuracy=0.9304 dev:loss=0.0495 dev:accuracy=0.9840
Epoch 2/10 7.1s loss=0.0564 accuracy=0.9828 dev:loss=0.0412 dev:accuracy=0.9880
Epoch 3/10 loss=0.0440 accuracy=0.9850:  42%|███████▎         | 462/1100 [00:02<00:03, 168.96batch/s]

and furthermore (because logdir option of configure is specified), a directory logs/0_trainable_module-YYYYMMDD_HHMMSS-... is created, containing the training and evaluation logs both as plain text files and as TensorBoard logs (which can be browsed at http://localhost:6006 after running the tensorboard --logdir logs command).

Full Minnt Example

We now present a full Minnt example, which extends the above script by using minnt.TransformedDataset to process training and evaluation data, using a configurable learning rate decay (cosine by default with linear or none override), and using the minnt.losses.CategoricalCrossEntropy loss. Below you can view the whole script or just the diff to the previous example.

@@ -14,5 +14,7 @@
 parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
 parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
+parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
 parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
+parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
 parser.add_argument("--seed", default=42, type=int, help="Random seed.")
 parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
@@ -36,4 +38,10 @@


+class Dataset(minnt.TransformedDataset):
+    def transform(self, image, label):
+        image = torchvision.transforms.functional.to_tensor(image)
+        return image, label
+
+
 def main(args: argparse.Namespace) -> None:
     # Set the random seed and the number of threads.
@@ -42,14 +50,13 @@

     # Load the data using torchvision.
-    to_tensor = torchvision.transforms.functional.to_tensor  # convert the MNIST PIL images to tensors
-    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True, transform=to_tensor)
+    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
     mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
     mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
-    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True, transform=to_tensor)
+    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

     # Create data loaders from the datasets.
-    train = torch.utils.data.DataLoader(mnist_train, batch_size=args.batch_size, shuffle=True)
-    dev = torch.utils.data.DataLoader(mnist_dev, batch_size=args.batch_size)
-    test = torch.utils.data.DataLoader(mnist_test, batch_size=args.batch_size)
+    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
+    dev = Dataset(mnist_dev).dataloader(args.batch_size)
+    test = Dataset(mnist_test).dataloader(args.batch_size)

     # Create a model according to the given arguments.
@@ -59,6 +66,7 @@
     # Configure the model for training.
     model.configure(
-        optimizer=torch.optim.Adam(model.parameters(), args.learning_rate),
-        loss=torch.nn.CrossEntropyLoss(),
+        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
+        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
+        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
         metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
         logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Note that the data loaders are now created using minnt.TransformedDataset.dataloader, and that the data processing now happens in minnt.TransformedDataset.transform.

We now illustrate several Minnt features and extensions by modifying this example.

Data Augmentation

To implement data augmentation using PyTorch transforms, we can construct two augmentation pipelines for the training data: one applied on individual images in minnt.TransformedDataset.transform, and another applied on batches in minnt.TransformedDataset.transform_batch. Note that for the latter to work, the corresponding data loader must be created using the minnt.TransformedDataset.dataloader method.

Furthermore, the data augmentation can happen in subprocesses when --dataloader_workers is set to a value greater than 0. By default, Minnt avoids the fork start method on Unix-like systems (as is the default in Python 3.14) and the minnt.TransformedDataset.dataloader changes the default of persistent_workers argument of torch.utils.data.DataLoader to True for better performance.

@@ -1,7 +1,9 @@
 #!/usr/bin/env python3
 import argparse
+import re

 import torch
 import torchvision
+import torchvision.transforms.v2 as v2

 import minnt
@@ -9,6 +11,8 @@
 # Parse arguments
 parser = argparse.ArgumentParser()
+parser.add_argument("--augmentation", default="", type=str, help="What data augmentation to use.")
 parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
 parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
+parser.add_argument("--dataloader_workers", default=0, type=int, help="Number of dataloader workers.")
 parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
 parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
@@ -39,7 +43,32 @@

 class Dataset(minnt.TransformedDataset):
+    def __init__(self, dataset: torch.utils.data.Dataset, augmentation: str = ""):
+        super().__init__(dataset)
+
+        transformations = [v2.ToImage()]
+        if "basic" in augmentation:
+            transformations.append(v2.RandomHorizontalFlip())
+            transformations.append(v2.RandomCrop((28, 28), padding=4, fill=127))
+        if randaugment := re.search(r"randaugment-(\d+)-(\d+)", augmentation):
+            n, m = map(int, randaugment.groups())
+            transformations.append(v2.RandAugment(n, m, fill=127))
+        if augmix := re.search(r"augmix-(\d+)", augmentation):
+            severity, = map(int, augmix.groups())
+            transformations.append(v2.AugMix(severity))
+        transformations.append(v2.ToDtype(torch.float32, scale=True))
+        self._transformation = v2.Compose(transformations)
+
+        batch_augmentations = []
+        if "cutmix" in augmentation:
+            batch_augmentations.append(v2.CutMix(num_classes=10))
+        if "mixup" in augmentation:
+            batch_augmentations.append(v2.MixUp(num_classes=10))
+        self._batch_augmentation = v2.RandomChoice(batch_augmentations) if batch_augmentations else None
+
     def transform(self, image, label):
-        image = torchvision.transforms.functional.to_tensor(image)
-        return image, label
+        return self._transformation(image), label
+
+    def transform_batch(self, *batch):
+        return self._batch_augmentation(*batch) if self._batch_augmentation else batch


@@ -56,5 +85,6 @@

     # Create data loaders from the datasets.
-    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
+    train = Dataset(mnist_train, args.augmentation).dataloader(
+        args.batch_size, shuffle=True, num_workers=args.dataloader_workers)
     dev = Dataset(mnist_dev).dataloader(args.batch_size)
     test = Dataset(mnist_test).dataloader(args.batch_size)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
import argparse
import re

import torch
import torchvision
import torchvision.transforms.v2 as v2

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--augmentation", default="", type=str, help="What data augmentation to use.")
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dataloader_workers", default=0, type=int, help="Number of dataloader workers.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def __init__(self, dataset: torch.utils.data.Dataset, augmentation: str = ""):
        super().__init__(dataset)

        transformations = [v2.ToImage()]
        if "basic" in augmentation:
            transformations.append(v2.RandomHorizontalFlip())
            transformations.append(v2.RandomCrop((28, 28), padding=4, fill=127))
        if randaugment := re.search(r"randaugment-(\d+)-(\d+)", augmentation):
            n, m = map(int, randaugment.groups())
            transformations.append(v2.RandAugment(n, m, fill=127))
        if augmix := re.search(r"augmix-(\d+)", augmentation):
            severity, = map(int, augmix.groups())
            transformations.append(v2.AugMix(severity))
        transformations.append(v2.ToDtype(torch.float32, scale=True))
        self._transformation = v2.Compose(transformations)

        batch_augmentations = []
        if "cutmix" in augmentation:
            batch_augmentations.append(v2.CutMix(num_classes=10))
        if "mixup" in augmentation:
            batch_augmentations.append(v2.MixUp(num_classes=10))
        self._batch_augmentation = v2.RandomChoice(batch_augmentations) if batch_augmentations else None

    def transform(self, image, label):
        return self._transformation(image), label

    def transform_batch(self, *batch):
        return self._batch_augmentation(*batch) if self._batch_augmentation else batch


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train, args.augmentation).dataloader(
        args.batch_size, shuffle=True, num_workers=args.dataloader_workers)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Using Callbacks

Callbacks provide a way to perform per-epoch hooks like customized evaluation. In the example below, we avoid passing dev=dev argument to minnt.TrainableModule.fit, and instead pass two example callbacks:

  • evaluate_dev, which manually logs a custom metric dev:quality and also performs evaluation on dev and adds the resulting dev metrics to the training logs;
  • dev_misclassifications, which runs prediction on dev after each epoch, storing first misclassified image for every class to the logs (as an image and a text describing the predicted label). The prediction is stopped immediately after all first misclassifications have been found.

These callbacks demonstrate that custom metrics can be added by extending the logs argument of a callback, and that the minnt.TrainableModule.logger can be used to log multimedia data like images, figures, audio, and text.

@@ -74,5 +74,23 @@

     # Train the model.
-    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)
+    def evaluate_dev(model, epoch, logs):
+        logs["dev:quality"] = 0.42  # just an example of logging a custom metric
+        logs |= model.evaluate(dev, "dev", log_results=False)  # do not log in `evaluate` to merge all logs
+
+    def dev_misclassifications(model, epoch, logs):
+        missing_classes = set(range(10))
+        for (image, label), prediction in zip(mnist_dev, model.predict(dev, data_with_labels=True)):
+            prediction = prediction.argmax().item()
+            if prediction != label and label in missing_classes:
+                model.logger.log_image(f"dev:misclassified/{label}", image, model.epoch, data_format="CHW")
+                model.logger.log_text(f"dev:misclassified/{label}/prediction", f"{prediction}", model.epoch)
+                missing_classes.remove(label)
+                if not missing_classes:
+                    break
+
+    model.fit(train, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[
+        evaluate_dev,
+        dev_misclassifications,
+    ])

     # Evaluate the model on the test data.
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    def evaluate_dev(model, epoch, logs):
        logs["dev:quality"] = 0.42  # just an example of logging a custom metric
        logs |= model.evaluate(dev, "dev", log_results=False)  # do not log in `evaluate` to merge all logs

    def dev_misclassifications(model, epoch, logs):
        missing_classes = set(range(10))
        for (image, label), prediction in zip(mnist_dev, model.predict(dev, data_with_labels=True)):
            prediction = prediction.argmax().item()
            if prediction != label and label in missing_classes:
                model.logger.log_image(f"dev:misclassified/{label}", image, model.epoch, data_format="CHW")
                model.logger.log_text(f"dev:misclassified/{label}/prediction", f"{prediction}", model.epoch)
                missing_classes.remove(label)
                if not missing_classes:
                    break

    model.fit(train, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[
        evaluate_dev,
        dev_misclassifications,
    ])

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Saving and Loading

To save and load a trained model, Minnt offers:

  • Saving and loading of model weights. By default, only the model parameters are saved; the optimizer and scheduler states can be saved to a second file by passing an additional argument optimizer_path.

  • Saving and loading of model options, which might be needed to reconstruct the model architecture before loading the weights. The options are saved as a JSON file and in addition to JSON types they also support argparse.Namespace objects. Note that saving the options might not be needed if the model architecture is fixed and known in advance.

@@ -79,4 +79,18 @@
     model.evaluate(test)

+    # Save the model's options and weights.
+    model.save_options("{logdir}/options.json", args=args)
+    model.save_weights("{logdir}/weights.pt")
+
+    # Construct a new model using the saved options and load the saved weights.
+    loaded_model = Model(**Model.load_options(f"{model.logdir}/options.json"))
+    loaded_model.load_weights(f"{model.logdir}/weights.pt")
+
+    # Compare the test predictions of the original and loaded model.
+    torch.testing.assert_close(
+        loaded_model.predict_tensor(test, data_with_labels=True),
+        model.predict_tensor(test, data_with_labels=True),
+    )
+

 if __name__ == "__main__":
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)

    # Evaluate the model on the test data.
    model.evaluate(test)

    # Save the model's options and weights.
    model.save_options("{logdir}/options.json", args=args)
    model.save_weights("{logdir}/weights.pt")

    # Construct a new model using the saved options and load the saved weights.
    loaded_model = Model(**Model.load_options(f"{model.logdir}/options.json"))
    loaded_model.load_weights(f"{model.logdir}/weights.pt")

    # Compare the test predictions of the original and loaded model.
    torch.testing.assert_close(
        loaded_model.predict_tensor(test, data_with_labels=True),
        model.predict_tensor(test, data_with_labels=True),
    )


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Saving via Callback

The model weights can be also saved via a minnt.callbacks.SaveWeights callback after every epoch. The path where to save the weights can include {logdir} and {epoch} placeholders, which allows both saving the weights to a fixed path inside the log directory or saving the weights to separate files for every epoch.

The callback can be also passed the optimizer_path argument to save the optimizer state; but the options, if needed, must be saved separately using minnt.TrainableModule.save_options.

@@ -74,5 +74,8 @@

     # Train the model.
-    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)
+    model.save_options("{logdir}/options.json", args=args)
+    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[
+        minnt.callbacks.SaveWeights("{logdir}/model_{epoch:02d}.pt")
+    ])

     # Evaluate the model on the test data.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    model.save_options("{logdir}/options.json", args=args)
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[
        minnt.callbacks.SaveWeights("{logdir}/model_{epoch:02d}.pt")
    ])

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Saving Best Weights

If you want to save the weights of a model that performed best on a development set, you can use the minnt.callbacks.SaveBestWeights callback. It works similarly to the previous callback, but only saves the weights when a specified metric improves. Apart from specifying a metric, you might also specify whether the metric should be maximixed (the default; mode="max") or minimized (mode="min"). After training, the best value of the monitored metric is available as minnt.callbacks.SaveBestWeights.best_value.

@@ -74,9 +74,17 @@

     # Train the model.
-    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)
+    model.save_options("{logdir}/options.json", args=args)
+    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[
+        minnt.callbacks.SaveBestWeights("{logdir}/best_model.pt", "dev:accuracy", mode="max"),
+    ])

     # Evaluate the model on the test data.
     model.evaluate(test)

+    print("Restoring the weights reaching the best dev accuracy.")
+    model.load_weights("{logdir}/best_model.pt")
+
+    model.evaluate(test)
+

 if __name__ == "__main__":
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    model.save_options("{logdir}/options.json", args=args)
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[
        minnt.callbacks.SaveBestWeights("{logdir}/best_model.pt", "dev:accuracy", mode="max"),
    ])

    # Evaluate the model on the test data.
    model.evaluate(test)

    print("Restoring the weights reaching the best dev accuracy.")
    model.load_weights("{logdir}/best_model.pt")

    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Keeping Best Weights

In some circumstances, you might want to keep the best weights in memory instead of saving them to disk. This can be done using the minnt.callbacks.KeepBestWeights callback, which keeps the best weights on a specified device (the model device by default). At the end of training, the best weights stored in minnt.callbacks.KeepBestWeights.best_state_dict can be restored using the standard torch.nn.Module.load_state_dict method.

Note that contrary to minnt.callbacks.SaveBestWeights, this callback does not allow saving also the optimizer state.

@@ -74,9 +74,15 @@

     # Train the model.
-    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)
+    best_weights = minnt.callbacks.KeepBestWeights("dev:accuracy", mode="max")
+    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[best_weights])

     # Evaluate the model on the test data.
     model.evaluate(test)

+    print(f"Restoring the weights reaching the best dev accuracy {best_weights.best_value:.4f}.")
+    model.load_state_dict(best_weights.best_state_dict)
+
+    model.evaluate(test)
+

 if __name__ == "__main__":
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )

    # Train the model.
    best_weights = minnt.callbacks.KeepBestWeights("dev:accuracy", mode="max")
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True, callbacks=[best_weights])

    # Evaluate the model on the test data.
    model.evaluate(test)

    print(f"Restoring the weights reaching the best dev accuracy {best_weights.best_value:.4f}.")
    model.load_state_dict(best_weights.best_state_dict)

    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Using W&B Logger

Other loggers can be used instead of the default TensorBoard logger by specifying the loggers argument of minnt.TrainableModule.configure. In the following example, we illustrate using the W&B logger minnt.loggers.WandBLogger saving the logs to logs/wandb directory.

  • When loggers are specified, the default TensorBoard logger is not used; if you want both TensorBoard and W&B logging, you need to explicitly specify the minnt.loggers.TensorBoardLogger in addition to the W&B logger.

  • If the logdir argument of configure is specified, the plain text logs are still saved to the log directory.

@@ -19,4 +19,5 @@
 parser.add_argument("--seed", default=42, type=int, help="Random seed.")
 parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
+parser.add_argument("--wandb_project", default="test_project", type=str, help="W&B project name.")


@@ -71,4 +72,5 @@
         metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
         logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
+        loggers=[minnt.loggers.WandBLogger(project=args.wandb_project, dir="logs")],
     )
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
parser.add_argument("--wandb_project", default="test_project", type=str, help="W&B project name.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
        loggers=[minnt.loggers.WandBLogger(project=args.wandb_project, dir="logs")],
    )

    # Train the model.
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)

Profiling CPU & GPU

A minnt.TrainableModule can be profiled by using the minnt.TrainableModule.profile method. The profiler tracks CPU usage, accelerator usage (if available), and memory usage, and the resulting trace file can be inspected in TensorBoard using the torch-tb-profiler plugin (which can be installed using pip install torch-tb-profiler). Given number of steps (forward calls) are profiled (either during training or evaluation), after optional number of warmup steps.

The example below profiles 2 steps after a warmup of 3 steps. Note that we disable graph logging to avoid profiling the corresponding graph tracing; we could have also used larger warmup (e.g., warmup=5) to achieve a similar effect.

A memory timeline is not generated in the example because it requires the Matplotlib package; you can try enabling it if you have it installed.

@@ -72,7 +72,12 @@
         logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
     )
+    model.profile(
+        2, "{logdir}/profile", warmup=3, quit_when_done=True,
+        export_memory_timeline=False,  # to generate a memory timeline, matplotlib package is required
+        export_cuda_allocations=True,  # generate CUDA allocations if CUDA available; ignored otherwise
+    )

     # Train the model.
-    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)
+    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args))

     # Evaluate the model on the test data.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
import argparse

import torch
import torchvision

import minnt

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--cnn_dim", default=16, type=int, help="Number of CNN filters.")
parser.add_argument("--dropout", default=0.2, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=256, type=int, help="Size of the hidden layer.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Label smoothing factor.")
parser.add_argument("--learning_rate", default=0.001, type=float, help="Learning rate.")
parser.add_argument("--learning_rate_decay", default="cosine", choices=["cosine", "linear", "none"], help="LR decay.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")


class Model(minnt.TrainableModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()

        self.network = torch.nn.Sequential(
            torch.nn.LazyConv2d(1 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(2 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.LazyConv2d(4 * args.cnn_dim, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(args.hidden_layer_size), torch.nn.ReLU(), torch.nn.Dropout(args.dropout),
            torch.nn.LazyLinear(10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


class Dataset(minnt.TransformedDataset):
    def transform(self, image, label):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, label


def main(args: argparse.Namespace) -> None:
    # Set the random seed and the number of threads.
    minnt.startup(args.seed, args.threads)
    minnt.global_keras_initializers()

    # Load the data using torchvision.
    mnist_train_dev = torchvision.datasets.MNIST("mnist", train=True, download=True)
    mnist_train = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[:-5000])
    mnist_dev = torch.utils.data.Subset(mnist_train_dev, list(range(len(mnist_train_dev)))[-5000:])
    mnist_test = torchvision.datasets.MNIST("mnist", train=False, download=True)

    # Create data loaders from the datasets.
    train = Dataset(mnist_train).dataloader(args.batch_size, shuffle=True)
    dev = Dataset(mnist_dev).dataloader(args.batch_size)
    test = Dataset(mnist_test).dataloader(args.batch_size)

    # Create a model according to the given arguments.
    model = Model(args)
    print("The following model has been created:", model)

    # Configure the model for training.
    model.configure(
        optimizer=(optimizer := torch.optim.Adam(model.parameters(), args.learning_rate)),
        scheduler=minnt.schedulers.GenericDecay(optimizer, args.epochs * len(train), args.learning_rate_decay),
        loss=minnt.losses.CategoricalCrossEntropy(label_smoothing=args.label_smoothing),
        metrics={"accuracy": minnt.metrics.CategoricalAccuracy()},
        logdir=minnt.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
    )
    model.profile(
        2, "{logdir}/profile", warmup=3, quit_when_done=True,
        export_memory_timeline=False,  # to generate a memory timeline, matplotlib package is required
        export_cuda_allocations=True,  # generate CUDA allocations if CUDA available; ignored otherwise
    )

    # Train the model.
    model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args))

    # Evaluate the model on the test data.
    model.evaluate(test)


if __name__ == "__main__":
    main_args = parser.parse_args([] if "__file__" not in globals() else None)
    main(main_args)