From 53a1698ec3e35f3dbdb02fc7bfaed716ac1fee71 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 26 Feb 2025 14:10:26 -0500 Subject: [PATCH] support nvidia hosted vision models these models are not hosted on the common https://integrate.api.nvidia.com/v1. they are hosted on their own individual urls. --- .../remote/inference/nvidia/nvidia.py | 53 +++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index cc3bd85bb..837496ee1 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -6,6 +6,7 @@ import logging import warnings +from functools import lru_cache from typing import AsyncIterator, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI @@ -82,12 +83,42 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # ) self._config = config - # make sure the client lives longer than any async calls - self._client = AsyncOpenAI( - base_url=f"{self._config.url}/v1", - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) + + @lru_cache # noqa: B019 + def _get_client(self, provider_model_id: str) -> AsyncOpenAI: + """ + For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However, + some models are hosted on different URLs. This function returns the appropriate client + for the given provider_model_id. + + This relies on lru_cache and self._default_client to avoid creating a new client for each request + or for each model that is hosted on https://integrate.api.nvidia.com/v1. + + :param provider_model_id: The provider model ID + :return: An OpenAI client + """ + + @lru_cache # noqa: B019 + def _get_client_for_base_url(base_url: str) -> AsyncOpenAI: + """ + Maintain a single OpenAI client per base_url. + """ + return AsyncOpenAI( + base_url=base_url, + api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), + timeout=self._config.timeout, + ) + + special_model_urls = { + "meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct", + "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", + } + + base_url = f"{self._config.url}/v1" + if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: + base_url = special_model_urls[provider_model_id] + + return _get_client_for_base_url(base_url) async def completion( self, @@ -103,9 +134,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): await check_health(self._config) # this raises errors + provider_model_id = self.get_provider_model_id(model_id) request = convert_completion_request( request=CompletionRequest( - model=self.get_provider_model_id(model_id), + model=provider_model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -116,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.completions.create(**request) + response = await self._get_client(provider_model_id).completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -152,7 +184,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] model = self.get_provider_model_id(model_id) - response = await self._client.embeddings.create( + response = await self._get_client(model).embeddings.create( model=model, input=input, # extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent? @@ -183,6 +215,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): await check_health(self._config) # this raises errors + provider_model_id = self.get_provider_model_id(model_id) request = await convert_chat_completion_request( request=ChatCompletionRequest( model=self.get_provider_model_id(model_id), @@ -198,7 +231,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.chat.completions.create(**request) + response = await self._get_client(provider_model_id).chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e