Skip to content

ExactMatch

minnt.metrics.ExactMatch

Bases: Mean

Exact match metric implementation.

The elements to compare can be either tensors or generic iterables. When tensors are used, the element_dims parameter can be specified to indicate which dimensions of the tensors form an element for comparison; when iterables are used, the input one-dimensional sequences are compared directly.

Source code in minnt/metrics/exact_match.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ExactMatch(Mean):
    """Exact match metric implementation.

    The elements to compare can be either tensors or generic iterables. When tensors are used,
    the `element_dims` parameter can be specified to indicate which dimensions of the tensors
    form an element for comparison; when iterables are used, the input one-dimensional sequences are
    compared directly.
    """

    def __init__(self, element_dims: int | tuple[int] = ()) -> None:
        """Create the ExactMatch metric object.

        Parameters:
          element_dims: If the values to compare are tensors, this parameter
            can be used to specify which dimensions of the tensors form
            an element for comparison.
        """
        super().__init__()
        if isinstance(element_dims, int):
            self._element_dims = (element_dims,)
        elif isinstance(element_dims, (tuple, list)):
            self._element_dims = tuple(element_dims)
        else:
            raise TypeError("The element_dims argument must be an int or a tuple of ints.")

    def update(
        self,
        y: torch.Tensor | Iterable[Any],
        y_true: torch.Tensor | Iterable[Any],
        sample_weights: torch.Tensor | None = None,
    ) -> None:
        """Update the exact match by comparing the given values.

        The inputs can be either both tensors or both iterables. When they are both tensors,
        the `element_dims` parameter can be used to specify which dimensions of the tensors
        form an element for comparison; when they are both iterables, the elements are
        compared directly.

        Optional sample weight might be provided; if not, all values are weighted with 1.

        Parameters:
          y: A tensor or an iterable of predicted values of the same shape as `y_true`.
          y_true: A tensor or an iterable of ground-truth targets of the same shape as `y`.
          sample_weights: Optional sample weights. Their shape must be broadcastable to a
            prefix of the shape of `y` (with `element_dims` dimensions removed, if specified).
        """
        if isinstance(y, torch.Tensor) and isinstance(y_true, torch.Tensor):
            assert y.shape == y_true.shape, "The y and y_true tensors must have the same shape."

            equals = (y == y_true)
            if self._element_dims:
                equals = torch.all(equals, dim=self._element_dims)
        else:
            assert not self._element_dims, "Nonempty element_dims can only be used with tensor inputs."
            equals = torch.tensor([pred == true for pred, true in zip(y, y_true, strict=True)], dtype=torch.float32)

        super().update(equals, sample_weights=sample_weights)

__init__

__init__(element_dims: int | tuple[int] = ()) -> None

Create the ExactMatch metric object.

Parameters:

  • element_dims (int | tuple[int], default: () ) –

    If the values to compare are tensors, this parameter can be used to specify which dimensions of the tensors form an element for comparison.

Source code in minnt/metrics/exact_match.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, element_dims: int | tuple[int] = ()) -> None:
    """Create the ExactMatch metric object.

    Parameters:
      element_dims: If the values to compare are tensors, this parameter
        can be used to specify which dimensions of the tensors form
        an element for comparison.
    """
    super().__init__()
    if isinstance(element_dims, int):
        self._element_dims = (element_dims,)
    elif isinstance(element_dims, (tuple, list)):
        self._element_dims = tuple(element_dims)
    else:
        raise TypeError("The element_dims argument must be an int or a tuple of ints.")

update

update(
    y: Tensor | Iterable[Any],
    y_true: Tensor | Iterable[Any],
    sample_weights: Tensor | None = None,
) -> None

Update the exact match by comparing the given values.

The inputs can be either both tensors or both iterables. When they are both tensors, the element_dims parameter can be used to specify which dimensions of the tensors form an element for comparison; when they are both iterables, the elements are compared directly.

Optional sample weight might be provided; if not, all values are weighted with 1.

Parameters:

  • y (Tensor | Iterable[Any]) –

    A tensor or an iterable of predicted values of the same shape as y_true.

  • y_true (Tensor | Iterable[Any]) –

    A tensor or an iterable of ground-truth targets of the same shape as y.

  • sample_weights (Tensor | None, default: None ) –

    Optional sample weights. Their shape must be broadcastable to a prefix of the shape of y (with element_dims dimensions removed, if specified).

Source code in minnt/metrics/exact_match.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def update(
    self,
    y: torch.Tensor | Iterable[Any],
    y_true: torch.Tensor | Iterable[Any],
    sample_weights: torch.Tensor | None = None,
) -> None:
    """Update the exact match by comparing the given values.

    The inputs can be either both tensors or both iterables. When they are both tensors,
    the `element_dims` parameter can be used to specify which dimensions of the tensors
    form an element for comparison; when they are both iterables, the elements are
    compared directly.

    Optional sample weight might be provided; if not, all values are weighted with 1.

    Parameters:
      y: A tensor or an iterable of predicted values of the same shape as `y_true`.
      y_true: A tensor or an iterable of ground-truth targets of the same shape as `y`.
      sample_weights: Optional sample weights. Their shape must be broadcastable to a
        prefix of the shape of `y` (with `element_dims` dimensions removed, if specified).
    """
    if isinstance(y, torch.Tensor) and isinstance(y_true, torch.Tensor):
        assert y.shape == y_true.shape, "The y and y_true tensors must have the same shape."

        equals = (y == y_true)
        if self._element_dims:
            equals = torch.all(equals, dim=self._element_dims)
    else:
        assert not self._element_dims, "Nonempty element_dims can only be used with tensor inputs."
        equals = torch.tensor([pred == true for pred, true in zip(y, y_true, strict=True)], dtype=torch.float32)

    super().update(equals, sample_weights=sample_weights)