add image support to NVIDIA inference provider

This commit is contained in:
Matthew Farrellee 2025-01-30 14:22:41 -05:00
parent 7fe2592795
commit bcd14cc2d3

View file

@ -6,7 +6,7 @@
import json import json
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_models.datatypes import ( from llama_models.datatypes import (
GreedySamplingStrategy, GreedySamplingStrategy,
@ -23,6 +23,8 @@ from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
ChatCompletionMessageParam as OpenAIChatCompletionMessage, ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
@ -33,6 +35,9 @@ from openai.types.chat.chat_completion import (
Choice as OpenAIChoice, Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
) )
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import ( from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction, Function as OpenAIFunction,
) )
@ -40,6 +45,9 @@ from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
@ -159,11 +167,41 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
else: else:
raise ValueError(f"Unsupported message role: {message['role']}") raise ValueError(f"Unsupported message role: {message['role']}")
# Map Llama Stack spec to OpenAI spec -
# str -> str
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
def _convert_user_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str) or isinstance(content, TextContentItem):
return content
elif isinstance(content, ImageContentItem):
if content.image.url:
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(url=content.image.url.uri),
type="image_url",
)
elif content.image.data:
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(
url=f"data:image/png;base64,{content.image.data.decode()}" # TODO(mf): how do we know the type?
),
type="image_url",
)
elif isinstance(content, List):
return [_convert_user_message_content(item) for item in content]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
out: OpenAIChatCompletionMessage = None out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage): if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage( out = OpenAIChatCompletionUserMessage(
role="user", role="user",
content=message.content, # TODO(mf): handle image content content=_convert_user_message_content(message.content),
) )
elif isinstance(message, CompletionMessage): elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage( out = OpenAIChatCompletionAssistantMessage(