From 2ec6eb2a5bb2fd8c7d935ed173432f55d9ef29c1 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 21 Jan 2026 10:09:38 +0800 Subject: [PATCH 1/6] Fix dead code and logic error 1. added missing SPADE test cases to LATENT_CNDM_TEST_CASES 2. fixed the SPADE branch in test_sample_intermediates Signed-off-by: ytl0623 --- tests/inferers/test_controlnet_inferers.py | 47 +++++++++++++++++++--- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 2b6777a75f..092c45d22b 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -201,6 +201,45 @@ (1, 1, 16, 16, 16), (1, 3, 4, 4, 4), ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "label_nc": 5, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + "label_nc": 5, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], ] LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ [ @@ -905,7 +944,7 @@ def test_sample_intermediates( else: input_shape_seg[1] = autoencoder_params["label_nc"] input_seg = torch.randn(input_shape_seg).to(device) - sample = inferer.sample( + sample, intermediates = inferer.sample( input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, @@ -913,11 +952,9 @@ def test_sample_intermediates( seg=input_seg, controlnet=controlnet, cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, ) - - # TODO: this isn't correct, should the above produce intermediates as well? - # This test has always passed so is this branch not being used? - intermediates = None else: sample, intermediates = inferer.sample( input_noise=noise, From 778c953f5d7a01ae49db49f684537676b5150c5b Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 21 Jan 2026 11:19:38 +0800 Subject: [PATCH 2/6] fix missing logic in inferer tests Signed-off-by: ytl0623 --- tests/inferers/test_controlnet_inferers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 092c45d22b..4d79ce485f 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -781,6 +781,8 @@ def test_prediction_shape( stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: @@ -803,7 +805,7 @@ def test_prediction_shape( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -852,6 +854,8 @@ def test_pred_shape( stage_1 = VQVAE(**autoencoder_params) if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) controlnet = ControlNet(**controlnet_params) From e52329c4a7ea59e5e474d5737d2b9cd32c844d6f Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 21 Jan 2026 11:29:00 +0800 Subject: [PATCH 3/6] fix stage_2 selection Signed-off-by: ytl0623 --- tests/inferers/test_controlnet_inferers.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 4d79ce485f..1a6d6f84ac 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -848,16 +848,12 @@ def test_pred_shape( ): stage_1 = None - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - if ae_model_type == "VQVAE": - stage_1 = VQVAE(**autoencoder_params) if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) - if ae_model_type == "SPADEAutoencoderKL": - stage_1 = SPADEAutoencoderKL(**autoencoder_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" From d40604be50a248a9b7eb97532dbe4cfdcd1dadaf Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 21 Jan 2026 11:34:45 +0800 Subject: [PATCH 4/6] missing autoencoder instantiation for non-SPADE cases Signed-off-by: ytl0623 --- tests/inferers/test_controlnet_inferers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 1a6d6f84ac..ed6a2557b8 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -852,6 +852,10 @@ def test_pred_shape( stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) controlnet = ControlNet(**controlnet_params) From a66124dc5a7125812e61c4ff6e47b51b354d5114 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 21 Jan 2026 11:43:44 +0800 Subject: [PATCH 5/6] inconsistent SPADE handling Signed-off-by: ytl0623 --- tests/inferers/test_controlnet_inferers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index ed6a2557b8..74ab45df68 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -1014,7 +1014,7 @@ def test_get_likelihoods( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1084,7 +1084,7 @@ def test_resample_likelihoods( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1168,7 +1168,7 @@ def test_prediction_shape_conditioned_concat( timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1250,7 +1250,7 @@ def test_sample_shape_conditioned_concat( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1331,7 +1331,7 @@ def test_shape_different_latents( timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] From 69c9dfbb3025612387ebd5e793c541908c4ee20b Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 21 Jan 2026 13:24:59 +0800 Subject: [PATCH 6/6] replace assert_allclose with assert_close Since PyTorch 1.10, assert_allclose is deprecated. This change migrates to assert_close and explicitly converts inputs to tensors to satisfy stricter type checks. Signed-off-by: ytl0623 --- tests/inferers/test_controlnet_inferers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 74ab45df68..1799fec542 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -700,7 +700,7 @@ def test_normal_cdf(self): x = torch.linspace(-10, 10, 20) cdf_approx = inferer._approx_standard_normal_cdf(x) cdf_true = norm.cdf(x) - torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + torch.testing.assert_close(cdf_approx, torch.as_tensor(cdf_true, dtype=cdf_approx.dtype), atol=1e-3, rtol=1e-5) @parameterized.expand(CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops")