diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 54d55e60e..c8d061f6c 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -150,4 +150,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[ + "openai", + ], + module="llama_stack.providers.remote.inference.nvidia", + config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/inference/nvidia/__init__.py b/llama_stack/providers/remote/inference/nvidia/__init__.py new file mode 100644 index 000000000..9c537d448 --- /dev/null +++ b/llama_stack/providers/remote/inference/nvidia/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import Inference + +from .config import NVIDIAConfig + + +async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference: + # import dynamically so `llama stack build` does not fail due to missing dependencies + from .nvidia import NVIDIAInferenceAdapter + + if not isinstance(config, NVIDIAConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + adapter = NVIDIAInferenceAdapter(config) + return adapter + + +__all__ = ["get_adapter_impl", "NVIDIAConfig"] diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py new file mode 100644 index 000000000..c50143043 --- /dev/null +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class NVIDIAConfig(BaseModel): + """ + Configuration for the NVIDIA NIM inference endpoint. + + Attributes: + url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 + api_key (str): The access key for the hosted NIM endpoints + + There are two ways to access NVIDIA NIMs - + 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com + 1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure + + By default the configuration is set to use the hosted APIs. This requires + an API key which can be obtained from https://ngc.nvidia.com/. + + By default the configuration will attempt to read the NVIDIA_API_KEY environment + variable to set the api_key. Please do not put your API key in code. + + If you are using a self-hosted NVIDIA NIM, you can set the url to the + URL of your running NVIDIA NIM and do not need to set the api_key. + """ + + url: str = Field( + default="https://integrate.api.nvidia.com", + description="A base url for accessing the NVIDIA NIM", + ) + api_key: Optional[str] = Field( + default_factory=lambda: os.getenv("NVIDIA_API_KEY"), + description="The NVIDIA API key, only needed of using the hosted service", + ) + timeout: int = Field( + default=60, + description="Timeout for the HTTP requests", + ) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py new file mode 100644 index 000000000..f38aa7112 --- /dev/null +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import AsyncIterator, List, Optional, Union + +from llama_models.datatypes import SamplingParams +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + Message, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_models.sku_list import CoreModelId +from openai import APIConnectionError, AsyncOpenAI + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + ResponseFormat, +) +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) + +from . import NVIDIAConfig +from .openai_utils import ( + convert_chat_completion_request, + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, +) +from .utils import _is_nvidia_hosted, check_health + +_MODEL_ALIASES = [ + build_model_alias( + "meta/llama3-8b-instruct", + CoreModelId.llama3_8b_instruct.value, + ), + build_model_alias( + "meta/llama3-70b-instruct", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "meta/llama-3.1-8b-instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "meta/llama-3.1-70b-instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "meta/llama-3.1-405b-instruct", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_alias( + "meta/llama-3.2-1b-instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_alias( + "meta/llama-3.2-3b-instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "meta/llama-3.2-11b-vision-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "meta/llama-3.2-90b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + # TODO(mf): how do we handle Nemotron models? + # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", +] + + +class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): + def __init__(self, config: NVIDIAConfig) -> None: + # TODO(mf): filter by available models + ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) + + print(f"Initializing NVIDIAInferenceAdapter({config.url})...") + + if _is_nvidia_hosted(config): + if not config.api_key: + raise RuntimeError( + "API key is required for hosted NVIDIA NIM. " + "Either provide an API key or use a self-hosted NIM." + ) + # elif self._config.api_key: + # + # we don't raise this warning because a user may have deployed their + # self-hosted NIM with an API key requirement. + # + # warnings.warn( + # "API key is not required for self-hosted NVIDIA NIM. " + # "Consider removing the api_key from the configuration." + # ) + + self._config = config + # make sure the client lives longer than any async calls + self._client = AsyncOpenAI( + base_url=f"{self._config.url}/v1", + api_key=self._config.api_key or "NO KEY", + timeout=self._config.timeout, + ) + + def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + raise NotImplementedError() + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ + ToolPromptFormat + ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: + if tool_prompt_format: + warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") + + await check_health(self._config) # this raises errors + + request = convert_chat_completion_request( + request=ChatCompletionRequest( + model=self.get_provider_model_id(model_id), + messages=messages, + sampling_params=sampling_params, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ), + n=1, + ) + + try: + response = await self._client.chat.completions.create(**request) + except APIConnectionError as e: + raise ConnectionError( + f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}" + ) from e + + if stream: + return convert_openai_chat_completion_stream(response) + else: + # we pass n=1 to get only one completion + return convert_openai_chat_completion_choice(response.choices[0]) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py new file mode 100644 index 000000000..b74aa05da --- /dev/null +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -0,0 +1,581 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import warnings +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional + +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + CompletionMessage, + StopReason, + TokenLogProbs, + ToolCall, + ToolDefinition, +) +from openai import AsyncStream + +from openai.types.chat import ( + ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, + ChatCompletionChunk as OpenAIChatCompletionChunk, + ChatCompletionMessageParam as OpenAIChatCompletionMessage, + ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) + +from openai.types.chat.chat_completion_message_tool_call_param import ( + Function as OpenAIFunction, +) + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + JsonSchemaResponseFormat, + Message, + SystemMessage, + ToolCallDelta, + ToolCallParseStatus, + ToolResponseMessage, + UserMessage, +) + + +def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: + """ + Convert a ToolDefinition to an OpenAI API-compatible dictionary. + + ToolDefinition: + tool_name: str | BuiltinTool + description: Optional[str] + parameters: Optional[Dict[str, ToolParamDefinition]] + + ToolParamDefinition: + param_type: str + description: Optional[str] + required: Optional[bool] + default: Optional[Any] + + + OpenAI spec - + + { + "type": "function", + "function": { + "name": tool_name, + "description": description, + "parameters": { + "type": "object", + "properties": { + param_name: { + "type": param_type, + "description": description, + "default": default, + }, + ... + }, + "required": [param_name, ...], + }, + }, + } + """ + out = { + "type": "function", + "function": {}, + } + function = out["function"] + + if isinstance(tool.tool_name, BuiltinTool): + function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient? + else: + function.update(name=tool.tool_name) + + if tool.description: + function.update(description=tool.description) + + if tool.parameters: + parameters = { + "type": "object", + "properties": {}, + } + properties = parameters["properties"] + required = [] + for param_name, param in tool.parameters.items(): + properties[param_name] = {"type": param.param_type} + if param.description: + properties[param_name].update(description=param.description) + if param.default: + properties[param_name].update(default=param.default) + if param.required: + required.append(param_name) + + if required: + parameters.update(required=required) + + function.update(parameters=parameters) + + return out + + +def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: + """ + Convert a Message to an OpenAI API-compatible dictionary. + """ + # users can supply a dict instead of a Message object, we'll + # convert it to a Message object and proceed with some type safety. + if isinstance(message, dict): + if "role" not in message: + raise ValueError("role is required in message") + if message["role"] == "user": + message = UserMessage(**message) + elif message["role"] == "assistant": + message = CompletionMessage(**message) + elif message["role"] == "ipython": + message = ToolResponseMessage(**message) + elif message["role"] == "system": + message = SystemMessage(**message) + else: + raise ValueError(f"Unsupported message role: {message['role']}") + + out: OpenAIChatCompletionMessage = None + if isinstance(message, UserMessage): + out = OpenAIChatCompletionUserMessage( + role="user", + content=message.content, # TODO(mf): handle image content + ) + elif isinstance(message, CompletionMessage): + out = OpenAIChatCompletionAssistantMessage( + role="assistant", + content=message.content, + tool_calls=[ + OpenAIChatCompletionMessageToolCall( + id=tool.call_id, + function=OpenAIFunction( + name=tool.tool_name, + arguments=json.dumps(tool.arguments), + ), + type="function", + ) + for tool in message.tool_calls + ], + ) + elif isinstance(message, ToolResponseMessage): + out = OpenAIChatCompletionToolMessage( + role="tool", + tool_call_id=message.call_id, + content=message.content, + ) + elif isinstance(message, SystemMessage): + out = OpenAIChatCompletionSystemMessage( + role="system", + content=message.content, + ) + else: + raise ValueError(f"Unsupported message type: {type(message)}") + + return out + + +def convert_chat_completion_request( + request: ChatCompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # messages -> messages + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # response_format -> GrammarResponseFormat TODO(mf) + # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema + # tools -> tools + # tool_choice ("auto", "required") -> tool_choice + # tool_prompt_format -> TBD + # stream -> stream + # logprobs -> logprobs + + if request.response_format and not isinstance( + request.response_format, JsonSchemaResponseFormat + ): + raise ValueError( + f"Unsupported response format: {request.response_format}. " + "Only JsonSchemaResponseFormat is supported." + ) + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + messages=[_convert_message(message) for message in request.messages], + stream=request.stream, + n=n, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + ) + + if request.response_format: + # server bug - setting guided_json changes the behavior of response_format resulting in an error + # payload.update(response_format="json_object") + nvext.update(guided_json=request.response_format.json_schema) + + if request.tools: + payload.update( + tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools] + ) + if request.tool_choice: + payload.update( + tool_choice=request.tool_choice.value + ) # we cannot include tool_choice w/o tools, server will complain + + if request.logprobs: + payload.update(logprobs=True) + payload.update(top_logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if ( + request.sampling_params.top_k != -1 + and request.sampling_params.top_k < 1 + ): + warnings.warn("top_k must be -1 or >= 1") + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _convert_openai_finish_reason(finish_reason: str) -> StopReason: + """ + Convert an OpenAI chat completion finish_reason to a StopReason. + + finish_reason: Literal["stop", "length", "tool_calls", ...] + - stop: model hit a natural stop point or a provided stop sequence + - length: maximum number of tokens specified in the request was reached + - tool_calls: model called a tool + + -> + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + + # TODO(mf): are end_of_turn and end_of_message semantics correct? + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + +def _convert_openai_tool_calls( + tool_calls: List[OpenAIChatCompletionMessageToolCall], +) -> List[ToolCall]: + """ + Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. + + OpenAI ChatCompletionMessageToolCall: + id: str + function: Function + type: Literal["function"] + + OpenAI Function: + arguments: str + name: str + + -> + + ToolCall: + call_id: str + tool_name: str + arguments: Dict[str, ...] + """ + if not tool_calls: + return [] # CompletionMessage tool_calls is not optional + + return [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=json.loads(call.function.arguments), + ) + for call in tool_calls + ] + + +def _convert_openai_logprobs( + logprobs: OpenAIChoiceLogprobs, +) -> Optional[List[TokenLogProbs]]: + """ + Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. + + OpenAI ChoiceLogprobs: + content: Optional[List[ChatCompletionTokenLogprob]] + + OpenAI ChatCompletionTokenLogprob: + token: str + logprob: float + top_logprobs: List[TopLogprob] + + OpenAI TopLogprob: + token: str + logprob: float + + -> + + TokenLogProbs: + logprobs_by_token: Dict[str, float] + - token, logprob + + """ + if not logprobs: + return None + + return [ + TokenLogProbs( + logprobs_by_token={ + logprobs.token: logprobs.logprob for logprobs in content.top_logprobs + } + ) + for content in logprobs.content + ] + + +def convert_openai_chat_completion_choice( + choice: OpenAIChoice, +) -> ChatCompletionResponse: + """ + Convert an OpenAI Choice into a ChatCompletionResponse. + + OpenAI Choice: + message: ChatCompletionMessage + finish_reason: str + logprobs: Optional[ChoiceLogprobs] + + OpenAI ChatCompletionMessage: + role: Literal["assistant"] + content: Optional[str] + tool_calls: Optional[List[ChatCompletionMessageToolCall]] + + -> + + ChatCompletionResponse: + completion_message: CompletionMessage + logprobs: Optional[List[TokenLogProbs]] + + CompletionMessage: + role: Literal["assistant"] + content: str | ImageMedia | List[str | ImageMedia] + stop_reason: StopReason + tool_calls: List[ToolCall] + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + assert ( + hasattr(choice, "message") and choice.message + ), "error in server response: message not found" + assert ( + hasattr(choice, "finish_reason") and choice.finish_reason + ), "error in server response: finish_reason not found" + + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content + or "", # CompletionMessage content is not optional + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), + ), + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + + +async def convert_openai_chat_completion_stream( + stream: AsyncStream[OpenAIChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """ + Convert a stream of OpenAI chat completion chunks into a stream + of ChatCompletionResponseStreamChunk. + + OpenAI ChatCompletionChunk: + choices: List[Choice] + + OpenAI Choice: # different from the non-streamed Choice + delta: ChoiceDelta + finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] + logprobs: Optional[ChoiceLogprobs] + + OpenAI ChoiceDelta: + content: Optional[str] + role: Optional[Literal["system", "user", "assistant", "tool"]] + tool_calls: Optional[List[ChoiceDeltaToolCall]] + + OpenAI ChoiceDeltaToolCall: + index: int + id: Optional[str] + function: Optional[ChoiceDeltaToolCallFunction] + type: Optional[Literal["function"]] + + OpenAI ChoiceDeltaToolCallFunction: + name: Optional[str] + arguments: Optional[str] + + -> + + ChatCompletionResponseStreamChunk: + event: ChatCompletionResponseEvent + + ChatCompletionResponseEvent: + event_type: ChatCompletionResponseEventType + delta: Union[str, ToolCallDelta] + logprobs: Optional[List[TokenLogProbs]] + stop_reason: Optional[StopReason] + + ChatCompletionResponseEventType: + start = "start" + progress = "progress" + complete = "complete" + + ToolCallDelta: + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + ToolCall: + call_id: str + tool_name: str + arguments: str + + ToolCallParseStatus: + started = "started" + in_progress = "in_progress" + failure = "failure" + success = "success" + + TokenLogProbs: + logprobs_by_token: Dict[str, float] + - token, logprob + + StopReason: + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + + # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... + def _event_type_generator() -> ( + Generator[ChatCompletionResponseEventType, None, None] + ): + yield ChatCompletionResponseEventType.start + while True: + yield ChatCompletionResponseEventType.progress + + event_type = _event_type_generator() + + # we implement NIM specific semantics, the main difference from OpenAI + # is that tool_calls are always produced as a complete call. there is no + # intermediate / partial tool call streamed. because of this, we can + # simplify the logic and not concern outselves with parse_status of + # started/in_progress/failed. we can always assume success. + # + # a stream of ChatCompletionResponseStreamChunk consists of + # 0. a start event + # 1. zero or more progress events + # - each progress event has a delta + # - each progress event may have a stop_reason + # - each progress event may have logprobs + # - each progress event may have tool_calls + # if a progress event has tool_calls, + # it is fully formed and + # can be emitted with a parse_status of success + # 2. a complete event + + stop_reason = None + + async for chunk in stream: + choice = chunk.choices[0] # assuming only one choice per chunk + + # we assume there's only one finish_reason in the stream + stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason + + # if there's a tool call, emit an event for each tool in the list + # if tool call and content, emit both separately + + if choice.delta.tool_calls: + # the call may have content and a tool call. ChatCompletionResponseEvent + # does not support both, so we emit the content first + if choice.delta.content: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=choice.delta.content, + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + ) + + # it is possible to have parallel tool calls in stream, but + # ChatCompletionResponseEvent only supports one per stream + if len(choice.delta.tool_calls) > 1: + warnings.warn( + "multiple tool calls found in a single delta, using the first, ignoring the rest" + ) + + # NIM only produces fully formed tool calls, so we can assume success + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=ToolCallDelta( + content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], + parse_status=ToolCallParseStatus.success, + ), + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=choice.delta.content or "", # content is not optional + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py new file mode 100644 index 000000000..0ec80e9dd --- /dev/null +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Tuple + +import httpx + +from . import NVIDIAConfig + + +def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: + return "integrate.api.nvidia.com" in config.url + + +async def _get_health(url: str) -> Tuple[bool, bool]: + """ + Query {url}/v1/health/{live,ready} to check if the server is running and ready + + Args: + url (str): URL of the server + + Returns: + Tuple[bool, bool]: (is_live, is_ready) + """ + async with httpx.AsyncClient() as client: + live = await client.get(f"{url}/v1/health/live") + ready = await client.get(f"{url}/v1/health/ready") + return live.status_code == 200, ready.status_code == 200 + + +async def check_health(config: NVIDIAConfig) -> None: + """ + Check if the server is running and ready + + Args: + url (str): URL of the server + + Raises: + RuntimeError: If the server is not running or ready + """ + if not _is_nvidia_hosted(config): + print("Checking NVIDIA NIM health...") + try: + is_live, is_ready = await _get_health(config.url) + if not is_live: + raise ConnectionError("NVIDIA NIM is not running") + if not is_ready: + raise ConnectionError("NVIDIA NIM is not ready") + # TODO(mf): should we wait for the server to be ready? + except httpx.ConnectError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index d013d6a9e..7fe19b403 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -6,6 +6,8 @@ import pytest +from ..conftest import get_provider_fixture_overrides + from .fixtures import INFERENCE_FIXTURES @@ -67,11 +69,12 @@ def pytest_generate_tests(metafunc): indirect=True, ) if "inference_stack" in metafunc.fixturenames: - metafunc.parametrize( - "inference_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in INFERENCE_FIXTURES - ], - indirect=True, - ) + fixtures = INFERENCE_FIXTURES + if filtered_stacks := get_provider_fixture_overrides( + metafunc.config, + { + "inference": INFERENCE_FIXTURES, + }, + ): + fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] + metafunc.parametrize("inference_stack", fixtures, indirect=True) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a53ddf639..2007818e5 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig @@ -142,6 +143,19 @@ def inference_bedrock() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_nvidia() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAConfig().model_dump(), + ) + ], + ) + + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -175,6 +189,7 @@ INFERENCE_FIXTURES = [ "vllm_remote", "remote", "bedrock", + "nvidia", ] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 1a7f1870c..f0f1d0eb2 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -198,6 +198,7 @@ class TestInference: "remote::fireworks", "remote::tgi", "remote::together", + "remote::nvidia", ): pytest.skip("Other inference providers don't support structured output yet") @@ -361,7 +362,10 @@ class TestInference: for chunk in grouped[ChatCompletionResponseEventType.progress] ) first = grouped[ChatCompletionResponseEventType.progress][0] - assert first.event.delta.parse_status == ToolCallParseStatus.started + if not isinstance( + first.event.delta.content, ToolCall + ): # first chunk may contain entire call + assert first.event.delta.parse_status == ToolCallParseStatus.started last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 07225fac0..8dbfab14a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -29,7 +29,6 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli return ModelAlias( provider_model_id=provider_model_id, aliases=[ - model_descriptor, get_huggingface_repo(model_descriptor), ], llama_model=model_descriptor, @@ -57,6 +56,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[alias_obj.provider_model_id] = ( alias_obj.provider_model_id ) + # ensure we can go from llama model to provider model id + self.alias_to_provider_id_map[alias_obj.llama_model] = ( + alias_obj.provider_model_id + ) self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = ( alias_obj.llama_model )