Skip to content

Mean

minnt.metrics.Mean

Bases: Metric

A class tracking the (optionally weighted) mean of given values.

Source code in minnt/metrics/mean.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Mean(Metric):
    """A class tracking the (optionally weighted) mean of given values."""

    def __init__(self) -> None:
        """Create the Mean metric object."""
        super().__init__()
        self.register_buffer("_total", torch.tensor(0.0, dtype=torch.float32), persistent=False)
        self.register_buffer("_count", torch.tensor(0.0, dtype=torch.float32), persistent=False)

    def update(
        self, y: torch.Tensor, y_true: torch.Tensor | None = None, sample_weights: torch.Tensor | None = None,
    ) -> None:
        """Update the accumulated mean by introducing new values.

        Optional sample weight might be provided; if not, all values are weighted with 1.

        Parameters:
          y: The values to average.
          y_true: This parameter is present for [minnt.Metric][] compatibility, but must be `None` for this metric.
          sample_weights: Optional sample weights. Their shape must be broadcastable to a prefix of the shape of `y`.
        """
        assert y_true is None, "The y_true parameter must be None for the Mean metric."

        if sample_weights is not None:
            sample_weights = broadcast_to_prefix(sample_weights, y.shape)

        self._total.add_(torch.sum(y * sample_weights) if sample_weights is not None else torch.sum(y))
        self._count.add_(torch.sum(sample_weights) if sample_weights is not None else y.numel())

    def compute(self) -> torch.Tensor:
        return self._total / self._count

    def reset(self):
        self._total.zero_()
        self._count.zero_()

__init__

__init__() -> None

Create the Mean metric object.

Source code in minnt/metrics/mean.py
15
16
17
18
19
def __init__(self) -> None:
    """Create the Mean metric object."""
    super().__init__()
    self.register_buffer("_total", torch.tensor(0.0, dtype=torch.float32), persistent=False)
    self.register_buffer("_count", torch.tensor(0.0, dtype=torch.float32), persistent=False)

update

update(
    y: Tensor,
    y_true: Tensor | None = None,
    sample_weights: Tensor | None = None,
) -> None

Update the accumulated mean by introducing new values.

Optional sample weight might be provided; if not, all values are weighted with 1.

Parameters:

  • y (Tensor) –

    The values to average.

  • y_true (Tensor | None, default: None ) –

    This parameter is present for minnt.Metric compatibility, but must be None for this metric.

  • sample_weights (Tensor | None, default: None ) –

    Optional sample weights. Their shape must be broadcastable to a prefix of the shape of y.

Source code in minnt/metrics/mean.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def update(
    self, y: torch.Tensor, y_true: torch.Tensor | None = None, sample_weights: torch.Tensor | None = None,
) -> None:
    """Update the accumulated mean by introducing new values.

    Optional sample weight might be provided; if not, all values are weighted with 1.

    Parameters:
      y: The values to average.
      y_true: This parameter is present for [minnt.Metric][] compatibility, but must be `None` for this metric.
      sample_weights: Optional sample weights. Their shape must be broadcastable to a prefix of the shape of `y`.
    """
    assert y_true is None, "The y_true parameter must be None for the Mean metric."

    if sample_weights is not None:
        sample_weights = broadcast_to_prefix(sample_weights, y.shape)

    self._total.add_(torch.sum(y * sample_weights) if sample_weights is not None else torch.sum(y))
    self._count.add_(torch.sum(sample_weights) if sample_weights is not None else y.numel())

compute

compute() -> Tensor

Compute the accumulated metric value.

Returns:

  • Tensor

    A (usually scalar) tensor representing the accumulated metric value.

Source code in minnt/metrics/mean.py
41
42
def compute(self) -> torch.Tensor:
    return self._total / self._count

reset

reset()

Reset the internal state of the metric.

Source code in minnt/metrics/mean.py
44
45
46
def reset(self):
    self._total.zero_()
    self._count.zero_()