Rework InterleavedContentMedia datatype so URL downloading is in llama-stack

This commit is contained in:
Ashwin Bharambe 2024-12-15 13:23:30 -08:00
parent c2f7905fa4
commit a9a041a1de
10 changed files with 368 additions and 146 deletions

View file

@ -7,13 +7,21 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
class URL(BaseModel):
uri: str
def __str__(self) -> str:
return self.uri
@json_schema_type @json_schema_type
class RestAPIMethod(Enum): class RestAPIMethod(Enum):
GET = "GET" GET = "GET"

View file

@ -16,14 +16,23 @@ from typing import (
Union, Union,
) )
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.apis.common.deployment_types import URL
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403
@ -40,17 +49,17 @@ class QuantizationType(Enum):
@json_schema_type @json_schema_type
class Fp8QuantizationConfig(BaseModel): class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value type: Literal["fp8"] = "fp8"
@json_schema_type @json_schema_type
class Bf16QuantizationConfig(BaseModel): class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value type: Literal["bf16"] = "bf16"
@json_schema_type @json_schema_type
class Int4QuantizationConfig(BaseModel): class Int4QuantizationConfig(BaseModel):
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value type: Literal["int4"] = "int4"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation" scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
@ -60,6 +69,98 @@ QuantizationConfig = Annotated[
] ]
@json_schema_type
class ImageContentItem(BaseModel):
type: Literal["image"] = "image"
data: Union[bytes, URL]
@json_schema_type
class TextContentItem(BaseModel):
type: Literal["text"] = "text"
text: str
# other modalities can be added here
InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
]
# accept a single "str" as a special case since it is common
InterleavedContent = str | InterleavedContentItem | List[InterleavedContentItem]
@json_schema_type
class UserMessage(BaseModel):
role: Literal["user"] = "user"
content: InterleavedContent
context: Optional[InterleavedContent] = None
@json_schema_type
class SystemMessage(BaseModel):
role: Literal["system"] = "system"
content: InterleavedContent
@json_schema_type
class ToolResponseMessage(BaseModel):
role: Literal["ipython"] = "ipython"
# it was nice to re-use the ToolResponse type, but having all messages
# have a `content` type makes things nicer too
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@json_schema_type
class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: List[ToolCall] = Field(default_factory=list)
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
]
@json_schema_type
class ToolResponse(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class ToolChoice(Enum):
auto = "auto"
required = "required"
@json_schema_type
class TokenLogProbs(BaseModel):
logprobs_by_token: Dict[str, float]
@json_schema_type @json_schema_type
class ChatCompletionResponseEventType(Enum): class ChatCompletionResponseEventType(Enum):
start = "start" start = "start"
@ -117,7 +218,7 @@ ResponseFormat = Annotated[
@json_schema_type @json_schema_type
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedTextMedia content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
@ -230,7 +331,7 @@ class Inference(Protocol):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -258,5 +359,5 @@ class Inference(Protocol):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ... ) -> EmbeddingsResponse: ...

View file

@ -24,7 +24,8 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import RawContent, RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import ( from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,10 +39,6 @@ from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from .config import ( from .config import (
Fp8QuantizationConfig, Fp8QuantizationConfig,
@ -53,6 +50,14 @@ from .config import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -206,7 +211,7 @@ class Llama:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
model_input: ModelInput, model_input: LLMInput,
max_gen_len: int, max_gen_len: int,
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
@ -343,7 +348,7 @@ class Llama:
def completion( def completion(
self, self,
request: CompletionRequest, request: CompletionRequestWithRawContent,
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
@ -354,10 +359,7 @@ class Llama:
): ):
max_gen_len = self.model.params.max_seq_len - 1 max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt( model_input = self.formatter.encode_content(request.content)
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
yield from self.generate( yield from self.generate(
model_input=model_input, model_input=model_input,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
@ -374,10 +376,8 @@ class Llama:
def chat_completion( def chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequestWithRawContent,
) -> Generator: ) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if ( if (
@ -389,7 +389,7 @@ class Llama:
yield from self.generate( yield from self.generate(
model_input=self.formatter.encode_dialog_prompt( model_input=self.formatter.encode_dialog_prompt(
messages, request.messages,
request.tool_prompt_format, request.tool_prompt_format,
), ),
max_gen_len=max_gen_len, max_gen_len=max_gen_len,

View file

@ -7,25 +7,59 @@
import asyncio import asyncio
import logging import logging
from typing import AsyncGenerator, List from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import Model
from llama_models.llama3.api.datatypes import (
RawMessage,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.apis.models import ModelType
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url, augment_content_with_response_format_prompt,
request_has_media, chat_completion_request_to_messages,
interleaved_content_convert_to_raw,
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama from .generation import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
Llama,
)
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -90,7 +124,7 @@ class MetaReferenceInferenceImpl(
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -99,6 +133,7 @@ class MetaReferenceInferenceImpl(
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest( request = CompletionRequest(
model=model_id, model=model_id,
content=content, content=content,
@ -108,7 +143,7 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request = await request_with_localized_media(request) request = await convert_request_to_raw(request)
if request.stream: if request.stream:
return self._stream_completion(request) return self._stream_completion(request)
@ -233,7 +268,13 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request = await request_with_localized_media(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
if SEMAPHORE.locked(): if SEMAPHORE.locked():
@ -406,29 +447,16 @@ class MetaReferenceInferenceImpl(
yield x yield x
async def request_with_localized_media( async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]: ) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages: for m in request.messages:
m.content = await _convert_content(m.content) content = await interleaved_content_convert_to_raw(m.content)
messages.append(RawMessage(**m.model_dump(), content=content))
request.messages = messages
else: else:
request.content = await _convert_content(request.content) request.content = await interleaved_content_convert_to_raw(request.content)
return request return request

View file

@ -19,6 +19,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -29,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media, content_has_media,
convert_message_to_dict, interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -108,7 +109,7 @@ class FireworksInferenceAdapter(
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -238,7 +239,7 @@ class FireworksInferenceAdapter(
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present:
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages await convert_message_to_openai_dict(m) for m in request.messages
] ]
else: else:
input_dict["prompt"] = chat_completion_request_to_prompt( input_dict["prompt"] = chat_completion_request_to_prompt(
@ -265,7 +266,7 @@ class FireworksInferenceAdapter(
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
@ -277,7 +278,7 @@ class FireworksInferenceAdapter(
), "Fireworks does not support media for embeddings" ), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create( response = self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
**kwargs, **kwargs,
) )

View file

@ -37,7 +37,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media, content_has_media,
convert_image_media_to_url, convert_image_content_to_url,
interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -141,7 +142,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -234,7 +235,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present:
contents = [ contents = [
await convert_message_to_dict_for_ollama(m) await convert_message_to_openai_dict_for_ollama(m)
for m in request.messages for m in request.messages
] ]
# flatten the list of lists # flatten the list of lists
@ -320,7 +321,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
@ -329,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
), "Ollama does not support media for embeddings" ), "Ollama does not support media for embeddings"
response = await self.client.embed( response = await self.client.embed(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )
embeddings = response["embeddings"] embeddings = response["embeddings"]
@ -358,21 +359,23 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return model return model
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia): if isinstance(content, ImageContentItem):
return { return {
"role": message.role, "role": message.role,
"images": [ "images": [
await convert_image_media_to_url( await convert_image_content_to_url(
content, download=True, include_format=False content, download=True, include_format=False
) )
], ],
} }
else: else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return { return {
"role": message.role, "role": message.role,
"content": content, "content": text,
} }
if isinstance(message.content, list): if isinstance(message.content, list):

View file

@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media, content_has_media,
convert_message_to_dict, interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -230,7 +231,7 @@ class TogetherInferenceAdapter(
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present:
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages await convert_message_to_openai_dict(m) for m in request.messages
] ]
else: else:
input_dict["prompt"] = chat_completion_request_to_prompt( input_dict["prompt"] = chat_completion_request_to_prompt(
@ -252,7 +253,7 @@ class TogetherInferenceAdapter(
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
assert all( assert all(
@ -260,7 +261,7 @@ class TogetherInferenceAdapter(
), "Together does not support media for embeddings" ), "Together does not support media for embeddings"
r = self._get_client().embeddings.create( r = self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )
embeddings = [item.embedding for item in r.data] embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)

View file

@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -30,7 +31,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media, content_has_media,
convert_message_to_dict, interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -71,7 +72,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -163,7 +164,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if media_present: if media_present:
# vllm does not seem to work well with image urls, so we download the images # vllm does not seem to work well with image urls, so we download the images
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_dict(m, download=True) await convert_message_to_openai_dict(m, download=True)
for m in request.messages for m in request.messages
] ]
else: else:
@ -202,7 +203,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
@ -215,7 +216,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
), "VLLM does not support media for embeddings" ), "VLLM does not support media for embeddings"
response = self.client.embeddings.create( response = self.client.embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
**kwargs, **kwargs,
) )

View file

@ -11,9 +11,12 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
class OpenAICompatCompletionChoiceDelta(BaseModel): class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str content: str
@ -246,3 +249,32 @@ async def process_chat_completion_stream_response(
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(
content, download=download
),
},
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
content = [await _convert_content(c) for c in message.content]
else:
content = [await _convert_content(message.content)]
return {
"role": message.role,
"content": content,
}

View file

@ -4,19 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import io
import json import json
import logging import logging
from typing import Tuple import re
from typing import List, Optional, Tuple, Union
import httpx import httpx
from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from PIL import Image as PIL_Image from llama_models.llama3.api.datatypes import (
from llama_models.llama3.api.datatypes import * # noqa: F403 RawContent,
from llama_stack.apis.inference import * # noqa: F403 RawContentItem,
from llama_models.datatypes import ModelFamily RawMediaItem,
RawTextItem,
Role,
ToolChoice,
ToolPromptFormat,
)
from llama_models.llama3.prompt_templates import ( from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,
FunctionTagCustomToolGenerator, FunctionTagCustomToolGenerator,
@ -25,15 +32,89 @@ from llama_models.llama3.prompt_templates import (
SystemDefaultGenerator, SystemDefaultGenerator,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from PIL import Image as PIL_Image
from llama_stack.apis.common.deployment_types import URL
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
Message,
ResponseFormat,
ResponseFormatType,
SystemMessage,
TextContentItem,
UserMessage,
)
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def content_has_media(content: InterleavedTextMedia): def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
if isinstance(c, str):
return c
elif isinstance(c, ImageContentItem):
return "<image>"
elif isinstance(c, TextContentItem):
return c.text
else:
raise ValueError(f"Unsupported content type: {type(c)}")
if isinstance(content, list):
return sep.join(_process(c) for c in content)
else:
return _process(content)
async def interleaved_content_convert_to_raw(
content: InterleavedContent,
) -> RawContent:
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
if isinstance(c, str):
return RawTextItem(text=c)
elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem):
# load image and return PIL version
img = c.data
if isinstance(img, URL):
if img.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
if not match:
raise ValueError("Invalid data URL format")
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif img.uri.startswith("file://"):
path = img.uri[len("file://") :]
with open(path, "rb") as f:
data = f.read() # type: ignore
elif img.uri.startswith("http"):
async with httpx.AsyncClient() as client:
response = await client.get(img.uri)
data = response.content
else:
raise ValueError("Unsupported URL type")
return RawMediaItem(data=data)
else:
raise ValueError(f"Unsupported content type: {type(c)}")
if isinstance(content, list):
return await asyncio.gather(*(_localize_single(c) for c in content))
else:
return await _localize_single(content)
def content_has_media(content: InterleavedContent):
def _has_media_content(c): def _has_media_content(c):
return isinstance(c, ImageMedia) return isinstance(c, ImageContentItem)
if isinstance(content, list): if isinstance(content, list):
return any(_has_media_content(c) for c in content) return any(_has_media_content(c) for c in content)
@ -52,28 +133,8 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
return content_has_media(request.content) return content_has_media(request.content)
async def convert_image_media_to_url( async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
media: ImageMedia, download: bool = False, include_format: bool = True if isinstance(media.data, URL) and media.data.uri.startswith("http"):
) -> str:
if isinstance(media.image, PIL_Image.Image):
if media.image.format == "PNG":
format = "png"
elif media.image.format == "GIF":
format = "gif"
elif media.image.format == "JPEG":
format = "jpeg"
else:
raise ValueError(f"Unsupported image format {media.image.format}")
bytestream = io.BytesIO()
media.image.save(bytestream, format=media.image.format)
bytestream.seek(0)
content = bytestream.getvalue()
else:
if not download:
return media.image.uri
else:
assert isinstance(media.image, URL)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri) r = await client.get(media.image.uri)
content = r.content content = r.content
@ -82,7 +143,19 @@ async def convert_image_media_to_url(
format = content_type.split("/")[-1] format = content_type.split("/")[-1]
else: else:
format = "png" format = "png"
return content, format
else:
image = PIL_Image.open(media.data)
return media.data, image.format
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.data, URL) and not download:
return media.image.uri
content, format = await localize_image_content(media)
if include_format: if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode( return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8" "utf-8"
@ -91,32 +164,6 @@ async def convert_image_media_to_url(
return base64.b64encode(content).decode("utf-8") return base64.b64encode(content).decode("utf-8")
# TODO: name this function better! this is about OpenAI compatibile image
# media conversion of the message. this should probably go in openai_compat.py
async def convert_message_to_dict(message: Message, download: bool = False) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_media_to_url(content, download=download),
},
}
else:
assert isinstance(content, str)
return {"type": "text", "text": content}
if isinstance(message.content, list):
content = [await _convert_content(c) for c in message.content]
else:
content = [await _convert_content(message.content)]
return {
"role": message.role,
"content": content,
}
def completion_request_to_prompt( def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat request: CompletionRequest, formatter: ChatFormat
) -> str: ) -> str:
@ -330,7 +377,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n" sys_content += "\n"
if existing_system_message: if existing_system_message:
sys_content += interleaved_text_media_as_str( sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n" existing_system_message.content, sep="\n"
) )