diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 421706650..0070756d8 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -84,14 +84,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_completion( self, request: CompletionRequest, client: Fireworks ) -> CompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) r = await client.completion.acreate(**params) return process_completion_response(r, self.formatter) async def _stream_completion( self, request: CompletionRequest, client: Fireworks ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) stream = client.completion.acreate(**params) async for chunk in process_completion_stream_response(stream, self.formatter): @@ -130,7 +130,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks ) -> ChatCompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) if "messages" in params: r = await client.chat.completions.acreate(**params) else: @@ -140,7 +140,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def _stream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) if "messages" in params: stream = client.chat.completions.acreate(**params) @@ -152,7 +152,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): ): yield chunk - def _get_params( + async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: input_dict = {} @@ -161,7 +161,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): if isinstance(request, ChatCompletionRequest): if media_present: input_dict["messages"] = [ - convert_message_to_dict(m) for m in request.messages + await convert_message_to_dict(m) for m in request.messages ] else: input_dict["prompt"] = chat_completion_request_to_prompt( diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 4c12a967f..3530e1234 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import io from typing import AsyncGenerator import httpx @@ -31,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_image_media_to_url, request_has_media, ) @@ -282,7 +281,11 @@ async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: if isinstance(content, ImageMedia): return { "role": message.role, - "images": [await convert_image_media_to_base64(content)], + "images": [ + await convert_image_media_to_url( + content, download=True, include_format=False + ) + ], } else: return { @@ -294,18 +297,3 @@ async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: return [await _convert_content(c) for c in message.content] else: return [await _convert_content(message.content)] - - -async def convert_image_media_to_base64(media: ImageMedia) -> str: - if isinstance(media.image, PIL_Image.Image): - bytestream = io.BytesIO() - media.image.save(bytestream, format=media.image.format) - bytestream.seek(0) - content = bytestream.getvalue() - else: - assert isinstance(media.image, URL) - async with httpx.AsyncClient() as client: - r = await client.get(media.image.uri) - content = r.content - - return base64.b64encode(content).decode("utf-8") diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 08f5a020f..28a566415 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -99,12 +99,12 @@ class TogetherInferenceAdapter( async def _nonstream_completion( self, request: CompletionRequest ) -> ChatCompletionResponse: - params = self._get_params_for_completion(request) + params = await self._get_params(request) r = self._get_client().completions.create(**params) return process_completion_response(r, self.formatter) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): @@ -165,7 +165,7 @@ class TogetherInferenceAdapter( async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) if "messages" in params: r = self._get_client().chat.completions.create(**params) else: @@ -175,7 +175,7 @@ class TogetherInferenceAdapter( async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): @@ -192,7 +192,7 @@ class TogetherInferenceAdapter( ): yield chunk - def _get_params( + async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: input_dict = {} @@ -200,7 +200,7 @@ class TogetherInferenceAdapter( if isinstance(request, ChatCompletionRequest): if media_present: input_dict["messages"] = [ - convert_message_to_dict(m) for m in request.messages + await convert_message_to_dict(m) for m in request.messages ] else: input_dict["prompt"] = chat_completion_request_to_prompt( diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 5588be6c0..fb72e25ef 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_image_media_to_url, + request_has_media, +) + from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -388,3 +393,30 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + +async def request_with_localized_media( + request: Union[ChatCompletionRequest, CompletionRequest], +) -> Union[ChatCompletionRequest, CompletionRequest]: + if not request_has_media(request): + return request + + async def _convert_single_content(content): + if isinstance(content, ImageMedia): + return await convert_image_media_to_url(content, download=True) + else: + return content + + async def _convert_content(content): + if isinstance(content, list): + return [await _convert_single_content(c) for c in content] + else: + return await _convert_single_content(content) + + if isinstance(request, ChatCompletionRequest): + for m in request.messages: + m.content = await _convert_content(m.content) + else: + request.content = await _convert_content(request.content) + + return request diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 96bc3b799..9decf5a00 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -9,6 +9,8 @@ import io import json from typing import Tuple +import httpx + from llama_models.llama3.api.chat_format import ChatFormat from PIL import Image as PIL_Image from termcolor import cprint @@ -49,7 +51,9 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]): return content_has_media(request.content) -def convert_image_media_to_url(media: ImageMedia) -> str: +async def convert_image_media_to_url( + media: ImageMedia, download: bool = False, include_format: bool = True +) -> str: if isinstance(media.image, PIL_Image.Image): if media.image.format == "PNG": format = "png" @@ -63,21 +67,36 @@ def convert_image_media_to_url(media: ImageMedia) -> str: bytestream = io.BytesIO() media.image.save(bytestream, format=media.image.format) bytestream.seek(0) - return f"data:image/{format};base64," + base64.b64encode( - bytestream.getvalue() - ).decode("utf-8") + content = bytestream.getvalue() else: - assert isinstance(media.image, URL) - return media.image.uri + if not download: + return media.image.uri + else: + assert isinstance(media.image, URL) + async with httpx.AsyncClient() as client: + r = await client.get(media.image.uri) + content = r.content + content_type = r.headers.get("content-type") + if content_type: + format = content_type.split("/")[-1] + else: + format = "png" + + if include_format: + return f"data:image/{format};base64," + base64.b64encode(content).decode( + "utf-8" + ) + else: + return base64.b64encode(content).decode("utf-8") -def convert_message_to_dict(message: Message) -> dict: - def _convert_content(content) -> dict: +async def convert_message_to_dict(message: Message) -> dict: + async def _convert_content(content) -> dict: if isinstance(content, ImageMedia): return { "type": "image_url", "image_url": { - "url": convert_image_media_to_url(content), + "url": await convert_image_media_to_url(content), }, } else: @@ -85,9 +104,9 @@ def convert_message_to_dict(message: Message) -> dict: return {"type": "text", "text": content} if isinstance(message.content, list): - content = [_convert_content(c) for c in message.content] + content = [await _convert_content(c) for c in message.content] else: - content = [_convert_content(message.content)] + content = [await _convert_content(message.content)] return { "role": message.role,