Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TestTimeAugmentation:
Args:
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
. All random transforms must be of type `InvertibleTransform`.
All random transforms must be of type `InvertibleTransform`.
batch_size: number of realizations to infer at once.
num_workers: how many subprocesses to use for data.
inferrer_fn: function to use to perform inference.
Expand All @@ -92,6 +92,11 @@ class TestTimeAugmentation:
will return the full data. Dimensions will be same size as when passing a single image through
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
progress: whether to display a progress bar.
apply_inverse_to_pred: whether to apply inverse transformations to the predictions.
If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions
back to the original spatial reference.
If the prediction is non-spatial (e.g. classification label or score), this should be `False` to
aggregate the raw predictions directly. Defaults to `True`.

Example:
.. code-block:: python
Expand Down Expand Up @@ -125,6 +130,7 @@ def __init__(
post_func: Callable = _identity,
return_full_data: bool = False,
progress: bool = True,
apply_inverse_to_pred: bool = True,
) -> None:
self.transform = transform
self.batch_size = batch_size
Expand All @@ -134,6 +140,7 @@ def __init__(
self.image_key = image_key
self.return_full_data = return_full_data
self.progress = progress
self.apply_inverse_to_pred = apply_inverse_to_pred
self._pred_key = CommonKeys.PRED
self.inverter = Invertd(
keys=self._pred_key,
Expand Down Expand Up @@ -199,7 +206,10 @@ def __call__(
for b in tqdm(dl) if has_tqdm and self.progress else dl:
# do model forward pass
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
if self.apply_inverse_to_pred:
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
else:
outs.extend([i[self._pred_key] for i in decollate_batch(b)])

output: NdarrayOrTensor = stack(outs, 0)

Expand Down
39 changes: 38 additions & 1 deletion tests/integration/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_test_time_augmentation(self):
# output might be different size, so pad so that they match
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device)
loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

Expand Down Expand Up @@ -181,6 +181,43 @@ def test_image_no_label(self):
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image")
tta(self.get_data(1, (20, 20), include_label=False))

def test_non_spatial_output(self):
"""
Test TTA for non-spatial output (e.g., classification scores).
Verifies that setting `apply_inverse_to_pred=False` correctly aggregates
predictions without attempting spatial inversion.
"""
input_size = (20, 20)
data = {"image": np.random.rand(1, *input_size).astype(np.float32)}

transforms = Compose(
[EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)]
)

def mock_classifier(x):
batch_size = x.shape[0]
return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device)

tt_aug = TestTimeAugmentation(
transform=transforms,
batch_size=2,
num_workers=0,
inferrer_fn=mock_classifier,
device="cpu",
orig_key="image",
apply_inverse_to_pred=False,
return_full_data=False,
)
mode, mean, std, vvc = tt_aug(data, num_examples=4)

self.assertEqual(mean.shape, (2,))
np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6)
np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6)

tt_aug.return_full_data = True
full_output = tt_aug(data, num_examples=4)
self.assertEqual(full_output.shape, (4, 2))


if __name__ == "__main__":
unittest.main()
Loading