mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
localize media for meta_reference inference
This commit is contained in:
parent
d543eb442b
commit
2da16bf852
5 changed files with 80 additions and 41 deletions
|
@ -84,14 +84,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest, client: Fireworks
|
self, request: CompletionRequest, client: Fireworks
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await client.completion.acreate(**params)
|
r = await client.completion.acreate(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(
|
async def _stream_completion(
|
||||||
self, request: CompletionRequest, client: Fireworks
|
self, request: CompletionRequest, client: Fireworks
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = client.completion.acreate(**params)
|
stream = client.completion.acreate(**params)
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
|
@ -130,7 +130,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await client.chat.completions.acreate(**params)
|
r = await client.chat.completions.acreate(**params)
|
||||||
else:
|
else:
|
||||||
|
@ -140,7 +140,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
stream = client.chat.completions.acreate(**params)
|
stream = client.chat.completions.acreate(**params)
|
||||||
|
@ -152,7 +152,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
@ -161,7 +161,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
convert_message_to_dict(m) for m in request.messages
|
await convert_message_to_dict(m) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -31,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -282,7 +281,11 @@ async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
if isinstance(content, ImageMedia):
|
if isinstance(content, ImageMedia):
|
||||||
return {
|
return {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"images": [await convert_image_media_to_base64(content)],
|
"images": [
|
||||||
|
await convert_image_media_to_url(
|
||||||
|
content, download=True, include_format=False
|
||||||
|
)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
|
@ -294,18 +297,3 @@ async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
return [await _convert_content(c) for c in message.content]
|
return [await _convert_content(c) for c in message.content]
|
||||||
else:
|
else:
|
||||||
return [await _convert_content(message.content)]
|
return [await _convert_content(message.content)]
|
||||||
|
|
||||||
|
|
||||||
async def convert_image_media_to_base64(media: ImageMedia) -> str:
|
|
||||||
if isinstance(media.image, PIL_Image.Image):
|
|
||||||
bytestream = io.BytesIO()
|
|
||||||
media.image.save(bytestream, format=media.image.format)
|
|
||||||
bytestream.seek(0)
|
|
||||||
content = bytestream.getvalue()
|
|
||||||
else:
|
|
||||||
assert isinstance(media.image, URL)
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
r = await client.get(media.image.uri)
|
|
||||||
content = r.content
|
|
||||||
|
|
||||||
return base64.b64encode(content).decode("utf-8")
|
|
||||||
|
|
|
@ -99,12 +99,12 @@ class TogetherInferenceAdapter(
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params(request)
|
||||||
r = self._get_client().completions.create(**params)
|
r = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -165,7 +165,7 @@ class TogetherInferenceAdapter(
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = self._get_client().chat.completions.create(**params)
|
r = self._get_client().chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
|
@ -175,7 +175,7 @@ class TogetherInferenceAdapter(
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -192,7 +192,7 @@ class TogetherInferenceAdapter(
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
@ -200,7 +200,7 @@ class TogetherInferenceAdapter(
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
convert_message_to_dict(m) for m in request.messages
|
await convert_message_to_dict(m) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
|
|
@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_image_media_to_url,
|
||||||
|
request_has_media,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -388,3 +393,30 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
async def request_with_localized_media(
|
||||||
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||||
|
if not request_has_media(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def _convert_single_content(content):
|
||||||
|
if isinstance(content, ImageMedia):
|
||||||
|
return await convert_image_media_to_url(content, download=True)
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
async def _convert_content(content):
|
||||||
|
if isinstance(content, list):
|
||||||
|
return [await _convert_single_content(c) for c in content]
|
||||||
|
else:
|
||||||
|
return await _convert_single_content(content)
|
||||||
|
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
for m in request.messages:
|
||||||
|
m.content = await _convert_content(m.content)
|
||||||
|
else:
|
||||||
|
request.content = await _convert_content(request.content)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
|
@ -9,6 +9,8 @@ import io
|
||||||
import json
|
import json
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -49,7 +51,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||||
return content_has_media(request.content)
|
return content_has_media(request.content)
|
||||||
|
|
||||||
|
|
||||||
def convert_image_media_to_url(media: ImageMedia) -> str:
|
async def convert_image_media_to_url(
|
||||||
|
media: ImageMedia, download: bool = False, include_format: bool = True
|
||||||
|
) -> str:
|
||||||
if isinstance(media.image, PIL_Image.Image):
|
if isinstance(media.image, PIL_Image.Image):
|
||||||
if media.image.format == "PNG":
|
if media.image.format == "PNG":
|
||||||
format = "png"
|
format = "png"
|
||||||
|
@ -63,21 +67,36 @@ def convert_image_media_to_url(media: ImageMedia) -> str:
|
||||||
bytestream = io.BytesIO()
|
bytestream = io.BytesIO()
|
||||||
media.image.save(bytestream, format=media.image.format)
|
media.image.save(bytestream, format=media.image.format)
|
||||||
bytestream.seek(0)
|
bytestream.seek(0)
|
||||||
return f"data:image/{format};base64," + base64.b64encode(
|
content = bytestream.getvalue()
|
||||||
bytestream.getvalue()
|
|
||||||
).decode("utf-8")
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(media.image, URL)
|
if not download:
|
||||||
return media.image.uri
|
return media.image.uri
|
||||||
|
else:
|
||||||
|
assert isinstance(media.image, URL)
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(media.image.uri)
|
||||||
|
content = r.content
|
||||||
|
content_type = r.headers.get("content-type")
|
||||||
|
if content_type:
|
||||||
|
format = content_type.split("/")[-1]
|
||||||
|
else:
|
||||||
|
format = "png"
|
||||||
|
|
||||||
|
if include_format:
|
||||||
|
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def convert_message_to_dict(message: Message) -> dict:
|
async def convert_message_to_dict(message: Message) -> dict:
|
||||||
def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
if isinstance(content, ImageMedia):
|
if isinstance(content, ImageMedia):
|
||||||
return {
|
return {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": convert_image_media_to_url(content),
|
"url": await convert_image_media_to_url(content),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -85,9 +104,9 @@ def convert_message_to_dict(message: Message) -> dict:
|
||||||
return {"type": "text", "text": content}
|
return {"type": "text", "text": content}
|
||||||
|
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
content = [_convert_content(c) for c in message.content]
|
content = [await _convert_content(c) for c in message.content]
|
||||||
else:
|
else:
|
||||||
content = [_convert_content(message.content)]
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue