diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py index 8a10849904..ac75d96d94 100644 --- a/monai/apps/nnunet/nnunetv2_runner.py +++ b/monai/apps/nnunet/nnunetv2_runner.py @@ -525,35 +525,70 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int kwargs.pop("npz") logger.warning("please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.") - cmd = self.train_single_model_command(config, fold, gpu_id, kwargs) - run_cmd(cmd, shell=True) + cmd, env = self.train_single_model_command(config, fold, gpu_id, kwargs) + run_cmd(cmd, env=env) - def train_single_model_command(self, config, fold, gpu_id, kwargs): - if isinstance(gpu_id, (tuple, list)): + def train_single_model_command( + self, config: str, fold: int, gpu_id: int | str | tuple | list, kwargs: dict[str, Any] + ) -> tuple[list[str], dict[str, str]]: + """ + Build the shell command string for training a single nnU-Net model. + + Args: + config: Configuration name (e.g., "3d_fullres"). + fold: Cross-validation fold index (0-4). + gpu_id: Device selector—int, str (MIG UUID), or tuple/list for multi-GPU. + kwargs: Additional CLI arguments forwarded to nnUNetv2_train. + + Returns: + Shell command string. + + Raises: + ValueError: If gpu_id is an empty tuple or list. + """ + env = os.environ.copy() + device_setting: str | None = None + num_gpus = 1 + if isinstance(gpu_id, str): + device_setting = gpu_id + num_gpus = 1 + elif isinstance(gpu_id, (tuple, list)): + if len(gpu_id) == 0: + raise ValueError("gpu_id tuple/list cannot be empty") if len(gpu_id) > 1: - gpu_ids_str = "" - for _i in range(len(gpu_id)): - gpu_ids_str += f"{gpu_id[_i]}," - device_setting = f"CUDA_VISIBLE_DEVICES={gpu_ids_str[:-1]}" - else: - device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id[0]}" + device_setting = ",".join(str(x) for x in gpu_id) + num_gpus = len(gpu_id) + elif len(gpu_id) == 1: + device_setting = str(gpu_id[0]) + num_gpus = 1 else: - device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id}" - num_gpus = 1 if isinstance(gpu_id, int) or len(gpu_id) == 1 else len(gpu_id) - - cmd = ( - f"{device_setting} nnUNetv2_train " - + f"{self.dataset_name_or_id} {config} {fold} " - + f"-tr {self.trainer_class_name} -num_gpus {num_gpus}" - ) + device_setting = str(gpu_id) + num_gpus = 1 + env_cuda = env.get("CUDA_VISIBLE_DEVICES") + if env_cuda is not None and device_setting == "0": + logger.info(f"Using existing environment variable CUDA_VISIBLE_DEVICES='{env_cuda}'") + device_setting = None + elif device_setting is not None: + env["CUDA_VISIBLE_DEVICES"] = device_setting + + cmd = [ + "nnUNetv2_train", + f"{self.dataset_name_or_id}", + f"{config}", + f"{fold}", + "-tr", + f"{self.trainer_class_name}", + "-num_gpus", + f"{num_gpus}", + ] if self.export_validation_probabilities: - cmd += " --npz" + cmd.append("--npz") for _key, _value in kwargs.items(): if _key == "p" or _key == "pretrained_weights": - cmd += f" -{_key} {_value}" + cmd.extend([f"-{_key}", f"{_value}"]) else: - cmd += f" --{_key} {_value}" - return cmd + cmd.extend([f"--{_key}", f"{_value}"]) + return cmd, env def train( self, @@ -779,7 +814,7 @@ def predict( part_id: int = 0, num_processes_preprocessing: int = -1, num_processes_segmentation_export: int = -1, - gpu_id: int = 0, + gpu_id: int | str = 0, ) -> None: """ Use this to run inference with nnU-Net. This function is used when you want to manually specify a folder containing @@ -813,9 +848,14 @@ def predict( num_processes_preprocessing: out-of-RAM issues. num_processes_segmentation_export: Number of processes used for segmentation export. More is not always better. Beware of out-of-RAM issues. - gpu_id: which GPU to use for prediction. + gpu_id: GPU device index (int) or MIG UUID (str) for prediction. + If CUDA_VISIBLE_DEVICES is already set and gpu_id is 0, the existing + environment variable is preserved. """ - os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + if "CUDA_VISIBLE_DEVICES" in os.environ and (gpu_id == 0 or gpu_id == "0"): + logger.info(f"Predict: Using existing CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}") + else: + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor