forked from phoenix-oss/llama-stack-mirror
feat: support nvidia hosted vision models (llama 3.2 11b/90b) (#1278)
# What does this PR do? support nvidia hosted 3.2 11b/90b vision models. they are not hosted on the common https://integrate.api.nvidia.com/v1. they are hosted on their own individual urls. ## Test Plan `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -s -v tests/client-sdk/inference/test_vision_inference.py --inference-model=meta/llama-3.2-11b-vision-instruct -k image`
This commit is contained in:
parent
f4dc290705
commit
706b4ca651
1 changed files with 42 additions and 9 deletions
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import lru_cache
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||||
|
@ -82,12 +83,42 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
# make sure the client lives longer than any async calls
|
|
||||||
self._client = AsyncOpenAI(
|
@lru_cache # noqa: B019
|
||||||
base_url=f"{self._config.url}/v1",
|
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
||||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
"""
|
||||||
timeout=self._config.timeout,
|
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||||
)
|
some models are hosted on different URLs. This function returns the appropriate client
|
||||||
|
for the given provider_model_id.
|
||||||
|
|
||||||
|
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||||
|
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||||
|
|
||||||
|
:param provider_model_id: The provider model ID
|
||||||
|
:return: An OpenAI client
|
||||||
|
"""
|
||||||
|
|
||||||
|
@lru_cache # noqa: B019
|
||||||
|
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||||
|
"""
|
||||||
|
Maintain a single OpenAI client per base_url.
|
||||||
|
"""
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||||
|
timeout=self._config.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
special_model_urls = {
|
||||||
|
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||||
|
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||||
|
}
|
||||||
|
|
||||||
|
base_url = f"{self._config.url}/v1"
|
||||||
|
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||||
|
base_url = special_model_urls[provider_model_id]
|
||||||
|
|
||||||
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -105,9 +136,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
|
provider_model_id = self.get_provider_model_id(model_id)
|
||||||
request = convert_completion_request(
|
request = convert_completion_request(
|
||||||
request=CompletionRequest(
|
request=CompletionRequest(
|
||||||
model=self.get_provider_model_id(model_id),
|
model=provider_model_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -118,7 +150,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.completions.create(**request)
|
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
@ -206,6 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
|
provider_model_id = self.get_provider_model_id(model_id)
|
||||||
request = await convert_chat_completion_request(
|
request = await convert_chat_completion_request(
|
||||||
request=ChatCompletionRequest(
|
request=ChatCompletionRequest(
|
||||||
model=self.get_provider_model_id(model_id),
|
model=self.get_provider_model_id(model_id),
|
||||||
|
@ -221,7 +254,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.chat.completions.create(**request)
|
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue