From e21c8b6d8079c282568ed6fc4bfcf6a0f963d4db Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 1 Feb 2025 12:02:27 -0500 Subject: [PATCH] add image support to NVIDIA inference provider (#907) # What does this PR do? add support to the NVIDIA Inference provider for image inputs ## Test Plan 1. Run local [Llama 3.2 11b vision instruct](https://build.nvidia.com/meta/llama-3.2-11b-vision-instruct?snippet_tab=Docker) NIM 2. Start a stack, e.g. `llama stack run llama_stack/templates/nvidia/run.yaml --env NVIDIA_BASE_URL=http://localhost:8000` 3. Run image tests, e.g. `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_inference.py --vision-inference-model meta-llama/Llama-3.2-11B-Vision-Instruct -k image` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --- .../remote/inference/nvidia/nvidia.py | 2 +- .../remote/inference/nvidia/openai_utils.py | 46 +++++++++++++++++-- .../utils/inference/prompt_adapter.py | 1 + 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 81751e038..1395caf69 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -186,7 +186,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): await check_health(self._config) # this raises errors - request = convert_chat_completion_request( + request = await convert_chat_completion_request( request=ChatCompletionRequest( model=self.get_provider_model_id(model_id), messages=messages, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 43be0fc94..40228a4da 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -6,7 +6,7 @@ import json import warnings -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional +from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from llama_models.datatypes import ( GreedySamplingStrategy, @@ -23,6 +23,8 @@ from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionChunk as OpenAIChatCompletionChunk, + ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ChatCompletionMessageParam as OpenAIChatCompletionMessage, ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, @@ -33,6 +35,9 @@ from openai.types.chat.chat_completion import ( Choice as OpenAIChoice, ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) from openai.types.chat.chat_completion_message_tool_call_param import ( Function as OpenAIFunction, ) @@ -40,6 +45,9 @@ from openai.types.completion import Completion as OpenAICompletion from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from llama_stack.apis.common.content_types import ( + ImageContentItem, + InterleavedContent, + TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, @@ -62,6 +70,10 @@ from llama_stack.apis.inference import ( UserMessage, ) +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_image_content_to_url, +) + def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: """ @@ -139,7 +151,7 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: return out -def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: +async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. """ @@ -159,11 +171,35 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: else: raise ValueError(f"Unsupported message role: {message['role']}") + # Map Llama Stack spec to OpenAI spec - + # str -> str + # {"type": "text", "text": ...} -> {"type": "text", "text": ...} + # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} + # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} + # List[...] -> List[...] + async def _convert_user_message_content( + content: InterleavedContent, + ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: + # Llama Stack and OpenAI spec match for str and text input + if isinstance(content, str) or isinstance(content, TextContentItem): + return content + elif isinstance(content, ImageContentItem): + return OpenAIChatCompletionContentPartImageParam( + image_url=OpenAIImageURL( + url=await convert_image_content_to_url(content) + ), + type="image_url", + ) + elif isinstance(content, List): + return [await _convert_user_message_content(item) for item in content] + else: + raise ValueError(f"Unsupported content type: {type(content)}") + out: OpenAIChatCompletionMessage = None if isinstance(message, UserMessage): out = OpenAIChatCompletionUserMessage( role="user", - content=message.content, # TODO(mf): handle image content + content=await _convert_user_message_content(message.content), ) elif isinstance(message, CompletionMessage): out = OpenAIChatCompletionAssistantMessage( @@ -198,7 +234,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: return out -def convert_chat_completion_request( +async def convert_chat_completion_request( request: ChatCompletionRequest, n: int = 1, ) -> dict: @@ -235,7 +271,7 @@ def convert_chat_completion_request( nvext = {} payload: Dict[str, Any] = dict( model=request.model, - messages=[_convert_message(message) for message in request.messages], + messages=[await _convert_message(message) for message in request.messages], stream=request.stream, n=n, extra_body=dict(nvext=nvext), diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index e49771980..89a41e97d 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -186,6 +186,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: return content, format else: # data is a base64 encoded string, decode it to bytes first + # TODO(mf): do this more efficiently, decode less data_bytes = base64.b64decode(image.data) pil_image = PIL_Image.open(io.BytesIO(data_bytes)) return data_bytes, pil_image.format