mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 20:40:00 +00:00
chore: remove vision model URL workarounds and simplify client creation
The vision models are now available at the standard URL, so the workaround code has been removed. This also simplifies the codebase by eliminating the need for per-model client caching. - Remove special URL handling for meta/llama-3.2-11b/90b-vision-instruct models - Convert _get_client method to _client property for cleaner API - Remove unnecessary lru_cache decorator and functools import - Simplify client creation logic to use single base URL for all models
This commit is contained in:
parent
95fdc8ea94
commit
8cc3fe7669
2 changed files with 15 additions and 35 deletions
|
|
@ -7,7 +7,6 @@
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||||
|
|
@ -93,41 +92,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
@lru_cache # noqa: B019
|
@property
|
||||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
def _client(self) -> AsyncOpenAI:
|
||||||
"""
|
"""
|
||||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||||
some models are hosted on different URLs. This function returns the appropriate client
|
|
||||||
for the given provider_model_id.
|
|
||||||
|
|
||||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
|
||||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
|
||||||
|
|
||||||
:param provider_model_id: The provider model ID
|
|
||||||
:return: An OpenAI client
|
:return: An OpenAI client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@lru_cache # noqa: B019
|
|
||||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
|
||||||
"""
|
|
||||||
Maintain a single OpenAI client per base_url.
|
|
||||||
"""
|
|
||||||
return AsyncOpenAI(
|
|
||||||
base_url=base_url,
|
|
||||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
|
||||||
timeout=self._config.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
special_model_urls = {
|
|
||||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
|
||||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
|
||||||
}
|
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
return AsyncOpenAI(
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url=base_url,
|
||||||
return _get_client_for_base_url(base_url)
|
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||||
|
timeout=self._config.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
if not self.model_store:
|
if not self.model_store:
|
||||||
|
|
@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
response = await self._client.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
|
@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(provider_model_id).embeddings.create(
|
response = await self._client.embeddings.create(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
input=input,
|
input=input,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
|
|
@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
response = await self._client.chat.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
|
@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self._get_client(provider_model_id).completions.create(**params)
|
return await self._client.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
|
@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
return await self._client.chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||||
self.inference_make_request_patcher = patch(
|
self.inference_make_request_patcher = patch(
|
||||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client",
|
||||||
|
new_callable=unittest.mock.PropertyMock,
|
||||||
return_value=self.mock_client,
|
return_value=self.mock_client,
|
||||||
)
|
)
|
||||||
self.inference_make_request_patcher.start()
|
self.inference_make_request_patcher.start()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue