Bases: Mean
Binary classification accuracy metric.
The predictions are assumed to be logits or probabilities predicted by a model,
while the ground-truth targets are binary (0 or 1) values. In both cases, the predicted
class is considered to be the one with larger probability.
Source code in minnt/metrics/binary_accuracy.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 | class BinaryAccuracy(Mean):
"""Binary classification accuracy metric.
The predictions are assumed to be logits or probabilities predicted by a model,
while the ground-truth targets are binary (0 or 1) values. In both cases, the predicted
class is considered to be the one with larger probability.
"""
def __init__(self, *, probs: bool = False, device: torch.device | None = None) -> None:
"""Create the BinaryAccuracy metric object.
Parameters:
probs: If `False`, the predictions are assumed to be logits; if `True`, the
predictions are assumed to be probabilities. Note that gold targets are
always expected to be probabilities.
"""
super().__init__(device)
self._probs = probs
@torch.no_grad
def update(
self, y: torch.Tensor, y_true: torch.Tensor, sample_weights: torch.Tensor | None = None,
) -> Self:
"""Update the accumulated binary accuracy using new predictions and gold labels.
Optional sample weight might be provided; if not, all values are weighted with 1.
Parameters:
y: The predicted outputs. Their shape either has to be exactly the same as `y_true` (no broadcasting),
or can contain an additional single dimension of size 1. We consider the more probable class
to be predicted.
y_true: The ground-truth targets; they are rounded to 0 or 1 to obtain binary labels.
sample_weights: Optional sample weights. If provided, their shape must be broadcastable
to a prefix of a shape of `y_true`, and the loss for each sample is weighted accordingly.
Returns:
self
"""
y = maybe_remove_one_singleton_dimension(y, y_true)
assert y.shape == y_true.shape, f"Shapes of y {y.shape} and y_true {y_true.shape} have to match " \
"up to one singleton dim in y."
y = (y > (0.5 if self._probs else 0.0))
y_true = (y_true > 0.5)
return super().update(y == y_true, sample_weights=sample_weights)
|
__init__
__init__(*, probs: bool = False, device: device | None = None) -> None
Create the BinaryAccuracy metric object.
Parameters:
-
probs
(bool, default:
False
)
–
If False, the predictions are assumed to be logits; if True, the
predictions are assumed to be probabilities. Note that gold targets are
always expected to be probabilities.
Source code in minnt/metrics/binary_accuracy.py
22
23
24
25
26
27
28
29
30
31 | def __init__(self, *, probs: bool = False, device: torch.device | None = None) -> None:
"""Create the BinaryAccuracy metric object.
Parameters:
probs: If `False`, the predictions are assumed to be logits; if `True`, the
predictions are assumed to be probabilities. Note that gold targets are
always expected to be probabilities.
"""
super().__init__(device)
self._probs = probs
|
update
Update the accumulated binary accuracy using new predictions and gold labels.
Optional sample weight might be provided; if not, all values are weighted with 1.
Parameters:
-
y
(Tensor)
–
The predicted outputs. Their shape either has to be exactly the same as y_true (no broadcasting),
or can contain an additional single dimension of size 1. We consider the more probable class
to be predicted.
-
y_true
(Tensor)
–
The ground-truth targets; they are rounded to 0 or 1 to obtain binary labels.
-
sample_weights
(Tensor | None, default:
None
)
–
Optional sample weights. If provided, their shape must be broadcastable
to a prefix of a shape of y_true, and the loss for each sample is weighted accordingly.
Returns:
Source code in minnt/metrics/binary_accuracy.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 | @torch.no_grad
def update(
self, y: torch.Tensor, y_true: torch.Tensor, sample_weights: torch.Tensor | None = None,
) -> Self:
"""Update the accumulated binary accuracy using new predictions and gold labels.
Optional sample weight might be provided; if not, all values are weighted with 1.
Parameters:
y: The predicted outputs. Their shape either has to be exactly the same as `y_true` (no broadcasting),
or can contain an additional single dimension of size 1. We consider the more probable class
to be predicted.
y_true: The ground-truth targets; they are rounded to 0 or 1 to obtain binary labels.
sample_weights: Optional sample weights. If provided, their shape must be broadcastable
to a prefix of a shape of `y_true`, and the loss for each sample is weighted accordingly.
Returns:
self
"""
y = maybe_remove_one_singleton_dimension(y, y_true)
assert y.shape == y_true.shape, f"Shapes of y {y.shape} and y_true {y_true.shape} have to match " \
"up to one singleton dim in y."
y = (y > (0.5 if self._probs else 0.0))
y_true = (y_true > 0.5)
return super().update(y == y_true, sample_weights=sample_weights)
|