Rework InterleavedContentMedia datatype so URL downloading is in llama-stack

This commit is contained in:
Ashwin Bharambe 2024-12-15 13:23:30 -08:00
parent c2f7905fa4
commit a9a041a1de
10 changed files with 368 additions and 146 deletions

View file

@ -24,7 +24,8 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import RawContent, RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,10 +39,6 @@ from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from .config import (
Fp8QuantizationConfig,
@ -53,6 +50,14 @@ from .config import (
log = logging.getLogger(__name__)
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -206,7 +211,7 @@ class Llama:
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
@ -343,7 +348,7 @@ class Llama:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
@ -354,10 +359,7 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -374,10 +376,8 @@ class Llama:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
@ -389,7 +389,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
request.messages,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,

View file

@ -7,25 +7,59 @@
import asyncio
import logging
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import Model
from llama_models.llama3.api.datatypes import (
RawMessage,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
interleaved_content_convert_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .generation import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
Llama,
)
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -90,7 +124,7 @@ class MetaReferenceInferenceImpl(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -99,6 +133,7 @@ class MetaReferenceInferenceImpl(
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
@ -108,7 +143,7 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
@ -233,7 +268,13 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -406,29 +447,16 @@ class MetaReferenceInferenceImpl(
yield x
async def request_with_localized_media(
async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
m.content = await _convert_content(m.content)
content = await interleaved_content_convert_to_raw(m.content)
messages.append(RawMessage(**m.model_dump(), content=content))
request.messages = messages
else:
request.content = await _convert_content(request.content)
request.content = await interleaved_content_convert_to_raw(request.content)
return request

View file

@ -19,6 +19,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -29,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
interleaved_content_as_str,
request_has_media,
)
@ -108,7 +109,7 @@ class FireworksInferenceAdapter(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -238,7 +239,7 @@ class FireworksInferenceAdapter(
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages
await convert_message_to_openai_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
@ -265,7 +266,7 @@ class FireworksInferenceAdapter(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -277,7 +278,7 @@ class FireworksInferenceAdapter(
), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)

View file

@ -37,7 +37,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_image_media_to_url,
convert_image_content_to_url,
interleaved_content_as_str,
request_has_media,
)
@ -141,7 +142,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -234,7 +235,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if isinstance(request, ChatCompletionRequest):
if media_present:
contents = [
await convert_message_to_dict_for_ollama(m)
await convert_message_to_openai_dict_for_ollama(m)
for m in request.messages
]
# flatten the list of lists
@ -320,7 +321,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -329,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
), "Ollama does not support media for embeddings"
response = await self.client.embed(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = response["embeddings"]
@ -358,21 +359,23 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return model
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [
await convert_image_media_to_url(
await convert_image_content_to_url(
content, download=True, include_format=False
)
],
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {
"role": message.role,
"content": content,
"content": text,
}
if isinstance(message.content, list):

View file

@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
interleaved_content_as_str,
request_has_media,
)
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -230,7 +231,7 @@ class TogetherInferenceAdapter(
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages
await convert_message_to_openai_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
@ -252,7 +253,7 @@ class TogetherInferenceAdapter(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
@ -260,7 +261,7 @@ class TogetherInferenceAdapter(
), "Together does not support media for embeddings"
r = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -30,7 +31,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
interleaved_content_as_str,
request_has_media,
)
@ -71,7 +72,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -163,7 +164,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if media_present:
# vllm does not seem to work well with image urls, so we download the images
input_dict["messages"] = [
await convert_message_to_dict(m, download=True)
await convert_message_to_openai_dict(m, download=True)
for m in request.messages
]
else:
@ -202,7 +203,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -215,7 +216,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)

View file

@ -11,9 +11,12 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str
@ -246,3 +249,32 @@ async def process_chat_completion_stream_response(
stop_reason=stop_reason,
)
)
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(
content, download=download
),
},
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
content = [await _convert_content(c) for c in message.content]
else:
content = [await _convert_content(message.content)]
return {
"role": message.role,
"content": content,
}

View file

@ -4,19 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import json
import logging
from typing import Tuple
import re
from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.llama3.api.chat_format import ChatFormat
from PIL import Image as PIL_Image
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import ModelFamily
from llama_models.llama3.api.datatypes import (
RawContent,
RawContentItem,
RawMediaItem,
RawTextItem,
Role,
ToolChoice,
ToolPromptFormat,
)
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
@ -25,15 +32,89 @@ from llama_models.llama3.prompt_templates import (
SystemDefaultGenerator,
)
from llama_models.sku_list import resolve_model
from PIL import Image as PIL_Image
from llama_stack.apis.common.deployment_types import URL
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
Message,
ResponseFormat,
ResponseFormatType,
SystemMessage,
TextContentItem,
UserMessage,
)
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)
def content_has_media(content: InterleavedTextMedia):
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
if isinstance(c, str):
return c
elif isinstance(c, ImageContentItem):
return "<image>"
elif isinstance(c, TextContentItem):
return c.text
else:
raise ValueError(f"Unsupported content type: {type(c)}")
if isinstance(content, list):
return sep.join(_process(c) for c in content)
else:
return _process(content)
async def interleaved_content_convert_to_raw(
content: InterleavedContent,
) -> RawContent:
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
if isinstance(c, str):
return RawTextItem(text=c)
elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem):
# load image and return PIL version
img = c.data
if isinstance(img, URL):
if img.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
if not match:
raise ValueError("Invalid data URL format")
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif img.uri.startswith("file://"):
path = img.uri[len("file://") :]
with open(path, "rb") as f:
data = f.read() # type: ignore
elif img.uri.startswith("http"):
async with httpx.AsyncClient() as client:
response = await client.get(img.uri)
data = response.content
else:
raise ValueError("Unsupported URL type")
return RawMediaItem(data=data)
else:
raise ValueError(f"Unsupported content type: {type(c)}")
if isinstance(content, list):
return await asyncio.gather(*(_localize_single(c) for c in content))
else:
return await _localize_single(content)
def content_has_media(content: InterleavedContent):
def _has_media_content(c):
return isinstance(c, ImageMedia)
return isinstance(c, ImageContentItem)
if isinstance(content, list):
return any(_has_media_content(c) for c in content)
@ -52,37 +133,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
return content_has_media(request.content)
async def convert_image_media_to_url(
media: ImageMedia, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.image, PIL_Image.Image):
if media.image.format == "PNG":
format = "png"
elif media.image.format == "GIF":
format = "gif"
elif media.image.format == "JPEG":
format = "jpeg"
else:
raise ValueError(f"Unsupported image format {media.image.format}")
bytestream = io.BytesIO()
media.image.save(bytestream, format=media.image.format)
bytestream.seek(0)
content = bytestream.getvalue()
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
format = content_type.split("/")[-1]
else:
format = "png"
return content, format
else:
if not download:
return media.image.uri
else:
assert isinstance(media.image, URL)
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
format = content_type.split("/")[-1]
else:
format = "png"
image = PIL_Image.open(media.data)
return media.data, image.format
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.data, URL) and not download:
return media.image.uri
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
@ -91,32 +164,6 @@ async def convert_image_media_to_url(
return base64.b64encode(content).decode("utf-8")
# TODO: name this function better! this is about OpenAI compatibile image
# media conversion of the message. this should probably go in openai_compat.py
async def convert_message_to_dict(message: Message, download: bool = False) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_media_to_url(content, download=download),
},
}
else:
assert isinstance(content, str)
return {"type": "text", "text": content}
if isinstance(message.content, list):
content = [await _convert_content(c) for c in message.content]
else:
content = [await _convert_content(message.content)]
return {
"role": message.role,
"content": content,
}
def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
@ -330,7 +377,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n"
if existing_system_message:
sys_content += interleaved_text_media_as_str(
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)