Skip to content

TensorBoardLogger

minnt.loggers.TensorBoardLogger

Bases: BaseLogger

A TensorBoard logger interface.

In addition to implementing the Logger interface, it provides a method for obtaining the underlying TensorBoard SummaryWriter instance for a given writer name.

Source code in minnt/loggers/tensorboard_logger.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class TensorBoardLogger(BaseLogger):
    """A TensorBoard logger interface.

    In addition to implementing the [Logger][minnt.Logger] interface, it provides
    a method for obtaining the underlying TensorBoard
    [SummaryWriter][torch.utils.tensorboard.writer.SummaryWriter] instance for a given
    writer name.
    """
    def __init__(self, logdir: str) -> None:
        """Initialize the TensorBoard logger.

        Parameters:
          logdir: The root directory where the TensorBoard logs will be stored.
        """
        self._logdir: str = logdir
        self._writers: dict[str, torch.utils.tensorboard.writer.SummaryWriter] = {}

    def close(self) -> None:
        for writer in self._writers.values():
            writer.close()
        self._writers.clear()

    def get_writer(self, name: str) -> torch.utils.tensorboard.writer.SummaryWriter:
        """Possibly create and return a TensorBoard writer for the given name.

        Returns:
          writer: The opened TensorBoard writer.
        """
        if name not in self._writers:
            self._writers[name] = torch.utils.tensorboard.SummaryWriter(os.path.join(self._logdir, name))
        return self._writers[name]

    def _get_writer_from_label(self, label: str) -> tuple[torch.utils.tensorboard.writer.SummaryWriter, str]:
        """Possibly create and return a TensorBoard writer for the given label.

        Returns:
          writer: The opened TensorBoard writer.
          label: The label without the writer prefix.
        """
        writer, label = label.split(":", maxsplit=1) if ":" in label else ("train", label)
        return self.get_writer(writer), label

    def log_audio(self, label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self:
        audio = self.preprocess_audio(audio)
        audio = (audio.to(torch.float32) / 32_767).clamp(-1.0, 1.0).movedim(-1, 0)
        if audio.shape[0] == 2:
            audio = audio.mean(dim=0, keepdim=True)
        writer, label = self._get_writer_from_label(label)
        writer.add_audio(label, audio, epoch, sample_rate=sample_rate)
        writer.flush()
        return self

    def log_config(self, config: dict[str, Any], epoch: int) -> Self:
        writer = self.get_writer("train")
        writer.add_text("config", self.format_config_as_json(config), epoch)
        writer.flush()
        return self

    def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
        return super().log_figure(label, figure, epoch, tight_layout, close)

    def log_graph(self, graph: torch.nn.Module, data: TensorOrTensors, epoch: int) -> Self:
        writer = self.get_writer("train")
        writer.add_graph(graph, data, use_strict_trace=False)
        writer.flush()
        return self

    def log_image(self, label: str, image: AnyArray, epoch: int, data_format: DataFormat = "HWC") -> Self:
        image = self.preprocess_image(image, data_format)
        writer, label = self._get_writer_from_label(label)
        writer.add_image(label, image, epoch, dataformats="HWC" if image.ndim == 3 else "HW")
        writer.flush()
        return self

    def log_metrics(self, logs: dict[str, float], epoch: int, description: str | None = None) -> Self:
        for label, value in logs.items():
            writer, label = self._get_writer_from_label(label)
            writer.add_scalar(label, value, epoch)

        for writer in self._writers.values():
            writer.flush()

        return self

    def log_text(self, label: str, text: str, epoch: int) -> Self:
        writer, label = self._get_writer_from_label(label)
        writer.add_text(label, text, epoch)
        writer.flush()
        return self

__init__

__init__(logdir: str) -> None

Initialize the TensorBoard logger.

Parameters:

  • logdir (str) –

    The root directory where the TensorBoard logs will be stored.

Source code in minnt/loggers/tensorboard_logger.py
24
25
26
27
28
29
30
31
def __init__(self, logdir: str) -> None:
    """Initialize the TensorBoard logger.

    Parameters:
      logdir: The root directory where the TensorBoard logs will be stored.
    """
    self._logdir: str = logdir
    self._writers: dict[str, torch.utils.tensorboard.writer.SummaryWriter] = {}

close

close() -> None

Close the logger and release its resources.

Source code in minnt/loggers/tensorboard_logger.py
33
34
35
36
def close(self) -> None:
    for writer in self._writers.values():
        writer.close()
    self._writers.clear()

get_writer

get_writer(name: str) -> SummaryWriter

Possibly create and return a TensorBoard writer for the given name.

Returns:

Source code in minnt/loggers/tensorboard_logger.py
38
39
40
41
42
43
44
45
46
def get_writer(self, name: str) -> torch.utils.tensorboard.writer.SummaryWriter:
    """Possibly create and return a TensorBoard writer for the given name.

    Returns:
      writer: The opened TensorBoard writer.
    """
    if name not in self._writers:
        self._writers[name] = torch.utils.tensorboard.SummaryWriter(os.path.join(self._logdir, name))
    return self._writers[name]

log_audio

log_audio(label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self

Log the given audio with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged audio.

  • audio (AnyArray) –

    The audio to log, represented as an array with any of the following shapes:

    • (L,) of (L, 1) for mono audio,
    • (L, 2) for stereo audio.

    If the sample values are floating-point numbers, they are expected to be in the [-1, 1] range; otherwise, they are assumed to be in the [-32_768, 32_767] range.

  • sample_rate (int) –

    The sample rate of the audio.

  • epoch (int) –

    The epoch number at which the audio is logged.

Source code in minnt/loggers/tensorboard_logger.py
58
59
60
61
62
63
64
65
66
def log_audio(self, label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self:
    audio = self.preprocess_audio(audio)
    audio = (audio.to(torch.float32) / 32_767).clamp(-1.0, 1.0).movedim(-1, 0)
    if audio.shape[0] == 2:
        audio = audio.mean(dim=0, keepdim=True)
    writer, label = self._get_writer_from_label(label)
    writer.add_audio(label, audio, epoch, sample_rate=sample_rate)
    writer.flush()
    return self

log_config

log_config(config: dict[str, Any], epoch: int) -> Self

Log the given configuration dictionary at the given epoch.

Parameters:

  • config (dict[str, Any]) –

    A JSON-serializable dictionary representing the configuration to log.

  • epoch (int) –

    The epoch number at which the configuration is logged.

Source code in minnt/loggers/tensorboard_logger.py
68
69
70
71
72
def log_config(self, config: dict[str, Any], epoch: int) -> Self:
    writer = self.get_writer("train")
    writer.add_text("config", self.format_config_as_json(config), epoch)
    writer.flush()
    return self

log_figure

log_figure(
    label: str,
    figure: Any,
    epoch: int,
    tight_layout: bool = True,
    close: bool = True,
) -> Self

Log the given matplotlib Figure with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • figure (Any) –

    A matplotlib Figure.

  • epoch (int) –

    The epoch number at which the image is logged.

  • tight_layout (bool, default: True ) –

    Whether to apply tight layout to the figure before logging it.

  • close (bool, default: True ) –

    Whether to close the figure after logging it.

Source code in minnt/loggers/tensorboard_logger.py
74
75
def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
    return super().log_figure(label, figure, epoch, tight_layout, close)

log_graph

log_graph(graph: Module, data: TensorOrTensors, epoch: int) -> Self

Log the given computation graph by tracing it with the given data.

Alternatively, loggers may choose to log the graph using TorchScript, run it on the given data, or use any other mechanism they see fit.

Parameters:

  • graph (Module) –

    The computation graph to log, represented as a PyTorch module.

  • data (TensorOrTensors) –

    The input data to use for tracing the computation graph.

  • epoch (int) –

    The epoch number at which the computation graph is logged.

Source code in minnt/loggers/tensorboard_logger.py
77
78
79
80
81
def log_graph(self, graph: torch.nn.Module, data: TensorOrTensors, epoch: int) -> Self:
    writer = self.get_writer("train")
    writer.add_graph(graph, data, use_strict_trace=False)
    writer.flush()
    return self

log_image

log_image(
    label: str, image: AnyArray, epoch: int, data_format: DataFormat = "HWC"
) -> Self

Log the given image with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • image (AnyArray) –

    The image to log, represented as a PIL image or as an array of any of the following shapes (assuming "HWC" data format):

    • (H, W) or (H, W, 1) for grayscale images,
    • (H, W, 2) for grayscale images with alpha channel,
    • (H, W, 3) for RGB images,
    • (H, W, 4) for RGBA images.

    If the pixel values are floating-point numbers, they are expected to be in the [0, 1] range; otherwise, they are assumed to be in the [0, 255] range.

  • epoch (int) –

    The epoch number at which the image is logged.

  • data_format (DataFormat, default: 'HWC' ) –

    The data format of the image specifying whether the channels are stored in the last dimension ("HWC", the default) or in the first dimension ("CHW"); ignored for a PIL image.

Source code in minnt/loggers/tensorboard_logger.py
83
84
85
86
87
88
def log_image(self, label: str, image: AnyArray, epoch: int, data_format: DataFormat = "HWC") -> Self:
    image = self.preprocess_image(image, data_format)
    writer, label = self._get_writer_from_label(label)
    writer.add_image(label, image, epoch, dataformats="HWC" if image.ndim == 3 else "HW")
    writer.flush()
    return self

log_metrics

log_metrics(
    logs: dict[str, float], epoch: int, description: str | None = None
) -> Self

Log metrics collected during a given epoch, with an optional description.

Parameters:

  • logs (dict[str, float]) –

    A dictionary of logged metrics for the epoch.

  • epoch (int) –

    The epoch number at which the logs were collected.

  • description (str | None, default: None ) –

    An optional description of the logged metrics (used only by some loggers).

Source code in minnt/loggers/tensorboard_logger.py
90
91
92
93
94
95
96
97
98
def log_metrics(self, logs: dict[str, float], epoch: int, description: str | None = None) -> Self:
    for label, value in logs.items():
        writer, label = self._get_writer_from_label(label)
        writer.add_scalar(label, value, epoch)

    for writer in self._writers.values():
        writer.flush()

    return self

log_text

log_text(label: str, text: str, epoch: int) -> Self

Log the given text with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged text.

  • text (str) –

    The text to log.

  • epoch (int) –

    The epoch number at which the text is logged.

Source code in minnt/loggers/tensorboard_logger.py
100
101
102
103
104
def log_text(self, label: str, text: str, epoch: int) -> Self:
    writer, label = self._get_writer_from_label(label)
    writer.add_text(label, text, epoch)
    writer.flush()
    return self