diff --git a/llama_stack/providers/remote/post_training/nvidia/__init__.py b/llama_stack/providers/remote/post_training/nvidia/__init__.py index 964e1fdaa..9210090e7 100644 --- a/llama_stack/providers/remote/post_training/nvidia/__init__.py +++ b/llama_stack/providers/remote/post_training/nvidia/__init__.py @@ -10,14 +10,18 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec from .config import NvidiaPostTrainingConfig -# post_training api and the torchtune provider is still experimental and under heavy development - async def get_adapter_impl( config: NvidiaPostTrainingConfig, deps: Dict[Api, ProviderSpec], ): - from .post_training import NvidiaPostTrainingImpl + from .post_training import NvidiaPostTrainingAdapter - impl = NvidiaPostTrainingImpl(config) + if not isinstance(config, NvidiaPostTrainingConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + + impl = NvidiaPostTrainingAdapter(config) return impl + + +__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"] diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index 15089c3b1..973bf3201 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -50,7 +50,7 @@ class ListNvidiaPostTrainingJobs(BaseModel): data: List[NvidiaPostTrainingJob] -class NvidiaPostTrainingImpl: +class NvidiaPostTrainingAdapter: def __init__(self, config: NvidiaPostTrainingConfig): self.config = config self.headers = {} @@ -226,12 +226,8 @@ class NvidiaPostTrainingImpl: # Extract LoRA-specific parameters lora_config = {k: v for k, v in algorithm_config.items() if k != "type"} job_config["hyperparameters"]["lora"] = lora_config - - # Add adapter_dim if available in training_config - if training_config.get("algorithm_config", {}).get("adapter_dim"): - job_config["hyperparameters"]["lora"]["adapter_dim"] = training_config["algorithm_config"][ - "adapter_dim" - ] + else: + raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}") # Create the customization job response = await self._make_request(