add image support to NVIDIA inference provider (#907)

# What does this PR do?

add support to the NVIDIA Inference provider for image inputs


## Test Plan

1. Run local [Llama 3.2 11b vision
instruct](https://build.nvidia.com/meta/llama-3.2-11b-vision-instruct?snippet_tab=Docker)
NIM
2. Start a stack, e.g. `llama stack run
llama_stack/templates/nvidia/run.yaml --env
NVIDIA_BASE_URL=http://localhost:8000`
3. Run image tests, e.g. `LLAMA_STACK_BASE_URL=http://localhost:8321
pytest -v tests/client-sdk/inference/test_inference.py
--vision-inference-model meta-llama/Llama-3.2-11B-Vision-Instruct -k
image`


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.
This commit is contained in:
Matthew Farrellee 2025-02-01 12:02:27 -05:00 committed by GitHub
parent 439d0da84c
commit e21c8b6d80
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 43 additions and 6 deletions

View file

@ -186,7 +186,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors
request = convert_chat_completion_request( request = await convert_chat_completion_request(
request=ChatCompletionRequest( request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id), model=self.get_provider_model_id(model_id),
messages=messages, messages=messages,

View file

@ -6,7 +6,7 @@
import json import json
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_models.datatypes import ( from llama_models.datatypes import (
GreedySamplingStrategy, GreedySamplingStrategy,
@ -23,6 +23,8 @@ from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
ChatCompletionMessageParam as OpenAIChatCompletionMessage, ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
@ -33,6 +35,9 @@ from openai.types.chat.chat_completion import (
Choice as OpenAIChoice, Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
) )
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import ( from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction, Function as OpenAIFunction,
) )
@ -40,6 +45,9 @@ from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
@ -62,6 +70,10 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
""" """
@ -139,7 +151,7 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out return out
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
""" """
Convert a Message to an OpenAI API-compatible dictionary. Convert a Message to an OpenAI API-compatible dictionary.
""" """
@ -159,11 +171,35 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
else: else:
raise ValueError(f"Unsupported message role: {message['role']}") raise ValueError(f"Unsupported message role: {message['role']}")
# Map Llama Stack spec to OpenAI spec -
# str -> str
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
async def _convert_user_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str) or isinstance(content, TextContentItem):
return content
elif isinstance(content, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(
url=await convert_image_content_to_url(content)
),
type="image_url",
)
elif isinstance(content, List):
return [await _convert_user_message_content(item) for item in content]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
out: OpenAIChatCompletionMessage = None out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage): if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage( out = OpenAIChatCompletionUserMessage(
role="user", role="user",
content=message.content, # TODO(mf): handle image content content=await _convert_user_message_content(message.content),
) )
elif isinstance(message, CompletionMessage): elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage( out = OpenAIChatCompletionAssistantMessage(
@ -198,7 +234,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
return out return out
def convert_chat_completion_request( async def convert_chat_completion_request(
request: ChatCompletionRequest, request: ChatCompletionRequest,
n: int = 1, n: int = 1,
) -> dict: ) -> dict:
@ -235,7 +271,7 @@ def convert_chat_completion_request(
nvext = {} nvext = {}
payload: Dict[str, Any] = dict( payload: Dict[str, Any] = dict(
model=request.model, model=request.model,
messages=[_convert_message(message) for message in request.messages], messages=[await _convert_message(message) for message in request.messages],
stream=request.stream, stream=request.stream,
n=n, n=n,
extra_body=dict(nvext=nvext), extra_body=dict(nvext=nvext),

View file

@ -186,6 +186,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
return content, format return content, format
else: else:
# data is a base64 encoded string, decode it to bytes first # data is a base64 encoded string, decode it to bytes first
# TODO(mf): do this more efficiently, decode less
data_bytes = base64.b64decode(image.data) data_bytes = base64.b64decode(image.data)
pil_image = PIL_Image.open(io.BytesIO(data_bytes)) pil_image = PIL_Image.open(io.BytesIO(data_bytes))
return data_bytes, pil_image.format return data_bytes, pil_image.format