use convert_image_content_to_url

This commit is contained in:
Matthew Farrellee 2025-01-31 11:18:05 -05:00
parent cef35bb3db
commit bb2143632b
3 changed files with 18 additions and 32 deletions

View file

@ -186,7 +186,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors
request = convert_chat_completion_request(
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
messages=messages,

View file

@ -4,10 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import warnings
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_models.datatypes import (
@ -46,8 +44,6 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from PIL import Image
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -74,6 +70,10 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
@ -151,7 +151,7 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
@ -177,36 +177,21 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...]
def _convert_user_message_content(
async def _convert_user_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str) or isinstance(content, TextContentItem):
return content
elif isinstance(content, ImageContentItem):
if content.image.url:
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(url=content.image.url.uri),
type="image_url",
)
elif content.image.data:
mime_type = Image.MIME[
Image.open(
BytesIO(
base64.b64decode(
content.image.data
) # TODO(mf): do this more efficiently, decode less
)
).format
]
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(
url=f"data:{mime_type};base64,{content.image.data}"
),
type="image_url",
)
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(
url=await convert_image_content_to_url(content)
),
type="image_url",
)
elif isinstance(content, List):
return [_convert_user_message_content(item) for item in content]
return [await _convert_user_message_content(item) for item in content]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
@ -214,7 +199,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=_convert_user_message_content(message.content),
content=await _convert_user_message_content(message.content),
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
@ -249,7 +234,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
return out
def convert_chat_completion_request(
async def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
@ -286,7 +271,7 @@ def convert_chat_completion_request(
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
messages=[await _convert_message(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),

View file

@ -186,6 +186,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
return content, format
else:
# data is a base64 encoded string, decode it to bytes first
# TODO(mf): do this more efficiently, decode less
data_bytes = base64.b64decode(image.data)
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
return data_bytes, pil_image.format