Skip to content

BaseLogger

minnt.loggers.BaseLogger

Bases: Logger

An abstract logger providing base functionality for other loggers.

Source code in minnt/loggers/base_logger.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class BaseLogger(Logger):
    """An abstract logger providing base functionality for other loggers."""

    def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
        import matplotlib.pyplot as plt
        import matplotlib.backends.backend_agg as plt_backend_agg

        tight_layout and figure.tight_layout()
        canvas = plt_backend_agg.FigureCanvasAgg(figure)
        canvas.draw()
        width, height = figure.canvas.get_width_height()
        image = torch.frombuffer(canvas.buffer_rgba(), dtype=torch.uint8).view(height, width, 4)
        close and plt.close(figure)

        return self.log_image(label, image, epoch)

    @staticmethod
    def format_config_as_json(config: dict[str, Any]) -> str:
        """Make a formatted JSON from configuration and epoch number."""
        return json.dumps(dict(sorted(config.items())), ensure_ascii=False, indent=2)

    @staticmethod
    def format_config_as_text(config: dict[str, Any], epoch: int) -> str:
        """Make a human-readable plain text from configuration and epoch number."""
        return " ".join(
            [f"Config epoch={epoch}"]
            + [f"{k}={v}" for k, v in sorted(config.items())]
        )

    @staticmethod
    def format_metrics(logs: dict[str, float]) -> str:
        """Make a human-readable string from the logged metrics."""
        return " ".join([f"{k}={v:#.{0 < abs(v) < 2e-4 and '2e' or '4f'}}" for k, v in logs.items()])

    @contextlib.contextmanager
    def graph_in_eval_mode(self, graph: torch.nn.Module):
        """Context manager to temporarily set the training mode of ``graph`` to eval."""
        if not isinstance(graph, torch.jit.ScriptFunction):
            old_training = graph.training
            graph.train(False)
            try:
                yield
            finally:
                graph.train(old_training)
        else:
            yield  # do nothing for a ScriptFunction

    @staticmethod
    def preprocess_audio(audio: AnyArray) -> torch.Tensor:
        """Produce a CPU-based [torch.Tensor][] with `dtype=torch.int16` and shape `(L, {1/2})`."""
        audio = torch.as_tensor(audio, device="cpu")
        audio = audio * 32_767 if audio.dtype.is_floating_point else audio
        audio = audio.clamp(-32_768, 32_767).to(torch.int16)
        assert audio.ndim == 1 or (audio.ndim == 2 and audio.shape[1] in (1, 2)), \
            "Audio must have shape (L,) or (L, 1/2)"
        if audio.ndim == 1:
            audio = audio.unsqueeze(-1)
        return audio

    @staticmethod
    def preprocess_image(image: AnyArray, data_format: DataFormat = "HWC") -> torch.Tensor:
        """Produce a CPU-based [torch.Tensor][] with `dtype=torch.uint8` and shape `(H, W, {1/3/4})`."""
        if type(image).__module__ == "PIL.Image":
            image, data_format = np.array(image, copy=True), "HWC"
        image = torch.as_tensor(image, device="cpu")
        image = image.movedim(0, -1) if data_format == "CHW" and image.ndim == 3 else image
        image = (image * 255 if image.dtype.is_floating_point else image).clamp(0, 255).to(torch.uint8)
        assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (1, 2, 3, 4)), \
            "Image must have shape (H, W) or (H, W, 1/2/3/4)"
        if image.ndim == 2:
            image = image.unsqueeze(-1)
        if image.shape[2] == 2:
            # Convert to RGBA
            image = torch.stack([image[:, :, 0]] * 3 + [image[:, :, 1]], dim=-1)
        return image

log_figure

log_figure(
    label: str,
    figure: Any,
    epoch: int,
    tight_layout: bool = True,
    close: bool = True,
) -> Self

Log the given matplotlib Figure with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • figure (Any) –

    A matplotlib Figure.

  • epoch (int) –

    The epoch number at which the image is logged.

  • tight_layout (bool, default: True ) –

    Whether to apply tight layout to the figure before logging it.

  • close (bool, default: True ) –

    Whether to close the figure after logging it.

Source code in minnt/loggers/base_logger.py
20
21
22
23
24
25
26
27
28
29
30
31
def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
    import matplotlib.pyplot as plt
    import matplotlib.backends.backend_agg as plt_backend_agg

    tight_layout and figure.tight_layout()
    canvas = plt_backend_agg.FigureCanvasAgg(figure)
    canvas.draw()
    width, height = figure.canvas.get_width_height()
    image = torch.frombuffer(canvas.buffer_rgba(), dtype=torch.uint8).view(height, width, 4)
    close and plt.close(figure)

    return self.log_image(label, image, epoch)

format_config_as_json staticmethod

format_config_as_json(config: dict[str, Any]) -> str

Make a formatted JSON from configuration and epoch number.

Source code in minnt/loggers/base_logger.py
33
34
35
36
@staticmethod
def format_config_as_json(config: dict[str, Any]) -> str:
    """Make a formatted JSON from configuration and epoch number."""
    return json.dumps(dict(sorted(config.items())), ensure_ascii=False, indent=2)

format_config_as_text staticmethod

format_config_as_text(config: dict[str, Any], epoch: int) -> str

Make a human-readable plain text from configuration and epoch number.

Source code in minnt/loggers/base_logger.py
38
39
40
41
42
43
44
@staticmethod
def format_config_as_text(config: dict[str, Any], epoch: int) -> str:
    """Make a human-readable plain text from configuration and epoch number."""
    return " ".join(
        [f"Config epoch={epoch}"]
        + [f"{k}={v}" for k, v in sorted(config.items())]
    )

format_metrics staticmethod

format_metrics(logs: dict[str, float]) -> str

Make a human-readable string from the logged metrics.

Source code in minnt/loggers/base_logger.py
46
47
48
49
@staticmethod
def format_metrics(logs: dict[str, float]) -> str:
    """Make a human-readable string from the logged metrics."""
    return " ".join([f"{k}={v:#.{0 < abs(v) < 2e-4 and '2e' or '4f'}}" for k, v in logs.items()])

graph_in_eval_mode

graph_in_eval_mode(graph: Module)

Context manager to temporarily set the training mode of graph to eval.

Source code in minnt/loggers/base_logger.py
51
52
53
54
55
56
57
58
59
60
61
62
@contextlib.contextmanager
def graph_in_eval_mode(self, graph: torch.nn.Module):
    """Context manager to temporarily set the training mode of ``graph`` to eval."""
    if not isinstance(graph, torch.jit.ScriptFunction):
        old_training = graph.training
        graph.train(False)
        try:
            yield
        finally:
            graph.train(old_training)
    else:
        yield  # do nothing for a ScriptFunction

preprocess_audio staticmethod

preprocess_audio(audio: AnyArray) -> Tensor

Produce a CPU-based torch.Tensor with dtype=torch.int16 and shape (L, {1/2}).

Source code in minnt/loggers/base_logger.py
64
65
66
67
68
69
70
71
72
73
74
@staticmethod
def preprocess_audio(audio: AnyArray) -> torch.Tensor:
    """Produce a CPU-based [torch.Tensor][] with `dtype=torch.int16` and shape `(L, {1/2})`."""
    audio = torch.as_tensor(audio, device="cpu")
    audio = audio * 32_767 if audio.dtype.is_floating_point else audio
    audio = audio.clamp(-32_768, 32_767).to(torch.int16)
    assert audio.ndim == 1 or (audio.ndim == 2 and audio.shape[1] in (1, 2)), \
        "Audio must have shape (L,) or (L, 1/2)"
    if audio.ndim == 1:
        audio = audio.unsqueeze(-1)
    return audio

preprocess_image staticmethod

preprocess_image(image: AnyArray, data_format: DataFormat = 'HWC') -> Tensor

Produce a CPU-based torch.Tensor with dtype=torch.uint8 and shape (H, W, {1/3/4}).

Source code in minnt/loggers/base_logger.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@staticmethod
def preprocess_image(image: AnyArray, data_format: DataFormat = "HWC") -> torch.Tensor:
    """Produce a CPU-based [torch.Tensor][] with `dtype=torch.uint8` and shape `(H, W, {1/3/4})`."""
    if type(image).__module__ == "PIL.Image":
        image, data_format = np.array(image, copy=True), "HWC"
    image = torch.as_tensor(image, device="cpu")
    image = image.movedim(0, -1) if data_format == "CHW" and image.ndim == 3 else image
    image = (image * 255 if image.dtype.is_floating_point else image).clamp(0, 255).to(torch.uint8)
    assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (1, 2, 3, 4)), \
        "Image must have shape (H, W) or (H, W, 1/2/3/4)"
    if image.ndim == 2:
        image = image.unsqueeze(-1)
    if image.shape[2] == 2:
        # Convert to RGBA
        image = torch.stack([image[:, :, 0]] * 3 + [image[:, :, 1]], dim=-1)
    return image