mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
use convert_image_content_to_url
This commit is contained in:
parent
cef35bb3db
commit
bb2143632b
3 changed files with 18 additions and 32 deletions
|
@ -186,7 +186,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
request = convert_chat_completion_request(
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
messages=messages,
|
||||
|
|
|
@ -4,10 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import (
|
||||
|
@ -46,8 +44,6 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||
from openai.types.completion import Completion as OpenAICompletion
|
||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
@ -74,6 +70,10 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
)
|
||||
|
||||
|
||||
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||
"""
|
||||
|
@ -151,7 +151,7 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||
async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
|
@ -177,36 +177,21 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
|||
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
|
||||
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
|
||||
# List[...] -> List[...]
|
||||
def _convert_user_message_content(
|
||||
async def _convert_user_message_content(
|
||||
content: InterleavedContent,
|
||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content, str) or isinstance(content, TextContentItem):
|
||||
return content
|
||||
elif isinstance(content, ImageContentItem):
|
||||
if content.image.url:
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
image_url=OpenAIImageURL(url=content.image.url.uri),
|
||||
type="image_url",
|
||||
)
|
||||
elif content.image.data:
|
||||
mime_type = Image.MIME[
|
||||
Image.open(
|
||||
BytesIO(
|
||||
base64.b64decode(
|
||||
content.image.data
|
||||
) # TODO(mf): do this more efficiently, decode less
|
||||
)
|
||||
).format
|
||||
]
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
image_url=OpenAIImageURL(
|
||||
url=f"data:{mime_type};base64,{content.image.data}"
|
||||
),
|
||||
type="image_url",
|
||||
)
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
image_url=OpenAIImageURL(
|
||||
url=await convert_image_content_to_url(content)
|
||||
),
|
||||
type="image_url",
|
||||
)
|
||||
elif isinstance(content, List):
|
||||
return [_convert_user_message_content(item) for item in content]
|
||||
return [await _convert_user_message_content(item) for item in content]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
|
||||
|
@ -214,7 +199,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
|||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
content=_convert_user_message_content(message.content),
|
||||
content=await _convert_user_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
|
@ -249,7 +234,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
|||
return out
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
async def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
|
@ -286,7 +271,7 @@ def convert_chat_completion_request(
|
|||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
messages=[await _convert_message(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
n=n,
|
||||
extra_body=dict(nvext=nvext),
|
||||
|
|
|
@ -186,6 +186,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
|||
return content, format
|
||||
else:
|
||||
# data is a base64 encoded string, decode it to bytes first
|
||||
# TODO(mf): do this more efficiently, decode less
|
||||
data_bytes = base64.b64decode(image.data)
|
||||
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
|
||||
return data_bytes, pil_image.format
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue