localize media for meta_reference inference

This commit is contained in:
Ashwin Bharambe 2024-11-05 15:35:10 -08:00
parent d543eb442b
commit 2da16bf852
5 changed files with 80 additions and 41 deletions

View file

@ -84,14 +84,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
) -> CompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
r = await client.completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
@ -130,7 +130,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
if "messages" in params:
r = await client.chat.completions.acreate(**params)
else:
@ -140,7 +140,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
if "messages" in params:
stream = client.chat.completions.acreate(**params)
@ -152,7 +152,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
):
yield chunk
def _get_params(
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
input_dict = {}
@ -161,7 +161,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
convert_message_to_dict(m) for m in request.messages
await convert_message_to_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
from typing import AsyncGenerator
import httpx
@ -31,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
convert_image_media_to_url,
request_has_media,
)
@ -282,7 +281,11 @@ async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
if isinstance(content, ImageMedia):
return {
"role": message.role,
"images": [await convert_image_media_to_base64(content)],
"images": [
await convert_image_media_to_url(
content, download=True, include_format=False
)
],
}
else:
return {
@ -294,18 +297,3 @@ async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]
async def convert_image_media_to_base64(media: ImageMedia) -> str:
if isinstance(media.image, PIL_Image.Image):
bytestream = io.BytesIO()
media.image.save(bytestream, format=media.image.format)
bytestream.seek(0)
content = bytestream.getvalue()
else:
assert isinstance(media.image, URL)
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
content = r.content
return base64.b64encode(content).decode("utf-8")

View file

@ -99,12 +99,12 @@ class TogetherInferenceAdapter(
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_completion(request)
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
@ -165,7 +165,7 @@ class TogetherInferenceAdapter(
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
if "messages" in params:
r = self._get_client().chat.completions.create(**params)
else:
@ -175,7 +175,7 @@ class TogetherInferenceAdapter(
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
@ -192,7 +192,7 @@ class TogetherInferenceAdapter(
):
yield chunk
def _get_params(
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
input_dict = {}
@ -200,7 +200,7 @@ class TogetherInferenceAdapter(
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
convert_message_to_dict(m) for m in request.messages
await convert_message_to_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(

View file

@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
@ -388,3 +393,30 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
return await convert_image_media_to_url(content, download=True)
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest):
for m in request.messages:
m.content = await _convert_content(m.content)
else:
request.content = await _convert_content(request.content)
return request

View file

@ -9,6 +9,8 @@ import io
import json
from typing import Tuple
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from PIL import Image as PIL_Image
from termcolor import cprint
@ -49,7 +51,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
return content_has_media(request.content)
def convert_image_media_to_url(media: ImageMedia) -> str:
async def convert_image_media_to_url(
media: ImageMedia, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.image, PIL_Image.Image):
if media.image.format == "PNG":
format = "png"
@ -63,21 +67,36 @@ def convert_image_media_to_url(media: ImageMedia) -> str:
bytestream = io.BytesIO()
media.image.save(bytestream, format=media.image.format)
bytestream.seek(0)
return f"data:image/{format};base64," + base64.b64encode(
bytestream.getvalue()
).decode("utf-8")
content = bytestream.getvalue()
else:
assert isinstance(media.image, URL)
return media.image.uri
if not download:
return media.image.uri
else:
assert isinstance(media.image, URL)
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
format = content_type.split("/")[-1]
else:
format = "png"
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
else:
return base64.b64encode(content).decode("utf-8")
def convert_message_to_dict(message: Message) -> dict:
def _convert_content(content) -> dict:
async def convert_message_to_dict(message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"type": "image_url",
"image_url": {
"url": convert_image_media_to_url(content),
"url": await convert_image_media_to_url(content),
},
}
else:
@ -85,9 +104,9 @@ def convert_message_to_dict(message: Message) -> dict:
return {"type": "text", "text": content}
if isinstance(message.content, list):
content = [_convert_content(c) for c in message.content]
content = [await _convert_content(c) for c in message.content]
else:
content = [_convert_content(message.content)]
content = [await _convert_content(message.content)]
return {
"role": message.role,