diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index 936e89713..c74fb2a24 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -392,19 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Handle LoRA-specific configuration if algorithm_config: - algorithm_config_dict = ( - algorithm_config.model_dump() if hasattr(algorithm_config, "model_dump") else algorithm_config - ) - if isinstance(algorithm_config_dict, dict) and algorithm_config_dict.get("type") == "LoRA": - warn_unsupported_params(algorithm_config_dict, supported_params["lora_config"], "LoRA config") + if algorithm_config.type == "LoRA": + warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config") job_config["hyperparameters"]["lora"] = { - k: v - for k, v in { - "adapter_dim": algorithm_config_dict.get("adapter_dim"), - "alpha": algorithm_config_dict.get("alpha"), - "adapter_dropout": algorithm_config_dict.get("adapter_dropout"), - }.items() - if v is not None + k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None } else: raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")