From a3e249807bfb6de7638ca7c3cc59a6f1780a49e3 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 16 Jul 2025 10:10:04 -0400 Subject: [PATCH] chore: remove vision model URL workarounds and simplify client creation (#2775) The vision models are now available at the standard URL, so the workaround code has been removed. This also simplifies the codebase by eliminating the need for per-model client caching. - Remove special URL handling for meta/llama-3.2-11b/90b-vision-instruct models - Convert _get_client method to _client property for cleaner API - Remove unnecessary lru_cache decorator and functools import - Simplify client creation logic to use single base URL for all models --- .../remote/inference/nvidia/nvidia.py | 47 +++++-------------- .../nvidia/test_supervised_fine_tuning.py | 3 +- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 1dd72da3f..f790c2312 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,6 @@ import logging import warnings from collections.abc import AsyncIterator -from functools import lru_cache from typing import Any from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -93,41 +92,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config - @lru_cache # noqa: B019 - def _get_client(self, provider_model_id: str) -> AsyncOpenAI: + @property + def _client(self) -> AsyncOpenAI: """ - For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However, - some models are hosted on different URLs. This function returns the appropriate client - for the given provider_model_id. + Returns an OpenAI client for the configured NVIDIA API endpoint. - This relies on lru_cache and self._default_client to avoid creating a new client for each request - or for each model that is hosted on https://integrate.api.nvidia.com/v1. - - :param provider_model_id: The provider model ID :return: An OpenAI client """ - @lru_cache # noqa: B019 - def _get_client_for_base_url(base_url: str) -> AsyncOpenAI: - """ - Maintain a single OpenAI client per base_url. - """ - return AsyncOpenAI( - base_url=base_url, - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) - - special_model_urls = { - "meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct", - "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", - } - base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url - if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: - base_url = special_model_urls[provider_model_id] - return _get_client_for_base_url(base_url) + return AsyncOpenAI( + base_url=base_url, + api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), + timeout=self._config.timeout, + ) async def _get_provider_model_id(self, model_id: str) -> str: if not self.model_store: @@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._get_client(provider_model_id).completions.create(**request) + response = await self._client.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._get_client(provider_model_id).embeddings.create( + response = await self._client.embeddings.create( model=provider_model_id, input=input, extra_body=extra_body, @@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._get_client(provider_model_id).chat.completions.create(**request) + response = await self._client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - return await self._get_client(provider_model_id).completions.create(**params) + return await self._client.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - return await self._get_client(provider_model_id).chat.completions.create(**params) + return await self._client.chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 97ca02fba..f75b0add9 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase): self.mock_client.chat.completions.create = unittest.mock.AsyncMock() self.inference_mock_make_request = self.mock_client.chat.completions.create self.inference_make_request_patcher = patch( - "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client", + new_callable=unittest.mock.PropertyMock, return_value=self.mock_client, ) self.inference_make_request_patcher.start()