diff --git a/monai/metrics/active_learning_metrics.py b/monai/metrics/active_learning_metrics.py index 7a1654191e..4756e421d5 100644 --- a/monai/metrics/active_learning_metrics.py +++ b/monai/metrics/active_learning_metrics.py @@ -129,9 +129,7 @@ def compute_variance( y_pred = y_pred.float() if not include_background: - y = y_pred - # TODO If this utils is made to be optional for 'y' it would be nice - y_pred, y = ignore_background(y_pred=y_pred, y=y) + y_pred = ignore_background(y_pred=y_pred) # Set any values below 0 to threshold y_pred[y_pred <= 0] = threshold diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a451b1a770..24817800a2 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -15,7 +15,7 @@ from collections.abc import Iterable, Sequence from functools import cache, partial from types import ModuleType -from typing import Any +from typing import Any, overload import numpy as np import torch @@ -51,21 +51,44 @@ ] -def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: +@overload +def ignore_background(y_pred: NdarrayTensor, y: None = ...) -> NdarrayTensor: ... + + +@overload +def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: ... + + +def ignore_background( + y_pred: NdarrayTensor, y: NdarrayTensor | None = None +) -> NdarrayTensor | tuple[NdarrayTensor, NdarrayTensor]: """ This function is used to remove background (the first channel) for `y_pred` and `y`. Args: y_pred: predictions. As for classification tasks, - `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, + `y_pred` should have the shape [BN] where N is larger than 1. As for segmentation tasks, the shape should be [BNHW] or [BNHWD]. - y: ground truth, the first dim is batch. + y: ground truth, the first dim is batch. (Optional) + Returns: + NdarrayTensor | tuple[NdarrayTensor, NdarrayTensor]: + - If `y` is None: returns background-removed `y_pred` only. + - If `y` is provided: returns a tuple of (background-removed `y_pred`, background-removed `y`). """ - y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment] - y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # type: ignore[assignment] - return y_pred, y + y_pred_out = y_pred + if y_pred.shape[1] > 1: + y_pred_out = y_pred[:, 1:] # type: ignore + + if y is None: + return y_pred_out + + y_out = y + if y.shape[1] > 1: + y_out = y[:, 1:] # type: ignore + + return y_pred_out, y_out def do_metric_reduction(