Skip to content

TensorBoardLogger

A TensorBoard logger interface.

In addition to implementing the Logger interface, it provides a method for obtaining the underlying TensorBoard SummaryWriter instance for a given writer name.

minnt.loggers.TensorBoardLogger

Bases: Logger

Source code in minnt/loggers/tensorboard_logger.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class TensorBoardLogger(Logger):
    def __init__(self, logdir: str) -> None:
        """Initialize the TensorBoard logger.

        Parameters:
          logdir: The root directory where the TensorBoard logs will be stored.
        """
        self._logdir: str = logdir
        self._writers: dict[str, torch.utils.tensorboard.writer.SummaryWriter] = {}

    def __del__(self) -> None:
        # Close the writers.
        for writer in self._writers.values():
            writer.close()

    def get_writer(self, name: str) -> torch.utils.tensorboard.writer.SummaryWriter:
        """Possibly create and return a TensorBoard writer for the given name.

        Returns:
          writer: The opened TensorBoard writer.
        """
        if name not in self._writers:
            self._writers[name] = torch.utils.tensorboard.SummaryWriter(os.path.join(self._logdir, name))
        return self._writers[name]

    def _get_writer_from_label(self, label: str) -> tuple[torch.utils.tensorboard.writer.SummaryWriter, str]:
        """Possibly create and return a TensorBoard writer for the given label.

        Returns:
          writer: The opened TensorBoard writer.
          label: The label without the writer prefix.
        """
        writer, label = label.split(":", maxsplit=1) if ":" in label else ("train", label)
        return self.get_writer(writer), label

    def log_audio(self, label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self:
        audio = self._process_audio(audio)
        audio = (audio.to(torch.float32) / 32_767).clamp(-1.0, 1.0).movedim(-1, 0)
        if audio.shape[0] == 2:
            audio = audio.mean(dim=0, keepdim=True)
        writer, label = self._get_writer_from_label(label)
        writer.add_audio(label, audio, epoch, sample_rate=sample_rate)
        writer.flush()
        return self

    def log_config(self, config: dict[str, Any], epoch: int) -> Self:
        config = dict(sorted(config.items()))
        writer = self.get_writer("train")
        writer.add_text("config", json.dumps(config, ensure_ascii=False, indent=2), epoch)
        writer.flush()
        return self

    def log_epoch(
        self, logs: dict[str, float], epoch: int, epochs: int | None = None, elapsed: float | None = None,
    ) -> Self:
        for label, value in logs.items():
            writer, label = self._get_writer_from_label(label)
            writer.add_scalar(label, value, epoch)

        for writer in self._writers.values():
            writer.flush()

        return self

    def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
        return super().log_figure(label, figure, epoch, tight_layout, close)

    def log_graph(self, graph: torch.nn.Module, data: TensorOrTensors, epoch: int) -> Self:
        writer = self.get_writer("train")
        writer.add_graph(graph, data, use_strict_trace=False)
        writer.flush()
        return self

    def log_image(self, label: str, image: AnyArray, epoch: int) -> Self:
        image = self._process_image(image)
        writer, label = self._get_writer_from_label(label)
        writer.add_image(label, image, epoch, dataformats="HWC" if image.ndim == 3 else "HW")
        writer.flush()
        return self

    def log_text(self, label: str, text: str, epoch: int) -> Self:
        writer, label = self._get_writer_from_label(label)
        writer.add_text(label, text, epoch)
        writer.flush()
        return self

__init__

__init__(logdir: str) -> None

Initialize the TensorBoard logger.

Parameters:

  • logdir (str) –

    The root directory where the TensorBoard logs will be stored.

Source code in minnt/loggers/tensorboard_logger.py
25
26
27
28
29
30
31
32
def __init__(self, logdir: str) -> None:
    """Initialize the TensorBoard logger.

    Parameters:
      logdir: The root directory where the TensorBoard logs will be stored.
    """
    self._logdir: str = logdir
    self._writers: dict[str, torch.utils.tensorboard.writer.SummaryWriter] = {}

get_writer

get_writer(name: str) -> SummaryWriter

Possibly create and return a TensorBoard writer for the given name.

Returns:

Source code in minnt/loggers/tensorboard_logger.py
39
40
41
42
43
44
45
46
47
def get_writer(self, name: str) -> torch.utils.tensorboard.writer.SummaryWriter:
    """Possibly create and return a TensorBoard writer for the given name.

    Returns:
      writer: The opened TensorBoard writer.
    """
    if name not in self._writers:
        self._writers[name] = torch.utils.tensorboard.SummaryWriter(os.path.join(self._logdir, name))
    return self._writers[name]

log_audio

log_audio(label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self

Log the given audio with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged audio.

  • audio (AnyArray) –

    The audio to log, represented as an array with any of the following shapes:

    • (L,) of (L, 1) for mono audio,
    • (L, 2) for stereo audio.

    If the sample values are floating-point numbers, they are expected to be in the [-1, 1] range; otherwise, they are assumed to be in the [-32_768, 32_767] range.

  • sample_rate (int) –

    The sample rate of the audio.

  • epoch (int) –

    The epoch number at which the audio is logged.

Source code in minnt/loggers/tensorboard_logger.py
59
60
61
62
63
64
65
66
67
def log_audio(self, label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self:
    audio = self._process_audio(audio)
    audio = (audio.to(torch.float32) / 32_767).clamp(-1.0, 1.0).movedim(-1, 0)
    if audio.shape[0] == 2:
        audio = audio.mean(dim=0, keepdim=True)
    writer, label = self._get_writer_from_label(label)
    writer.add_audio(label, audio, epoch, sample_rate=sample_rate)
    writer.flush()
    return self

log_config

log_config(config: dict[str, Any], epoch: int) -> Self

Log the given configuration dictionary at the given epoch.

Parameters:

  • config (dict[str, Any]) –

    A JSON-serializable dictionary representing the configuration to log.

  • epoch (int) –

    The epoch number at which the configuration is logged.

Source code in minnt/loggers/tensorboard_logger.py
69
70
71
72
73
74
def log_config(self, config: dict[str, Any], epoch: int) -> Self:
    config = dict(sorted(config.items()))
    writer = self.get_writer("train")
    writer.add_text("config", json.dumps(config, ensure_ascii=False, indent=2), epoch)
    writer.flush()
    return self

log_epoch

log_epoch(
    logs: dict[str, float],
    epoch: int,
    epochs: int | None = None,
    elapsed: float | None = None,
) -> Self

Log metrics collected during a given epoch.

Parameters:

  • logs (dict[str, float]) –

    A dictionary of logged metrics for the epoch.

  • epoch (int) –

    The epoch number at which the logs were collected.

  • epochs (int | None, default: None ) –

    The total number of epochs, if known.

  • elapsed (float | None, default: None ) –

    The time elapsed during the epoch, in seconds, if known.

Source code in minnt/loggers/tensorboard_logger.py
76
77
78
79
80
81
82
83
84
85
86
def log_epoch(
    self, logs: dict[str, float], epoch: int, epochs: int | None = None, elapsed: float | None = None,
) -> Self:
    for label, value in logs.items():
        writer, label = self._get_writer_from_label(label)
        writer.add_scalar(label, value, epoch)

    for writer in self._writers.values():
        writer.flush()

    return self

log_figure

log_figure(
    label: str,
    figure: Any,
    epoch: int,
    tight_layout: bool = True,
    close: bool = True,
) -> Self

Log the given matplotlib Figure with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • figure (Any) –

    A matplotlib Figure.

  • epoch (int) –

    The epoch number at which the image is logged.

  • tight_layout (bool, default: True ) –

    Whether to apply tight layout to the figure before logging it.

  • close (bool, default: True ) –

    Whether to close the figure after logging it.

Source code in minnt/loggers/tensorboard_logger.py
88
89
def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
    return super().log_figure(label, figure, epoch, tight_layout, close)

log_graph

log_graph(graph: Module, data: TensorOrTensors, epoch: int) -> Self

Log the given computation graph by tracing it with the given data.

Alternatively, loggers may choose to log the graph using TorchScript or other mechanisms.

Parameters:

  • graph (Module) –

    The computation graph to log, represented as a PyTorch module.

  • data (TensorOrTensors) –

    The input data to use for tracing the computation graph.

  • epoch (int) –

    The epoch number at which the computation graph is logged.

Source code in minnt/loggers/tensorboard_logger.py
91
92
93
94
95
def log_graph(self, graph: torch.nn.Module, data: TensorOrTensors, epoch: int) -> Self:
    writer = self.get_writer("train")
    writer.add_graph(graph, data, use_strict_trace=False)
    writer.flush()
    return self

log_image

log_image(label: str, image: AnyArray, epoch: int) -> Self

Log the given image with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • image (AnyArray) –

    The image to log, represented as an array, which can have any of the following shapes:

    • (H, W) or (H, W, 1) for grayscale images,
    • (H, W, 2) for grayscale images with alpha channel,
    • (H, W, 3) for RGB images,
    • (H, W, 4) for RGBA images.

    If the pixel values are floating-point numbers, they are expected to be in the [0, 1] range; otherwise, they are assumed to be in the [0, 255] range.

  • epoch (int) –

    The epoch number at which the image is logged.

Source code in minnt/loggers/tensorboard_logger.py
 97
 98
 99
100
101
102
def log_image(self, label: str, image: AnyArray, epoch: int) -> Self:
    image = self._process_image(image)
    writer, label = self._get_writer_from_label(label)
    writer.add_image(label, image, epoch, dataformats="HWC" if image.ndim == 3 else "HW")
    writer.flush()
    return self

log_text

log_text(label: str, text: str, epoch: int) -> Self

Log the given text with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged text.

  • text (str) –

    The text to log.

  • epoch (int) –

    The epoch number at which the text is logged.

Source code in minnt/loggers/tensorboard_logger.py
104
105
106
107
108
def log_text(self, label: str, text: str, epoch: int) -> Self:
    writer, label = self._get_writer_from_label(label)
    writer.add_text(label, text, epoch)
    writer.flush()
    return self