chore: remove vision model URL workarounds and simplify client creation

The vision models are now available at the standard URL, so the workaround
code has been removed. This also simplifies the codebase by eliminating
the need for per-model client caching.

- Remove special URL handling for meta/llama-3.2-11b/90b-vision-instruct models
- Convert _get_client method to _client property for cleaner API
- Remove unnecessary lru_cache decorator and functools import
- Simplify client creation logic to use single base URL for all models
This commit is contained in:
Matthew Farrellee 2025-07-16 05:21:25 -04:00
parent 95fdc8ea94
commit 8cc3fe7669
2 changed files with 15 additions and 35 deletions

View file

@ -7,7 +7,6 @@
import logging
import warnings
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -93,42 +92,22 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
@lru_cache # noqa: B019
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
@property
def _client(self) -> AsyncOpenAI:
"""
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
Returns an OpenAI client for the configured NVIDIA API endpoint.
This relies on lru_cache and self._default_client to avoid creating a new client for each request
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
:param provider_model_id: The provider model ID
:return: An OpenAI client
"""
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).completions.create(**request)
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(provider_model_id).embeddings.create(
response = await self._client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).chat.completions.create(**request)
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).completions.create(**params)
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).chat.completions.create(**params)
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client",
new_callable=unittest.mock.PropertyMock,
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()