feat: support nvidia hosted vision models (llama 3.2 11b/90b) (#1278)

# What does this PR do?

support nvidia hosted 3.2 11b/90b vision models. they are not hosted on
the common https://integrate.api.nvidia.com/v1. they are hosted on their
own individual urls.

## Test Plan

`LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -s -v
tests/client-sdk/inference/test_vision_inference.py
--inference-model=meta/llama-3.2-11b-vision-instruct -k image`
This commit is contained in:
Matthew Farrellee 2025-03-18 13:54:10 -05:00 committed by GitHub
parent f4dc290705
commit 706b4ca651
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,7 @@
import logging
import warnings
from functools import lru_cache
from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -82,12 +83,42 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# )
self._config = config
# make sure the client lives longer than any async calls
self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1",
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
@lru_cache # noqa: B019
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
"""
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
This relies on lru_cache and self._default_client to avoid creating a new client for each request
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
:param provider_model_id: The provider model ID
:return: An OpenAI client
"""
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def completion(
self,
@ -105,9 +136,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
model=self.get_provider_model_id(model_id),
model=provider_model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -118,7 +150,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.completions.create(**request)
response = await self._get_client(provider_model_id).completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -206,6 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
@ -221,7 +254,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.chat.completions.create(**request)
response = await self._get_client(provider_model_id).chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e