Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions monai/apps/nnunet/nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,35 +525,70 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int
kwargs.pop("npz")
logger.warning("please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.")

cmd = self.train_single_model_command(config, fold, gpu_id, kwargs)
run_cmd(cmd, shell=True)
cmd, env = self.train_single_model_command(config, fold, gpu_id, kwargs)
run_cmd(cmd, env=env)

def train_single_model_command(self, config, fold, gpu_id, kwargs):
if isinstance(gpu_id, (tuple, list)):
def train_single_model_command(
self, config: str, fold: int, gpu_id: int | str | tuple | list, kwargs: dict[str, Any]
) -> tuple[list[str], dict[str, str]]:
"""
Build the shell command string for training a single nnU-Net model.

Args:
config: Configuration name (e.g., "3d_fullres").
fold: Cross-validation fold index (0-4).
gpu_id: Device selector—int, str (MIG UUID), or tuple/list for multi-GPU.
kwargs: Additional CLI arguments forwarded to nnUNetv2_train.

Returns:
Shell command string.

Raises:
ValueError: If gpu_id is an empty tuple or list.
"""
env = os.environ.copy()
device_setting: str | None = None
num_gpus = 1
if isinstance(gpu_id, str):
device_setting = gpu_id
num_gpus = 1
elif isinstance(gpu_id, (tuple, list)):
if len(gpu_id) == 0:
raise ValueError("gpu_id tuple/list cannot be empty")
if len(gpu_id) > 1:
gpu_ids_str = ""
for _i in range(len(gpu_id)):
gpu_ids_str += f"{gpu_id[_i]},"
device_setting = f"CUDA_VISIBLE_DEVICES={gpu_ids_str[:-1]}"
else:
device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id[0]}"
device_setting = ",".join(str(x) for x in gpu_id)
num_gpus = len(gpu_id)
elif len(gpu_id) == 1:
device_setting = str(gpu_id[0])
num_gpus = 1
else:
device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id}"
num_gpus = 1 if isinstance(gpu_id, int) or len(gpu_id) == 1 else len(gpu_id)

cmd = (
f"{device_setting} nnUNetv2_train "
+ f"{self.dataset_name_or_id} {config} {fold} "
+ f"-tr {self.trainer_class_name} -num_gpus {num_gpus}"
)
device_setting = str(gpu_id)
num_gpus = 1
env_cuda = env.get("CUDA_VISIBLE_DEVICES")
if env_cuda is not None and device_setting == "0":
logger.info(f"Using existing environment variable CUDA_VISIBLE_DEVICES='{env_cuda}'")
device_setting = None
elif device_setting is not None:
env["CUDA_VISIBLE_DEVICES"] = device_setting

cmd = [
"nnUNetv2_train",
f"{self.dataset_name_or_id}",
f"{config}",
f"{fold}",
"-tr",
f"{self.trainer_class_name}",
"-num_gpus",
f"{num_gpus}",
]
if self.export_validation_probabilities:
cmd += " --npz"
cmd.append("--npz")
for _key, _value in kwargs.items():
if _key == "p" or _key == "pretrained_weights":
cmd += f" -{_key} {_value}"
cmd.extend([f"-{_key}", f"{_value}"])
else:
cmd += f" --{_key} {_value}"
return cmd
cmd.extend([f"--{_key}", f"{_value}"])
return cmd, env

def train(
self,
Expand Down Expand Up @@ -779,7 +814,7 @@ def predict(
part_id: int = 0,
num_processes_preprocessing: int = -1,
num_processes_segmentation_export: int = -1,
gpu_id: int = 0,
gpu_id: int | str = 0,
) -> None:
"""
Use this to run inference with nnU-Net. This function is used when you want to manually specify a folder containing
Expand Down Expand Up @@ -813,9 +848,14 @@ def predict(
num_processes_preprocessing: out-of-RAM issues.
num_processes_segmentation_export: Number of processes used for segmentation export.
More is not always better. Beware of out-of-RAM issues.
gpu_id: which GPU to use for prediction.
gpu_id: GPU device index (int) or MIG UUID (str) for prediction.
If CUDA_VISIBLE_DEVICES is already set and gpu_id is 0, the existing
environment variable is preserved.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
if "CUDA_VISIBLE_DEVICES" in os.environ and (gpu_id == 0 or gpu_id == "0"):
logger.info(f"Predict: Using existing CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
else:
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

Expand Down
Loading