mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Rework InterleavedContentMedia datatype so URL downloading is in llama-stack
This commit is contained in:
parent
c2f7905fa4
commit
a9a041a1de
10 changed files with 368 additions and 146 deletions
|
@ -7,13 +7,21 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type(
|
||||
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
|
||||
)
|
||||
class URL(BaseModel):
|
||||
uri: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.uri
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RestAPIMethod(Enum):
|
||||
GET = "GET"
|
||||
|
|
|
@ -16,14 +16,23 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.apis.common.deployment_types import URL
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
|
@ -40,17 +49,17 @@ class QuantizationType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Fp8QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
||||
type: Literal["fp8"] = "fp8"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Bf16QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||
type: Literal["bf16"] = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Int4QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
|
||||
type: Literal["int4"] = "int4"
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
|
@ -60,6 +69,98 @@ QuantizationConfig = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImageContentItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: Union[bytes, URL]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TextContentItem(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = str | InterleavedContentItem | List[InterleavedContentItem]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserMessage(BaseModel):
|
||||
role: Literal["user"] = "user"
|
||||
content: InterleavedContent
|
||||
context: Optional[InterleavedContent] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SystemMessage(BaseModel):
|
||||
role: Literal["system"] = "system"
|
||||
content: InterleavedContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolResponseMessage(BaseModel):
|
||||
role: Literal["ipython"] = "ipython"
|
||||
# it was nice to re-use the ToolResponse type, but having all messages
|
||||
# have a `content` type makes things nicer too
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionMessage(BaseModel):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolResponse(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolChoice(Enum):
|
||||
auto = "auto"
|
||||
required = "required"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TokenLogProbs(BaseModel):
|
||||
logprobs_by_token: Dict[str, float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseEventType(Enum):
|
||||
start = "start"
|
||||
|
@ -117,7 +218,7 @@ ResponseFormat = Annotated[
|
|||
@json_schema_type
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedTextMedia
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
@ -230,7 +331,7 @@ class Inference(Protocol):
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -258,5 +359,5 @@ class Inference(Protocol):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
|
|
@ -24,7 +24,8 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.datatypes import RawContent, RawMessage
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
|
@ -38,10 +39,6 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
|
@ -53,6 +50,14 @@ from .config import (
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: List[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
content: RawContent
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
|
@ -206,7 +211,7 @@ class Llama:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_input: ModelInput,
|
||||
model_input: LLMInput,
|
||||
max_gen_len: int,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
|
@ -343,7 +348,7 @@ class Llama:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
|
@ -354,10 +359,7 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
model_input = self.formatter.encode_content(content)
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
|
@ -374,10 +376,8 @@ class Llama:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
|
@ -389,7 +389,7 @@ class Llama:
|
|||
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
request.messages,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
|
|
|
@ -7,25 +7,59 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
RawMessage,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
interleaved_content_convert_to_raw,
|
||||
)
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .generation import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
Llama,
|
||||
)
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -90,7 +124,7 @@ class MetaReferenceInferenceImpl(
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -99,6 +133,7 @@ class MetaReferenceInferenceImpl(
|
|||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content = augment_content_with_response_format_prompt(response_format, content)
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
|
@ -108,7 +143,7 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
|
@ -233,7 +268,13 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
|
@ -406,29 +447,16 @@ class MetaReferenceInferenceImpl(
|
|||
yield x
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
async def convert_request_to_raw(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||
if not request_has_media(request):
|
||||
return request
|
||||
|
||||
async def _convert_single_content(content):
|
||||
if isinstance(content, ImageMedia):
|
||||
url = await convert_image_media_to_url(content, download=True)
|
||||
return ImageMedia(image=URL(uri=url))
|
||||
else:
|
||||
return content
|
||||
|
||||
async def _convert_content(content):
|
||||
if isinstance(content, list):
|
||||
return [await _convert_single_content(c) for c in content]
|
||||
else:
|
||||
return await _convert_single_content(content)
|
||||
|
||||
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
m.content = await _convert_content(m.content)
|
||||
content = await interleaved_content_convert_to_raw(m.content)
|
||||
messages.append(RawMessage(**m.model_dump(), content=content))
|
||||
request.messages = messages
|
||||
else:
|
||||
request.content = await _convert_content(request.content)
|
||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||
|
||||
return request
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -29,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -108,7 +109,7 @@ class FireworksInferenceAdapter(
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -238,7 +239,7 @@ class FireworksInferenceAdapter(
|
|||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_dict(m) for m in request.messages
|
||||
await convert_message_to_openai_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
|
@ -265,7 +266,7 @@ class FireworksInferenceAdapter(
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
@ -277,7 +278,7 @@ class FireworksInferenceAdapter(
|
|||
), "Fireworks does not support media for embeddings"
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -37,7 +37,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_media_to_url,
|
||||
convert_image_content_to_url,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -141,7 +142,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -234,7 +235,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
contents = [
|
||||
await convert_message_to_dict_for_ollama(m)
|
||||
await convert_message_to_openai_dict_for_ollama(m)
|
||||
for m in request.messages
|
||||
]
|
||||
# flatten the list of lists
|
||||
|
@ -320,7 +321,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
@ -329,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
), "Ollama does not support media for embeddings"
|
||||
response = await self.client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
|
@ -358,21 +359,23 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return model
|
||||
|
||||
|
||||
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageMedia):
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"role": message.role,
|
||||
"images": [
|
||||
await convert_image_media_to_url(
|
||||
await convert_image_content_to_url(
|
||||
content, download=True, include_format=False
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
text = content.text if isinstance(content, TextContentItem) else content
|
||||
assert isinstance(text, str)
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
"content": text,
|
||||
}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -230,7 +231,7 @@ class TogetherInferenceAdapter(
|
|||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_dict(m) for m in request.messages
|
||||
await convert_message_to_openai_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
|
@ -252,7 +253,7 @@ class TogetherInferenceAdapter(
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(
|
||||
|
@ -260,7 +261,7 @@ class TogetherInferenceAdapter(
|
|||
), "Together does not support media for embeddings"
|
||||
r = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -30,7 +31,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -71,7 +72,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -163,7 +164,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if media_present:
|
||||
# vllm does not seem to work well with image urls, so we download the images
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_dict(m, download=True)
|
||||
await convert_message_to_openai_dict(m, download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
else:
|
||||
|
@ -202,7 +203,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
@ -215,7 +216,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -11,9 +11,12 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
from llama_models.llama3.api.datatypes import StopReason
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
content: str
|
||||
|
@ -246,3 +249,32 @@ async def process_chat_completion_stream_response(
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict(
|
||||
message: Message, download: bool = False
|
||||
) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_content_to_url(
|
||||
content, download=download
|
||||
),
|
||||
},
|
||||
}
|
||||
else:
|
||||
text = content.text if isinstance(content, TextContentItem) else content
|
||||
assert isinstance(text, str)
|
||||
return {"type": "text", "text": text}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content = [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
content = [await _convert_content(message.content)]
|
||||
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
|
|
@ -4,19 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Tuple
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from PIL import Image as PIL_Image
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_models.datatypes import ModelFamily
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
RawContent,
|
||||
RawContentItem,
|
||||
RawMediaItem,
|
||||
RawTextItem,
|
||||
Role,
|
||||
ToolChoice,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
|
@ -25,15 +32,89 @@ from llama_models.llama3.prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.deployment_types import URL
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SystemMessage,
|
||||
TextContentItem,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def content_has_media(content: InterleavedTextMedia):
|
||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
elif isinstance(c, ImageContentItem):
|
||||
return "<image>"
|
||||
elif isinstance(c, TextContentItem):
|
||||
return c.text
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||
|
||||
if isinstance(content, list):
|
||||
return sep.join(_process(c) for c in content)
|
||||
else:
|
||||
return _process(content)
|
||||
|
||||
|
||||
async def interleaved_content_convert_to_raw(
|
||||
content: InterleavedContent,
|
||||
) -> RawContent:
|
||||
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
|
||||
|
||||
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
|
||||
if isinstance(c, str):
|
||||
return RawTextItem(text=c)
|
||||
elif isinstance(c, TextContentItem):
|
||||
return RawTextItem(text=c.text)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
# load image and return PIL version
|
||||
img = c.data
|
||||
if isinstance(img, URL):
|
||||
if img.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
|
||||
if not match:
|
||||
raise ValueError("Invalid data URL format")
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif img.uri.startswith("file://"):
|
||||
path = img.uri[len("file://") :]
|
||||
with open(path, "rb") as f:
|
||||
data = f.read() # type: ignore
|
||||
elif img.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(img.uri)
|
||||
data = response.content
|
||||
else:
|
||||
raise ValueError("Unsupported URL type")
|
||||
return RawMediaItem(data=data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||
|
||||
if isinstance(content, list):
|
||||
return await asyncio.gather(*(_localize_single(c) for c in content))
|
||||
else:
|
||||
return await _localize_single(content)
|
||||
|
||||
|
||||
def content_has_media(content: InterleavedContent):
|
||||
def _has_media_content(c):
|
||||
return isinstance(c, ImageMedia)
|
||||
return isinstance(c, ImageContentItem)
|
||||
|
||||
if isinstance(content, list):
|
||||
return any(_has_media_content(c) for c in content)
|
||||
|
@ -52,37 +133,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
|||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def convert_image_media_to_url(
|
||||
media: ImageMedia, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if isinstance(media.image, PIL_Image.Image):
|
||||
if media.image.format == "PNG":
|
||||
format = "png"
|
||||
elif media.image.format == "GIF":
|
||||
format = "gif"
|
||||
elif media.image.format == "JPEG":
|
||||
format = "jpeg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported image format {media.image.format}")
|
||||
|
||||
bytestream = io.BytesIO()
|
||||
media.image.save(bytestream, format=media.image.format)
|
||||
bytestream.seek(0)
|
||||
content = bytestream.getvalue()
|
||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.image.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
format = content_type.split("/")[-1]
|
||||
else:
|
||||
format = "png"
|
||||
return content, format
|
||||
else:
|
||||
if not download:
|
||||
return media.image.uri
|
||||
else:
|
||||
assert isinstance(media.image, URL)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.image.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
format = content_type.split("/")[-1]
|
||||
else:
|
||||
format = "png"
|
||||
image = PIL_Image.open(media.data)
|
||||
return media.data, image.format
|
||||
|
||||
|
||||
async def convert_image_content_to_url(
|
||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if isinstance(media.data, URL) and not download:
|
||||
return media.image.uri
|
||||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||
"utf-8"
|
||||
|
@ -91,32 +164,6 @@ async def convert_image_media_to_url(
|
|||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
# TODO: name this function better! this is about OpenAI compatibile image
|
||||
# media conversion of the message. this should probably go in openai_compat.py
|
||||
async def convert_message_to_dict(message: Message, download: bool = False) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageMedia):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_media_to_url(content, download=download),
|
||||
},
|
||||
}
|
||||
else:
|
||||
assert isinstance(content, str)
|
||||
return {"type": "text", "text": content}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content = [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
content = [await _convert_content(message.content)]
|
||||
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
|
@ -330,7 +377,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
sys_content += "\n"
|
||||
|
||||
if existing_system_message:
|
||||
sys_content += interleaved_text_media_as_str(
|
||||
sys_content += interleaved_content_as_str(
|
||||
existing_system_message.content, sep="\n"
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue