# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import warnings from typing import AsyncIterator, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, Inference, InterleavedContent, LogProbConfig, Message, ResponseFormat, ToolChoice, ToolConfig, ) from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from . import NVIDIAConfig from .models import _MODEL_ALIASES from .openai_utils import ( convert_chat_completion_request, convert_completion_request, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, convert_openai_completion_choice, convert_openai_completion_stream, ) from .utils import _is_nvidia_hosted, check_health class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") if _is_nvidia_hosted(config): if not config.api_key: raise RuntimeError( "API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM." ) # elif self._config.api_key: # # we don't raise this warning because a user may have deployed their # self-hosted NIM with an API key requirement. # # warnings.warn( # "API key is not required for self-hosted NVIDIA NIM. " # "Consider removing the api_key from the configuration." # ) self._config = config # make sure the client lives longer than any async calls self._client = AsyncOpenAI( base_url=f"{self._config.url}/v1", api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), timeout=self._config.timeout, ) async def completion( self, model_id: str, content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: if content_has_media(content): raise NotImplementedError("Media is not supported") await check_health(self._config) # this raises errors request = convert_completion_request( request=CompletionRequest( model=self.get_provider_model_id(model_id), content=content, sampling_params=sampling_params, response_format=response_format, stream=stream, logprobs=logprobs, ), n=1, ) try: response = await self._client.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e if stream: return convert_openai_completion_stream(response) else: # we pass n=1 to get only one completion return convert_openai_completion_choice(response.choices[0]) async def embeddings( self, model_id: str, contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() async def chat_completion( self, model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: if tool_prompt_format: warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") await check_health(self._config) # this raises errors request = await convert_chat_completion_request( request=ChatCompletionRequest( model=self.get_provider_model_id(model_id), messages=messages, sampling_params=sampling_params, response_format=response_format, tools=tools, stream=stream, logprobs=logprobs, tool_config=tool_config, ), n=1, ) try: response = await self._client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e if stream: return convert_openai_chat_completion_stream(response) else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0])