diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 15f0e72a1..c91b4d768 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -126,6 +126,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return _get_client_for_base_url(base_url) + async def _get_provider_model_id(self, model_id: str) -> str: + if not self.model_store: + raise RuntimeError("Model store is not set") + model = await self.model_store.get_model(model_id) + if model is None: + raise ValueError(f"Model {model_id} is unknown") + return model.provider_model_id + async def completion( self, model_id: str, @@ -144,7 +152,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # removing this health check as NeMo customizer endpoint health check is returning 404 # await check_health(self._config) # this raises errors - provider_model_id = self.get_provider_model_id(model_id) + provider_model_id = await self._get_provider_model_id(model_id) request = convert_completion_request( request=CompletionRequest( model=provider_model_id, @@ -188,7 +196,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents] input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] - model = self.get_provider_model_id(model_id) + provider_model_id = await self._get_provider_model_id(model_id) extra_body = {} @@ -211,8 +219,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._get_client(model).embeddings.create( - model=model, + response = await self._get_client(provider_model_id).embeddings.create( + model=provider_model_id, input=input, extra_body=extra_body, ) @@ -246,10 +254,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # await check_health(self._config) # this raises errors - provider_model_id = self.get_provider_model_id(model_id) + provider_model_id = await self._get_provider_model_id(model_id) request = await convert_chat_completion_request( request=ChatCompletionRequest( - model=self.get_provider_model_id(model_id), + model=provider_model_id, messages=messages, sampling_params=sampling_params, response_format=response_format, @@ -294,7 +302,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: - provider_model_id = self.get_provider_model_id(model) + provider_model_id = await self._get_provider_model_id(model) params = await prepare_openai_completion_params( model=provider_model_id, @@ -347,7 +355,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): top_p: Optional[float] = None, user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: - provider_model_id = self.get_provider_model_id(model) + provider_model_id = await self._get_provider_model_id(model) params = await prepare_openai_completion_params( model=provider_model_id,