From 0fc51ea6b2e3b0475ec55acad8a94a29a3006bf4 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Thu, 15 Jan 2026 13:04:54 +0000 Subject: [PATCH 1/2] Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with additional parameters - Add include_background, to_onehot_y, sigmoid, softmax, other_act, and reduction parameters - Fix argument order in forward() to match other losses (y_pred, y_true) - Add proper input validation and comprehensive docstrings - These changes make the losses consistent with DiceLoss API and fix zero loss issues Signed-off-by: Soumya Snigdha Kundu --- monai/losses/cldice.py | 268 +++++++++++++++++++++++++------ tests/losses/test_cldice_loss.py | 131 ++++++++++++--- 2 files changed, 325 insertions(+), 74 deletions(-) diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index 406cc3825f..06e9634817 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -11,10 +11,17 @@ from __future__ import annotations +import warnings +from collections.abc import Callable + import torch import torch.nn.functional as F from torch.nn.modules.loss import _Loss +from monai.losses.dice import DiceLoss +from monai.networks import one_hot +from monai.utils import LossReduction + def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore """ @@ -92,26 +99,6 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: return skel -def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: - """ - Function to compute soft dice loss - - Adapted from: - https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22 - - Args: - y_true: the shape should be BCH(WD) - y_pred: the shape should be BCH(WD) - - Returns: - dice loss - """ - intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) - coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) - soft_dice: torch.Tensor = 1.0 - coeff - return soft_dice - - class SoftclDiceLoss(_Loss): """ Compute the Soft clDice loss defined in: @@ -121,64 +108,241 @@ class SoftclDiceLoss(_Loss): Adapted from: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 + + The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + """ - def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: + def __init__( + self, + iter_: int = 3, + smooth: float = 1.0, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + reduction: LossReduction | str = LossReduction.MEAN, + ) -> None: """ Args: - iter_: Number of iterations for skeletonization - smooth: Smoothing parameter + iter_: Number of iterations for skeletonization. + smooth: Smoothing parameter. + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + """ - super().__init__() + super().__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") self.iter = iter_ self.smooth = smooth + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + + Raises: + AssertionError: When input and target (after one hot transform if set) + have different shapes. + + """ + n_pred_ch = input.shape[1] + + if self.sigmoid: + input = torch.sigmoid(input) + + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, dim=1) - def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - skel_pred = soft_skel(y_pred, self.iter) - skel_true = soft_skel(y_true, self.iter) - tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_pred[:, 1:, ...]) + self.smooth + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + + skel_pred = soft_skel(input, self.iter) + skel_true = soft_skel(target, self.iter) + + # Compute per-batch clDice by reducing over channel and spatial dimensions + # reduce_axis includes all dimensions except batch (dim 0) + reduce_axis: list[int] = list(range(1, len(input.shape))) + + tprec = (torch.sum(torch.multiply(skel_pred, target), dim=reduce_axis) + self.smooth) / ( + torch.sum(skel_pred, dim=reduce_axis) + self.smooth ) - tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_true[:, 1:, ...]) + self.smooth + tsens = (torch.sum(torch.multiply(skel_true, input), dim=reduce_axis) + self.smooth) / ( + torch.sum(skel_true, dim=reduce_axis) + self.smooth ) cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + cl_dice = torch.mean(cl_dice) + elif self.reduction == LossReduction.SUM.value: + cl_dice = torch.sum(cl_dice) + elif self.reduction == LossReduction.NONE.value: + pass # keep per-batch values + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return cl_dice class SoftDiceclDiceLoss(_Loss): """ - Compute the Soft clDice loss defined in: + Compute both Dice loss and clDice loss, and return the weighted sum of these two losses. + The details of Dice loss is shown in ``monai.losses.DiceLoss``. + The details of clDice loss is shown in ``monai.losses.SoftclDiceLoss``. + Adapted from: Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) - Adapted from: - https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 """ - def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: + def __init__( + self, + iter_: int = 3, + alpha: float = 0.5, + smooth: float = 1.0, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + reduction: LossReduction | str = LossReduction.MEAN, + ) -> None: """ Args: - iter_: Number of iterations for skeletonization - smooth: Smoothing parameter - alpha: Weighing factor for cldice + iter_: Number of iterations for skeletonization, used by clDice. + alpha: Weighing factor for cldice component. Total loss = (1 - alpha) * dice + alpha * cldice. + Defaults to 0.5. + smooth: Smoothing parameter, used by both Dice and clDice. + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + """ super().__init__() - self.iter = iter_ - self.smooth = smooth - self.alpha = alpha - - def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - dice = soft_dice(y_true, y_pred, self.smooth) - skel_pred = soft_skel(y_pred, self.iter) - skel_true = soft_skel(y_true, self.iter) - tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_pred[:, 1:, ...]) + self.smooth + self.dice = DiceLoss( + include_background=include_background, + to_onehot_y=False, + sigmoid=sigmoid, + softmax=softmax, + other_act=other_act, + reduction=reduction, + smooth_nr=smooth, + smooth_dr=smooth, ) - tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_true[:, 1:, ...]) + self.smooth + self.cldice = SoftclDiceLoss( + iter_=iter_, + smooth=smooth, + include_background=include_background, + to_onehot_y=False, + sigmoid=sigmoid, + softmax=softmax, + other_act=other_act, + reduction=reduction, ) - cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) - total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice + self.alpha = alpha + self.to_onehot_y = to_onehot_y + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + + Raises: + ValueError: When number of dimensions for input and target are different. + ValueError: When number of channels for target is neither 1 nor the same as input. + + """ + if input.dim() != target.dim(): + raise ValueError( + "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} and {target.shape}." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 nor the same as input, " + f"got shape {input.shape} and {target.shape}." + ) + + if self.to_onehot_y: + n_pred_ch = input.shape[1] + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + dice_loss = self.dice(input, target) + cldice_loss = self.cldice(input, target) + total_loss: torch.Tensor = (1.0 - self.alpha) * dice_loss + self.alpha * cldice_loss + return total_loss diff --git a/tests/losses/test_cldice_loss.py b/tests/losses/test_cldice_loss.py index 14d3575e3b..adb23008ce 100644 --- a/tests/losses/test_cldice_loss.py +++ b/tests/losses/test_cldice_loss.py @@ -1,3 +1,6 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software @@ -16,35 +19,119 @@ from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss -TEST_CASES = [ - [{"y_pred": torch.ones((7, 3, 11, 10)), "y_true": torch.ones((7, 3, 11, 10))}, 0.0], - [{"y_pred": torch.ones((2, 3, 13, 14, 5)), "y_true": torch.ones((2, 3, 13, 14, 5))}, 0.0], +# Reusable test tensors +ONES_2D = {"input": torch.ones((2, 3, 8, 8)), "target": torch.ones((2, 3, 8, 8))} +ONES_3D = {"input": torch.ones((2, 3, 8, 8, 8)), "target": torch.ones((2, 3, 8, 8, 8))} + +# Partial overlap: two 2x2 squares shifted by 1 pixel +PARTIAL_OVERLAP = { + "input": torch.tensor( + [[[[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), + "target": torch.tensor( + [[[[0.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]] + ), +} + +# Test cases: [loss_params, input_data, expected_value] +CLDICE_CASES = [ + [{}, ONES_2D, 0.0], + [{}, ONES_3D, 0.0], + [ + {"sigmoid": True, "smooth": 1e-5}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[0.5, 0.5], [0.5, 0.5]]]]), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]]]), + }, + 0.192777, + ], + [ + {"softmax": True, "smooth": 1e-5}, + { + "input": torch.tensor([[[[2.0, 0.0], [0.0, 2.0]], [[-2.0, 0.0], [0.0, -2.0]]]]), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]]]), + }, + 0.148760, + ], + [ + {"to_onehot_y": True, "smooth": 1e-5}, + { + "input": torch.tensor([[[[0.9, 0.1], [0.1, 0.9]], [[0.1, 0.9], [0.9, 0.1]]]]), + "target": torch.tensor([[[[0, 1], [1, 0]]]]), + }, + 0.052631, + ], ] +COMBINED_CASES = [ + [{"alpha": 0.5}, ONES_2D, 0.0], + [{"alpha": 0.5, "smooth": 1e-5}, PARTIAL_OVERLAP, 0.624995], + [{"alpha": 0.0, "smooth": 1e-5}, PARTIAL_OVERLAP, 0.250000], # pure Dice + [{"alpha": 1.0, "smooth": 1e-5}, PARTIAL_OVERLAP, 0.999990], # pure clDice +] -class TestclDiceLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_result(self, y_pred_data, expected_val): - loss = SoftclDiceLoss() - loss_dice = SoftDiceclDiceLoss() - result = loss(**y_pred_data) - result_dice = loss_dice(**y_pred_data) +class TestSoftclDiceLoss(unittest.TestCase): + + @parameterized.expand(CLDICE_CASES) + def test_result(self, loss_params, input_data, expected_val): + loss = SoftclDiceLoss(**loss_params) + result = loss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) - np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) - def test_with_cuda(self): + def test_cuda(self): + if not torch.cuda.is_available(): + return + loss = SoftclDiceLoss() + result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + + def test_reduction_shapes(self): + input_tensor = torch.ones((4, 2, 8, 8)) + target = torch.ones((4, 2, 8, 8)) + + self.assertEqual(SoftclDiceLoss(reduction="mean")(input_tensor, target).shape, torch.Size([])) + self.assertEqual(SoftclDiceLoss(reduction="sum")(input_tensor, target).shape, torch.Size([])) + self.assertEqual(SoftclDiceLoss(reduction="none")(input_tensor, target).shape, torch.Size([4])) + + def test_ill_shape(self): loss = SoftclDiceLoss() - loss_dice = SoftDiceclDiceLoss() - i = torch.ones((100, 3, 256, 256)) - j = torch.ones((100, 3, 256, 256)) - if torch.cuda.is_available(): - i = i.cuda() - j = j.cuda() - output = loss(i, j) - output_dice = loss_dice(i, j) - np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) - np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + with self.assertRaisesRegex(AssertionError, "ground truth has different shape"): + loss(torch.ones((1, 3, 8, 8)), torch.ones((1, 4, 8, 8))) + + def test_invalid_activation_combination(self): + with self.assertRaises(ValueError): + SoftclDiceLoss(sigmoid=True, softmax=True) + + def test_invalid_other_act(self): + with self.assertRaises(TypeError): + SoftclDiceLoss(other_act="invalid") + + +class TestSoftDiceclDiceLoss(unittest.TestCase): + + @parameterized.expand(COMBINED_CASES) + def test_result(self, loss_params, input_data, expected_val): + loss = SoftDiceclDiceLoss(**loss_params) + result = loss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_cuda(self): + if not torch.cuda.is_available(): + return + loss = SoftDiceclDiceLoss() + result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + + def test_dimension_mismatch(self): + loss = SoftDiceclDiceLoss() + with self.assertRaises(ValueError): + loss(torch.ones(2, 3, 8, 8), torch.ones(2, 3, 8)) + + def test_channel_mismatch(self): + loss = SoftDiceclDiceLoss() + with self.assertRaises(ValueError): + loss(torch.ones(2, 3, 8, 8), torch.ones(2, 2, 8, 8)) if __name__ == "__main__": From e8a2579bff50886be0ee398c66de3e86ca24ea16 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Thu, 15 Jan 2026 19:46:43 +0000 Subject: [PATCH 2/2] address coderabbit comment Signed-off-by: Soumya Snigdha Kundu --- monai/losses/cldice.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index 06e9634817..dacb7cb847 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -130,7 +130,7 @@ def __init__( """ Args: iter_: Number of iterations for skeletonization. - smooth: Smoothing parameter. + smooth: Smoothing parameter to avoid division by zero. Defaults to 1.0. include_background: if False, channel index 0 (background category) is excluded from the calculation. if the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. @@ -158,6 +158,8 @@ def __init__( raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + if smooth <= 0: + raise ValueError(f"smooth must be a positive value but got {smooth}.") self.iter = iter_ self.smooth = smooth self.include_background = include_background @@ -220,7 +222,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: tsens = (torch.sum(torch.multiply(skel_true, input), dim=reduce_axis) + self.smooth) / ( torch.sum(skel_true, dim=reduce_axis) + self.smooth ) - cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-8) # Apply reduction if self.reduction == LossReduction.MEAN.value: @@ -264,7 +266,7 @@ def __init__( iter_: Number of iterations for skeletonization, used by clDice. alpha: Weighing factor for cldice component. Total loss = (1 - alpha) * dice + alpha * cldice. Defaults to 0.5. - smooth: Smoothing parameter, used by both Dice and clDice. + smooth: Smoothing parameter to avoid division by zero, used by both Dice and clDice. Defaults to 1.0. include_background: if False, channel index 0 (background category) is excluded from the calculation. if the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. @@ -288,6 +290,8 @@ def __init__( """ super().__init__() + if smooth <= 0: + raise ValueError(f"smooth must be a positive value but got {smooth}.") self.dice = DiceLoss( include_background=include_background, to_onehot_y=False,