diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index bcd5ea91a9..e45b162f1f 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -68,7 +68,7 @@ class TestTimeAugmentation: Args: transform: transform (or composed) to be applied to each realization. At least one transform must be of type `RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`). - . All random transforms must be of type `InvertibleTransform`. + All random transforms must be of type `InvertibleTransform`. batch_size: number of realizations to infer at once. num_workers: how many subprocesses to use for data. inferrer_fn: function to use to perform inference. @@ -92,6 +92,11 @@ class TestTimeAugmentation: will return the full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. progress: whether to display a progress bar. + apply_inverse_to_pred: whether to apply inverse transformations to the predictions. + If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions + back to the original spatial reference. + If the prediction is non-spatial (e.g. classification label or score), this should be `False` to + aggregate the raw predictions directly. Defaults to `True`. Example: .. code-block:: python @@ -125,6 +130,7 @@ def __init__( post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, + apply_inverse_to_pred: bool = True, ) -> None: self.transform = transform self.batch_size = batch_size @@ -134,6 +140,7 @@ def __init__( self.image_key = image_key self.return_full_data = return_full_data self.progress = progress + self.apply_inverse_to_pred = apply_inverse_to_pred self._pred_key = CommonKeys.PRED self.inverter = Invertd( keys=self._pred_key, @@ -199,7 +206,10 @@ def __call__( for b in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device)) - outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) + if self.apply_inverse_to_pred: + outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)]) + else: + outs.extend([i[self._pred_key] for i in decollate_batch(b)]) output: NdarrayOrTensor = stack(outs, 0) diff --git a/tests/integration/test_testtimeaugmentation.py b/tests/integration/test_testtimeaugmentation.py index 62e4b46282..84da7c9c15 100644 --- a/tests/integration/test_testtimeaugmentation.py +++ b/tests/integration/test_testtimeaugmentation.py @@ -104,7 +104,7 @@ def test_test_time_augmentation(self): # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) - model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) @@ -181,6 +181,43 @@ def test_image_no_label(self): tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image") tta(self.get_data(1, (20, 20), include_label=False)) + def test_non_spatial_output(self): + """ + Test TTA for non-spatial output (e.g., classification scores). + Verifies that setting `apply_inverse_to_pred=False` correctly aggregates + predictions without attempting spatial inversion. + """ + input_size = (20, 20) + data = {"image": np.random.rand(1, *input_size).astype(np.float32)} + + transforms = Compose( + [EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)] + ) + + def mock_classifier(x): + batch_size = x.shape[0] + return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device) + + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=2, + num_workers=0, + inferrer_fn=mock_classifier, + device="cpu", + orig_key="image", + apply_inverse_to_pred=False, + return_full_data=False, + ) + mode, mean, std, vvc = tt_aug(data, num_examples=4) + + self.assertEqual(mean.shape, (2,)) + np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6) + np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6) + + tt_aug.return_full_data = True + full_output = tt_aug(data, num_examples=4) + self.assertEqual(full_output.shape, (4, 2)) + if __name__ == "__main__": unittest.main()