mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 02:42:31 +00:00
chore: remove vision model URL workarounds and simplify client creation (#2775)
The vision models are now available at the standard URL, so the workaround code has been removed. This also simplifies the codebase by eliminating the need for per-model client caching. - Remove special URL handling for meta/llama-3.2-11b/90b-vision-instruct models - Convert _get_client method to _client property for cleaner API - Remove unnecessary lru_cache decorator and functools import - Simplify client creation logic to use single base URL for all models
This commit is contained in:
parent
fa1bb9ae00
commit
a3e249807b
2 changed files with 15 additions and 35 deletions
|
@ -7,7 +7,6 @@
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||||
|
@ -93,42 +92,22 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
@lru_cache # noqa: B019
|
@property
|
||||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
def _client(self) -> AsyncOpenAI:
|
||||||
"""
|
"""
|
||||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||||
some models are hosted on different URLs. This function returns the appropriate client
|
|
||||||
for the given provider_model_id.
|
|
||||||
|
|
||||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
|
||||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
|
||||||
|
|
||||||
:param provider_model_id: The provider model ID
|
|
||||||
:return: An OpenAI client
|
:return: An OpenAI client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@lru_cache # noqa: B019
|
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
|
||||||
"""
|
|
||||||
Maintain a single OpenAI client per base_url.
|
|
||||||
"""
|
|
||||||
return AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||||
timeout=self._config.timeout,
|
timeout=self._config.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
special_model_urls = {
|
|
||||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
|
||||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
|
||||||
}
|
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
|
||||||
base_url = special_model_urls[provider_model_id]
|
|
||||||
return _get_client_for_base_url(base_url)
|
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
if not self.model_store:
|
if not self.model_store:
|
||||||
raise RuntimeError("Model store is not set")
|
raise RuntimeError("Model store is not set")
|
||||||
|
@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
response = await self._client.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(provider_model_id).embeddings.create(
|
response = await self._client.embeddings.create(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
input=input,
|
input=input,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
|
@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
response = await self._client.chat.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self._get_client(provider_model_id).completions.create(**params)
|
return await self._client.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
return await self._client.chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||||
self.inference_make_request_patcher = patch(
|
self.inference_make_request_patcher = patch(
|
||||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client",
|
||||||
|
new_callable=unittest.mock.PropertyMock,
|
||||||
return_value=self.mock_client,
|
return_value=self.mock_client,
|
||||||
)
|
)
|
||||||
self.inference_make_request_patcher.start()
|
self.inference_make_request_patcher.start()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue