From 9845631d5187be1b9e9adb2c57e199a7255e3436 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 18 Apr 2025 04:16:43 -0400 Subject: [PATCH] feat: update nvidia inference provider to use model_store (#1988) # What does this PR do? NVIDIA Inference provider was using the ModelRegistryHelper to map input model ids to provider model ids. this updates it to use the model_store. ## Test Plan `LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest -v tests/integration/inference/{test_embedding.py,test_text_inference.py,test_openai_completion.py} --embedding-model nvidia/llama-3.2-nv-embedqa-1b-v2 --text-model=meta-llama/Llama-3.1-70B-Instruct` --- .../remote/inference/nvidia/nvidia.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 15f0e72a1..c91b4d768 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -126,6 +126,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return _get_client_for_base_url(base_url) + async def _get_provider_model_id(self, model_id: str) -> str: + if not self.model_store: + raise RuntimeError("Model store is not set") + model = await self.model_store.get_model(model_id) + if model is None: + raise ValueError(f"Model {model_id} is unknown") + return model.provider_model_id + async def completion( self, model_id: str, @@ -144,7 +152,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # removing this health check as NeMo customizer endpoint health check is returning 404 # await check_health(self._config) # this raises errors - provider_model_id = self.get_provider_model_id(model_id) + provider_model_id = await self._get_provider_model_id(model_id) request = convert_completion_request( request=CompletionRequest( model=provider_model_id, @@ -188,7 +196,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents] input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] - model = self.get_provider_model_id(model_id) + provider_model_id = await self._get_provider_model_id(model_id) extra_body = {} @@ -211,8 +219,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._get_client(model).embeddings.create( - model=model, + response = await self._get_client(provider_model_id).embeddings.create( + model=provider_model_id, input=input, extra_body=extra_body, ) @@ -246,10 +254,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # await check_health(self._config) # this raises errors - provider_model_id = self.get_provider_model_id(model_id) + provider_model_id = await self._get_provider_model_id(model_id) request = await convert_chat_completion_request( request=ChatCompletionRequest( - model=self.get_provider_model_id(model_id), + model=provider_model_id, messages=messages, sampling_params=sampling_params, response_format=response_format, @@ -294,7 +302,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: - provider_model_id = self.get_provider_model_id(model) + provider_model_id = await self._get_provider_model_id(model) params = await prepare_openai_completion_params( model=provider_model_id, @@ -347,7 +355,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): top_p: Optional[float] = None, user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: - provider_model_id = self.get_provider_model_id(model) + provider_model_id = await self._get_provider_model_id(model) params = await prepare_openai_completion_params( model=provider_model_id,