forked from phoenix-oss/llama-stack-mirror
Enable vision models for (Together, Fireworks, Meta-Reference, Ollama) (#376)
* Enable vision models for Together and Fireworks * Works with ollama 0.4.0 pre-release with the vision model * localize media for meta_reference inference * Fix
This commit is contained in:
parent
db30809141
commit
cde9bc1388
11 changed files with 465 additions and 81 deletions
|
@ -3,10 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from PIL import Image as PIL_Image
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
@ -24,6 +30,90 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
def content_has_media(content: InterleavedTextMedia):
|
||||
def _has_media_content(c):
|
||||
return isinstance(c, ImageMedia)
|
||||
|
||||
if isinstance(content, list):
|
||||
return any(_has_media_content(c) for c in content)
|
||||
else:
|
||||
return _has_media_content(content)
|
||||
|
||||
|
||||
def messages_have_media(messages: List[Message]):
|
||||
return any(content_has_media(m.content) for m in messages)
|
||||
|
||||
|
||||
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
return messages_have_media(request.messages)
|
||||
else:
|
||||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def convert_image_media_to_url(
|
||||
media: ImageMedia, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if isinstance(media.image, PIL_Image.Image):
|
||||
if media.image.format == "PNG":
|
||||
format = "png"
|
||||
elif media.image.format == "GIF":
|
||||
format = "gif"
|
||||
elif media.image.format == "JPEG":
|
||||
format = "jpeg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported image format {media.image.format}")
|
||||
|
||||
bytestream = io.BytesIO()
|
||||
media.image.save(bytestream, format=media.image.format)
|
||||
bytestream.seek(0)
|
||||
content = bytestream.getvalue()
|
||||
else:
|
||||
if not download:
|
||||
return media.image.uri
|
||||
else:
|
||||
assert isinstance(media.image, URL)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.image.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
format = content_type.split("/")[-1]
|
||||
else:
|
||||
format = "png"
|
||||
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
else:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def convert_message_to_dict(message: Message) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageMedia):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_media_to_url(content),
|
||||
},
|
||||
}
|
||||
else:
|
||||
assert isinstance(content, str)
|
||||
return {"type": "text", "text": content}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content = [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
content = [await _convert_content(message.content)]
|
||||
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue