Skip to content

WandbLogger

A Wandb logger interface.

The text values are by default also logged as HTML for better visualization.

minnt.loggers.WandbLogger

Bases: Logger

Source code in minnt/loggers/wandb_logger.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class WandbLogger(Logger):
    def __init__(self, project: str, *, text_also_as_html: bool = True, **kwargs: dict[str, Any]) -> None:
        """Create the WandbLogger with the given project name.

        Additional keyword arguments are passed to `wandb.init()`.

        Parameters:
          project: The name of the Wandb project.
          text_also_as_html: Whether to log text messages also as HTML.
            That has the advantage of interactive visualization of the value
            at different epochs and preserving whitespace formatting.
          kwargs: Additional keyword arguments passed to `wandb.init()`.
        """
        import wandb
        self.wandb = wandb
        self.run = self.wandb.init(project=project, **kwargs)
        self._text_also_as_html = text_also_as_html

    def __del__(self) -> None:
        # Close the run.
        self.run.finish()

    def _maybe_as_html(self, label: str, text: str) -> dict[str, Any]:
        """Return a dict with the HTML version of the text if enabled.

        The text is converted to HTML-safe format and returned as a wandb.Html object.
        """
        if not self._text_also_as_html:
            return {}
        return {f"{label}_html": self.wandb.Html("<pre>" + html.escape(text) + "</pre>")}

    def log_audio(self, label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self:
        audio = self._process_audio(audio).numpy()
        self.run.log({label: self.wandb.Audio(audio, sample_rate=sample_rate)}, step=epoch)
        return self

    def log_config(self, config: dict[str, Any], epoch: int) -> Self:
        config = dict(sorted(config.items()))
        self.run.config.update(config)
        config = json.dumps(config, ensure_ascii=False, indent=2)
        self.run.log({"config": config} | self._maybe_as_html("config", config), step=epoch)
        return self

    def log_epoch(
        self, logs: dict[str, float], epoch: int, epochs: int | None = None, elapsed: float | None = None,
    ) -> Self:
        self.run.log(logs, step=epoch)
        return self

    def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
        return super().log_figure(label, figure, epoch, tight_layout, close)

    def log_graph(self, graph: torch.nn.Module, data: TensorOrTensors, epoch: int) -> Self:
        self.run.watch(graph, log=None, log_graph=True)
        graph(data)  # Run the graph to log it.
        return self

    def log_image(self, label: str, image: AnyArray, epoch: int) -> Self:
        image = self._process_image(image).numpy()
        self.run.log({label: self.wandb.Image(image)}, step=epoch)
        return self

    def log_text(self, label: str, text: str, epoch: int) -> Self:
        self.run.log({label: text} | self._maybe_as_html(label, text), step=epoch)
        return self

__init__

__init__(
    project: str, *, text_also_as_html: bool = True, **kwargs: dict[str, Any]
) -> None

Create the WandbLogger with the given project name.

Additional keyword arguments are passed to wandb.init().

Parameters:

  • project (str) –

    The name of the Wandb project.

  • text_also_as_html (bool, default: True ) –

    Whether to log text messages also as HTML. That has the advantage of interactive visualization of the value at different epochs and preserving whitespace formatting.

  • kwargs (dict[str, Any], default: {} ) –

    Additional keyword arguments passed to wandb.init().

Source code in minnt/loggers/wandb_logger.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, project: str, *, text_also_as_html: bool = True, **kwargs: dict[str, Any]) -> None:
    """Create the WandbLogger with the given project name.

    Additional keyword arguments are passed to `wandb.init()`.

    Parameters:
      project: The name of the Wandb project.
      text_also_as_html: Whether to log text messages also as HTML.
        That has the advantage of interactive visualization of the value
        at different epochs and preserving whitespace formatting.
      kwargs: Additional keyword arguments passed to `wandb.init()`.
    """
    import wandb
    self.wandb = wandb
    self.run = self.wandb.init(project=project, **kwargs)
    self._text_also_as_html = text_also_as_html

log_audio

log_audio(label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self

Log the given audio with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged audio.

  • audio (AnyArray) –

    The audio to log, represented as an array with any of the following shapes:

    • (L,) of (L, 1) for mono audio,
    • (L, 2) for stereo audio.

    If the sample values are floating-point numbers, they are expected to be in the [-1, 1] range; otherwise, they are assumed to be in the [-32_768, 32_767] range.

  • sample_rate (int) –

    The sample rate of the audio.

  • epoch (int) –

    The epoch number at which the audio is logged.

Source code in minnt/loggers/wandb_logger.py
51
52
53
54
def log_audio(self, label: str, audio: AnyArray, sample_rate: int, epoch: int) -> Self:
    audio = self._process_audio(audio).numpy()
    self.run.log({label: self.wandb.Audio(audio, sample_rate=sample_rate)}, step=epoch)
    return self

log_config

log_config(config: dict[str, Any], epoch: int) -> Self

Log the given configuration dictionary at the given epoch.

Parameters:

  • config (dict[str, Any]) –

    A JSON-serializable dictionary representing the configuration to log.

  • epoch (int) –

    The epoch number at which the configuration is logged.

Source code in minnt/loggers/wandb_logger.py
56
57
58
59
60
61
def log_config(self, config: dict[str, Any], epoch: int) -> Self:
    config = dict(sorted(config.items()))
    self.run.config.update(config)
    config = json.dumps(config, ensure_ascii=False, indent=2)
    self.run.log({"config": config} | self._maybe_as_html("config", config), step=epoch)
    return self

log_epoch

log_epoch(
    logs: dict[str, float],
    epoch: int,
    epochs: int | None = None,
    elapsed: float | None = None,
) -> Self

Log metrics collected during a given epoch.

Parameters:

  • logs (dict[str, float]) –

    A dictionary of logged metrics for the epoch.

  • epoch (int) –

    The epoch number at which the logs were collected.

  • epochs (int | None, default: None ) –

    The total number of epochs, if known.

  • elapsed (float | None, default: None ) –

    The time elapsed during the epoch, in seconds, if known.

Source code in minnt/loggers/wandb_logger.py
63
64
65
66
67
def log_epoch(
    self, logs: dict[str, float], epoch: int, epochs: int | None = None, elapsed: float | None = None,
) -> Self:
    self.run.log(logs, step=epoch)
    return self

log_figure

log_figure(
    label: str,
    figure: Any,
    epoch: int,
    tight_layout: bool = True,
    close: bool = True,
) -> Self

Log the given matplotlib Figure with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • figure (Any) –

    A matplotlib Figure.

  • epoch (int) –

    The epoch number at which the image is logged.

  • tight_layout (bool, default: True ) –

    Whether to apply tight layout to the figure before logging it.

  • close (bool, default: True ) –

    Whether to close the figure after logging it.

Source code in minnt/loggers/wandb_logger.py
69
70
def log_figure(self, label: str, figure: Any, epoch: int, tight_layout: bool = True, close: bool = True) -> Self:
    return super().log_figure(label, figure, epoch, tight_layout, close)

log_graph

log_graph(graph: Module, data: TensorOrTensors, epoch: int) -> Self

Log the given computation graph by tracing it with the given data.

Alternatively, loggers may choose to log the graph using TorchScript or other mechanisms.

Parameters:

  • graph (Module) –

    The computation graph to log, represented as a PyTorch module.

  • data (TensorOrTensors) –

    The input data to use for tracing the computation graph.

  • epoch (int) –

    The epoch number at which the computation graph is logged.

Source code in minnt/loggers/wandb_logger.py
72
73
74
75
def log_graph(self, graph: torch.nn.Module, data: TensorOrTensors, epoch: int) -> Self:
    self.run.watch(graph, log=None, log_graph=True)
    graph(data)  # Run the graph to log it.
    return self

log_image

log_image(label: str, image: AnyArray, epoch: int) -> Self

Log the given image with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged image.

  • image (AnyArray) –

    The image to log, represented as an array, which can have any of the following shapes:

    • (H, W) or (H, W, 1) for grayscale images,
    • (H, W, 2) for grayscale images with alpha channel,
    • (H, W, 3) for RGB images,
    • (H, W, 4) for RGBA images.

    If the pixel values are floating-point numbers, they are expected to be in the [0, 1] range; otherwise, they are assumed to be in the [0, 255] range.

  • epoch (int) –

    The epoch number at which the image is logged.

Source code in minnt/loggers/wandb_logger.py
77
78
79
80
def log_image(self, label: str, image: AnyArray, epoch: int) -> Self:
    image = self._process_image(image).numpy()
    self.run.log({label: self.wandb.Image(image)}, step=epoch)
    return self

log_text

log_text(label: str, text: str, epoch: int) -> Self

Log the given text with the given label at the given epoch.

Parameters:

  • label (str) –

    The label of the logged text.

  • text (str) –

    The text to log.

  • epoch (int) –

    The epoch number at which the text is logged.

Source code in minnt/loggers/wandb_logger.py
82
83
84
def log_text(self, label: str, text: str, epoch: int) -> Self:
    self.run.log({label: text} | self._maybe_as_html(label, text), step=epoch)
    return self