diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 2b6777a75f..1799fec542 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -201,6 +201,45 @@ (1, 1, 16, 16, 16), (1, 3, 4, 4, 4), ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "label_nc": 5, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + "label_nc": 5, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], ] LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ [ @@ -661,7 +700,7 @@ def test_normal_cdf(self): x = torch.linspace(-10, 10, 20) cdf_approx = inferer._approx_standard_normal_cdf(x) cdf_true = norm.cdf(x) - torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + torch.testing.assert_close(cdf_approx, torch.as_tensor(cdf_true, dtype=cdf_approx.dtype), atol=1e-3, rtol=1e-5) @parameterized.expand(CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -742,6 +781,8 @@ def test_prediction_shape( stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: @@ -764,7 +805,7 @@ def test_prediction_shape( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -807,14 +848,16 @@ def test_pred_shape( ): stage_1 = None - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - if ae_model_type == "VQVAE": - stage_1 = VQVAE(**autoencoder_params) if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -905,7 +948,7 @@ def test_sample_intermediates( else: input_shape_seg[1] = autoencoder_params["label_nc"] input_seg = torch.randn(input_shape_seg).to(device) - sample = inferer.sample( + sample, intermediates = inferer.sample( input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, @@ -913,11 +956,9 @@ def test_sample_intermediates( seg=input_seg, controlnet=controlnet, cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, ) - - # TODO: this isn't correct, should the above produce intermediates as well? - # This test has always passed so is this branch not being used? - intermediates = None else: sample, intermediates = inferer.sample( input_noise=noise, @@ -973,7 +1014,7 @@ def test_get_likelihoods( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1043,7 +1084,7 @@ def test_resample_likelihoods( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1127,7 +1168,7 @@ def test_prediction_shape_conditioned_concat( timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1209,7 +1250,7 @@ def test_sample_shape_conditioned_concat( inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - if dm_model_type == "SPADEDiffusionModelUNet": + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"] @@ -1290,7 +1331,7 @@ def test_shape_different_latents( timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": input_shape_seg = list(input_shape) if "label_nc" in stage_2_params.keys(): input_shape_seg[1] = stage_2_params["label_nc"]