diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index d3de930f7..900c58171 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): self.timeout = aiohttp.ClientTimeout(total=config.timeout) # TODO: filter by available models based on /config endpoint ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) - self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) - self.customizer_url = config.customizer_url + self.session = None + self.customizer_url = config.customizer_url if not self.customizer_url: warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2) self.customizer_url = "http://nemo.test" + async def _get_session(self) -> aiohttp.ClientSession: + if self.session is None or self.session.closed: + self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) + return self.session + async def _make_request( self, method: str, @@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): if json and "Content-Type" not in request_headers: request_headers["Content-Type"] = "application/json" + session = await self._get_session() for _ in range(self.config.max_retries): - async with self.session.request(method, url, params=params, json=json, **kwargs) as response: + async with session.request(method, url, params=params, json=json, **kwargs) as response: if response.status >= 400: error_data = await response.json() raise Exception(f"API request failed: {error_data}") @@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): jobs = [] for job in response.get("data", []): job_id = job.pop("id") - job_status = job.pop("status", "unknown").lower() - mapped_status = STATUS_MAPPING.get(job_status, "unknown") + job_status = job.pop("status", "scheduled").lower() + mapped_status = STATUS_MAPPING.get(job_status, "scheduled") # Convert string timestamps to datetime objects created_at = ( @@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): ) api_status = response.pop("status").lower() - mapped_status = STATUS_MAPPING.get(api_status, "unknown") + mapped_status = STATUS_MAPPING.get(api_status, "scheduled") return NvidiaPostTrainingJobStatusResponse( status=JobStatus(mapped_status), @@ -297,7 +303,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): User is informed about unsupported parameters via warnings. """ # Map model to nvidia model name - # ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models + # See `_MODEL_ENTRIES` for supported models nvidia_model = self.get_provider_model_id(model) # Check for unsupported method parameters @@ -389,14 +395,17 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Handle LoRA-specific configuration if algorithm_config: - if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA": - warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config") + algorithm_config_dict = ( + algorithm_config.model_dump() if hasattr(algorithm_config, "model_dump") else algorithm_config + ) + if isinstance(algorithm_config_dict, dict) and algorithm_config_dict.get("type") == "LoRA": + warn_unsupported_params(algorithm_config_dict, supported_params["lora_config"], "LoRA config") job_config["hyperparameters"]["lora"] = { k: v for k, v in { - "adapter_dim": algorithm_config.get("adapter_dim"), - "alpha": algorithm_config.get("alpha"), - "adapter_dropout": algorithm_config.get("adapter_dropout"), + "adapter_dim": algorithm_config_dict.get("adapter_dim"), + "alpha": algorithm_config_dict.get("alpha"), + "adapter_dropout": algorithm_config_dict.get("adapter_dropout"), }.items() if v is not None }