From 8cc3fe76693e0c60b8780c69a562b17831bf50a3 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 16 Jul 2025 05:21:25 -0400 Subject: [PATCH] chore: remove vision model URL workarounds and simplify client creation The vision models are now available at the standard URL, so the workaround code has been removed. This also simplifies the codebase by eliminating the need for per-model client caching. - Remove special URL handling for meta/llama-3.2-11b/90b-vision-instruct models - Convert _get_client method to _client property for cleaner API - Remove unnecessary lru_cache decorator and functools import - Simplify client creation logic to use single base URL for all models --- .../remote/inference/nvidia/nvidia.py | 47 +++++-------------- .../nvidia/test_supervised_fine_tuning.py | 3 +- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 1dd72da3f..f790c2312 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,6 @@ import logging import warnings from collections.abc import AsyncIterator -from functools import lru_cache from typing import Any from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -93,41 +92,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config - @lru_cache # noqa: B019 - def _get_client(self, provider_model_id: str) -> AsyncOpenAI: + @property + def _client(self) -> AsyncOpenAI: """ - For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However, - some models are hosted on different URLs. This function returns the appropriate client - for the given provider_model_id. + Returns an OpenAI client for the configured NVIDIA API endpoint. - This relies on lru_cache and self._default_client to avoid creating a new client for each request - or for each model that is hosted on https://integrate.api.nvidia.com/v1. - - :param provider_model_id: The provider model ID :return: An OpenAI client """ - @lru_cache # noqa: B019 - def _get_client_for_base_url(base_url: str) -> AsyncOpenAI: - """ - Maintain a single OpenAI client per base_url. - """ - return AsyncOpenAI( - base_url=base_url, - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) - - special_model_urls = { - "meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct", - "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", - } - base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url - if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: - base_url = special_model_urls[provider_model_id] - return _get_client_for_base_url(base_url) + return AsyncOpenAI( + base_url=base_url, + api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), + timeout=self._config.timeout, + ) async def _get_provider_model_id(self, model_id: str) -> str: if not self.model_store: @@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._get_client(provider_model_id).completions.create(**request) + response = await self._client.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._get_client(provider_model_id).embeddings.create( + response = await self._client.embeddings.create( model=provider_model_id, input=input, extra_body=extra_body, @@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._get_client(provider_model_id).chat.completions.create(**request) + response = await self._client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - return await self._get_client(provider_model_id).completions.create(**params) + return await self._client.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - return await self._get_client(provider_model_id).chat.completions.create(**params) + return await self._client.chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 97ca02fba..f75b0add9 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase): self.mock_client.chat.completions.create = unittest.mock.AsyncMock() self.inference_mock_make_request = self.mock_client.chat.completions.create self.inference_make_request_patcher = patch( - "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client", + new_callable=unittest.mock.PropertyMock, return_value=self.mock_client, ) self.inference_make_request_patcher.start()