use convert_image_content_to_url

This commit is contained in:
Matthew Farrellee 2025-01-31 11:18:05 -05:00
parent cef35bb3db
commit bb2143632b
3 changed files with 18 additions and 32 deletions

View file

@ -186,7 +186,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors
request = convert_chat_completion_request( request = await convert_chat_completion_request(
request=ChatCompletionRequest( request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id), model=self.get_provider_model_id(model_id),
messages=messages, messages=messages,

View file

@ -4,10 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import json import json
import warnings import warnings
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_models.datatypes import ( from llama_models.datatypes import (
@ -46,8 +44,6 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.completion import Completion as OpenAICompletion from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from PIL import Image
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -74,6 +70,10 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
""" """
@ -151,7 +151,7 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out return out
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
""" """
Convert a Message to an OpenAI API-compatible dictionary. Convert a Message to an OpenAI API-compatible dictionary.
""" """
@ -177,36 +177,21 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
# List[...] -> List[...] # List[...] -> List[...]
def _convert_user_message_content( async def _convert_user_message_content(
content: InterleavedContent, content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str) or isinstance(content, TextContentItem): if isinstance(content, str) or isinstance(content, TextContentItem):
return content return content
elif isinstance(content, ImageContentItem): elif isinstance(content, ImageContentItem):
if content.image.url: return OpenAIChatCompletionContentPartImageParam(
return OpenAIChatCompletionContentPartImageParam( image_url=OpenAIImageURL(
image_url=OpenAIImageURL(url=content.image.url.uri), url=await convert_image_content_to_url(content)
type="image_url", ),
) type="image_url",
elif content.image.data: )
mime_type = Image.MIME[
Image.open(
BytesIO(
base64.b64decode(
content.image.data
) # TODO(mf): do this more efficiently, decode less
)
).format
]
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(
url=f"data:{mime_type};base64,{content.image.data}"
),
type="image_url",
)
elif isinstance(content, List): elif isinstance(content, List):
return [_convert_user_message_content(item) for item in content] return [await _convert_user_message_content(item) for item in content]
else: else:
raise ValueError(f"Unsupported content type: {type(content)}") raise ValueError(f"Unsupported content type: {type(content)}")
@ -214,7 +199,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
if isinstance(message, UserMessage): if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage( out = OpenAIChatCompletionUserMessage(
role="user", role="user",
content=_convert_user_message_content(message.content), content=await _convert_user_message_content(message.content),
) )
elif isinstance(message, CompletionMessage): elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage( out = OpenAIChatCompletionAssistantMessage(
@ -249,7 +234,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
return out return out
def convert_chat_completion_request( async def convert_chat_completion_request(
request: ChatCompletionRequest, request: ChatCompletionRequest,
n: int = 1, n: int = 1,
) -> dict: ) -> dict:
@ -286,7 +271,7 @@ def convert_chat_completion_request(
nvext = {} nvext = {}
payload: Dict[str, Any] = dict( payload: Dict[str, Any] = dict(
model=request.model, model=request.model,
messages=[_convert_message(message) for message in request.messages], messages=[await _convert_message(message) for message in request.messages],
stream=request.stream, stream=request.stream,
n=n, n=n,
extra_body=dict(nvext=nvext), extra_body=dict(nvext=nvext),

View file

@ -186,6 +186,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
return content, format return content, format
else: else:
# data is a base64 encoded string, decode it to bytes first # data is a base64 encoded string, decode it to bytes first
# TODO(mf): do this more efficiently, decode less
data_bytes = base64.b64decode(image.data) data_bytes = base64.b64decode(image.data)
pil_image = PIL_Image.open(io.BytesIO(data_bytes)) pil_image = PIL_Image.open(io.BytesIO(data_bytes))
return data_bytes, pil_image.format return data_bytes, pil_image.format