mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Rework InterleavedContentMedia datatype so URL downloading is in llama-stack
This commit is contained in:
parent
c2f7905fa4
commit
a9a041a1de
10 changed files with 368 additions and 146 deletions
|
@ -7,13 +7,21 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type(
|
||||||
|
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
|
||||||
|
)
|
||||||
|
class URL(BaseModel):
|
||||||
|
uri: str
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.uri
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RestAPIMethod(Enum):
|
class RestAPIMethod(Enum):
|
||||||
GET = "GET"
|
GET = "GET"
|
||||||
|
|
|
@ -16,14 +16,23 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,17 +49,17 @@ class QuantizationType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Fp8QuantizationConfig(BaseModel):
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
|
type: Literal["fp8"] = "fp8"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Bf16QuantizationConfig(BaseModel):
|
class Bf16QuantizationConfig(BaseModel):
|
||||||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
type: Literal["bf16"] = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Int4QuantizationConfig(BaseModel):
|
class Int4QuantizationConfig(BaseModel):
|
||||||
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
|
type: Literal["int4"] = "int4"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,6 +69,98 @@ QuantizationConfig = Annotated[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ImageContentItem(BaseModel):
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
data: Union[bytes, URL]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TextContentItem(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
# other modalities can be added here
|
||||||
|
InterleavedContentItem = Annotated[
|
||||||
|
Union[ImageContentItem, TextContentItem],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# accept a single "str" as a special case since it is common
|
||||||
|
InterleavedContent = str | InterleavedContentItem | List[InterleavedContentItem]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class UserMessage(BaseModel):
|
||||||
|
role: Literal["user"] = "user"
|
||||||
|
content: InterleavedContent
|
||||||
|
context: Optional[InterleavedContent] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SystemMessage(BaseModel):
|
||||||
|
role: Literal["system"] = "system"
|
||||||
|
content: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolResponseMessage(BaseModel):
|
||||||
|
role: Literal["ipython"] = "ipython"
|
||||||
|
# it was nice to re-use the ToolResponse type, but having all messages
|
||||||
|
# have a `content` type makes things nicer too
|
||||||
|
call_id: str
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
content: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CompletionMessage(BaseModel):
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
content: InterleavedContent
|
||||||
|
stop_reason: StopReason
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
Message = Annotated[
|
||||||
|
Union[
|
||||||
|
UserMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolResponseMessage,
|
||||||
|
CompletionMessage,
|
||||||
|
],
|
||||||
|
Field(discriminator="role"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolResponse(BaseModel):
|
||||||
|
call_id: str
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
content: InterleavedContent
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolChoice(Enum):
|
||||||
|
auto = "auto"
|
||||||
|
required = "required"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TokenLogProbs(BaseModel):
|
||||||
|
logprobs_by_token: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseEventType(Enum):
|
class ChatCompletionResponseEventType(Enum):
|
||||||
start = "start"
|
start = "start"
|
||||||
|
@ -117,7 +218,7 @@ ResponseFormat = Annotated[
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedTextMedia
|
content: InterleavedContent
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
|
@ -230,7 +331,7 @@ class Inference(Protocol):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -258,5 +359,5 @@ class Inference(Protocol):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse: ...
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
|
@ -24,7 +24,8 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||||
|
from llama_models.llama3.api.datatypes import RawContent, RawMessage
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
|
@ -38,10 +39,6 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
augment_content_with_response_format_prompt,
|
|
||||||
chat_completion_request_to_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
Fp8QuantizationConfig,
|
Fp8QuantizationConfig,
|
||||||
|
@ -53,6 +50,14 @@ from .config import (
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||||
|
messages: List[RawMessage]
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequestWithRawContent(CompletionRequest):
|
||||||
|
content: RawContent
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
|
||||||
|
@ -206,7 +211,7 @@ class Llama:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
model_input: ModelInput,
|
model_input: LLMInput,
|
||||||
max_gen_len: int,
|
max_gen_len: int,
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
|
@ -343,7 +348,7 @@ class Llama:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequestWithRawContent,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
|
@ -354,10 +359,7 @@ class Llama:
|
||||||
):
|
):
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
content = augment_content_with_response_format_prompt(
|
model_input = self.formatter.encode_content(request.content)
|
||||||
request.response_format, request.content
|
|
||||||
)
|
|
||||||
model_input = self.formatter.encode_content(content)
|
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=model_input,
|
model_input=model_input,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
|
@ -374,10 +376,8 @@ class Llama:
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequestWithRawContent,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
messages = chat_completion_request_to_messages(request, self.llama_model)
|
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if (
|
||||||
|
@ -389,7 +389,7 @@ class Llama:
|
||||||
|
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
messages,
|
request.messages,
|
||||||
request.tool_prompt_format,
|
request.tool_prompt_format,
|
||||||
),
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
|
|
|
@ -7,25 +7,59 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.datatypes import Model
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
RawMessage,
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
|
Inference,
|
||||||
|
InterleavedContent,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
TokenLogProbs,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
augment_content_with_response_format_prompt,
|
||||||
request_has_media,
|
chat_completion_request_to_messages,
|
||||||
|
interleaved_content_convert_to_raw,
|
||||||
)
|
)
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import (
|
||||||
|
ChatCompletionRequestWithRawContent,
|
||||||
|
CompletionRequestWithRawContent,
|
||||||
|
Llama,
|
||||||
|
)
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -90,7 +124,7 @@ class MetaReferenceInferenceImpl(
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -99,6 +133,7 @@ class MetaReferenceInferenceImpl(
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
|
content = augment_content_with_response_format_prompt(response_format, content)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -108,7 +143,7 @@ class MetaReferenceInferenceImpl(
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
request = await request_with_localized_media(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_completion(request)
|
return self._stream_completion(request)
|
||||||
|
@ -233,7 +268,13 @@ class MetaReferenceInferenceImpl(
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
request = await request_with_localized_media(request)
|
|
||||||
|
# augment and rewrite messages depending on the model
|
||||||
|
request.messages = chat_completion_request_to_messages(
|
||||||
|
request, self.model.core_model_id.value
|
||||||
|
)
|
||||||
|
# download media and convert to raw content so we can send it to the model
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
|
@ -406,29 +447,16 @@ class MetaReferenceInferenceImpl(
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
|
||||||
async def request_with_localized_media(
|
async def convert_request_to_raw(
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||||
if not request_has_media(request):
|
|
||||||
return request
|
|
||||||
|
|
||||||
async def _convert_single_content(content):
|
|
||||||
if isinstance(content, ImageMedia):
|
|
||||||
url = await convert_image_media_to_url(content, download=True)
|
|
||||||
return ImageMedia(image=URL(uri=url))
|
|
||||||
else:
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _convert_content(content):
|
|
||||||
if isinstance(content, list):
|
|
||||||
return [await _convert_single_content(c) for c in content]
|
|
||||||
else:
|
|
||||||
return await _convert_single_content(content)
|
|
||||||
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
messages = []
|
||||||
for m in request.messages:
|
for m in request.messages:
|
||||||
m.content = await _convert_content(m.content)
|
content = await interleaved_content_convert_to_raw(m.content)
|
||||||
|
messages.append(RawMessage(**m.model_dump(), content=content))
|
||||||
|
request.messages = messages
|
||||||
else:
|
else:
|
||||||
request.content = await _convert_content(request.content)
|
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
|
@ -29,7 +30,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_message_to_dict,
|
interleaved_content_as_str,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -108,7 +109,7 @@ class FireworksInferenceAdapter(
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -238,7 +239,7 @@ class FireworksInferenceAdapter(
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_dict(m) for m in request.messages
|
await convert_message_to_openai_dict(m) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
@ -265,7 +266,7 @@ class FireworksInferenceAdapter(
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
@ -277,7 +278,7 @@ class FireworksInferenceAdapter(
|
||||||
), "Fireworks does not support media for embeddings"
|
), "Fireworks does not support media for embeddings"
|
||||||
response = self._get_client().embeddings.create(
|
response = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_image_media_to_url,
|
convert_image_content_to_url,
|
||||||
|
interleaved_content_as_str,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -141,7 +142,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -234,7 +235,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present:
|
||||||
contents = [
|
contents = [
|
||||||
await convert_message_to_dict_for_ollama(m)
|
await convert_message_to_openai_dict_for_ollama(m)
|
||||||
for m in request.messages
|
for m in request.messages
|
||||||
]
|
]
|
||||||
# flatten the list of lists
|
# flatten the list of lists
|
||||||
|
@ -320,7 +321,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
@ -329,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
), "Ollama does not support media for embeddings"
|
), "Ollama does not support media for embeddings"
|
||||||
response = await self.client.embed(
|
response = await self.client.embed(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
embeddings = response["embeddings"]
|
embeddings = response["embeddings"]
|
||||||
|
|
||||||
|
@ -358,21 +359,23 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
if isinstance(content, ImageMedia):
|
if isinstance(content, ImageContentItem):
|
||||||
return {
|
return {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"images": [
|
"images": [
|
||||||
await convert_image_media_to_url(
|
await convert_image_content_to_url(
|
||||||
content, download=True, include_format=False
|
content, download=True, include_format=False
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
text = content.text if isinstance(content, TextContentItem) else content
|
||||||
|
assert isinstance(text, str)
|
||||||
return {
|
return {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": content,
|
"content": text,
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
|
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_message_to_dict,
|
interleaved_content_as_str,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,7 +93,7 @@ class TogetherInferenceAdapter(
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -230,7 +231,7 @@ class TogetherInferenceAdapter(
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_dict(m) for m in request.messages
|
await convert_message_to_openai_dict(m) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
@ -252,7 +253,7 @@ class TogetherInferenceAdapter(
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
assert all(
|
assert all(
|
||||||
|
@ -260,7 +261,7 @@ class TogetherInferenceAdapter(
|
||||||
), "Together does not support media for embeddings"
|
), "Together does not support media for embeddings"
|
||||||
r = self._get_client().embeddings.create(
|
r = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
embeddings = [item.embedding for item in r.data]
|
embeddings = [item.embedding for item in r.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
|
@ -30,7 +31,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_message_to_dict,
|
interleaved_content_as_str,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -71,7 +72,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -163,7 +164,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if media_present:
|
if media_present:
|
||||||
# vllm does not seem to work well with image urls, so we download the images
|
# vllm does not seem to work well with image urls, so we download the images
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_dict(m, download=True)
|
await convert_message_to_openai_dict(m, download=True)
|
||||||
for m in request.messages
|
for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
@ -202,7 +203,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
@ -215,7 +216,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
), "VLLM does not support media for embeddings"
|
), "VLLM does not support media for embeddings"
|
||||||
response = self.client.embeddings.create(
|
response = self.client.embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,12 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_image_content_to_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
|
@ -246,3 +249,32 @@ async def process_chat_completion_stream_response(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_message_to_openai_dict(
|
||||||
|
message: Message, download: bool = False
|
||||||
|
) -> dict:
|
||||||
|
async def _convert_content(content) -> dict:
|
||||||
|
if isinstance(content, ImageContentItem):
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": await convert_image_content_to_url(
|
||||||
|
content, download=download
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
text = content.text if isinstance(content, TextContentItem) else content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
return {"type": "text", "text": text}
|
||||||
|
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
content = [await _convert_content(c) for c in message.content]
|
||||||
|
else:
|
||||||
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": message.role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
|
@ -4,19 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple
|
import re
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from PIL import Image as PIL_Image
|
from llama_models.llama3.api.datatypes import (
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
RawContent,
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
RawContentItem,
|
||||||
from llama_models.datatypes import ModelFamily
|
RawMediaItem,
|
||||||
|
RawTextItem,
|
||||||
|
Role,
|
||||||
|
ToolChoice,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
from llama_models.llama3.prompt_templates import (
|
from llama_models.llama3.prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
FunctionTagCustomToolGenerator,
|
FunctionTagCustomToolGenerator,
|
||||||
|
@ -25,15 +32,89 @@ from llama_models.llama3.prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
ImageContentItem,
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
ResponseFormatType,
|
||||||
|
SystemMessage,
|
||||||
|
TextContentItem,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def content_has_media(content: InterleavedTextMedia):
|
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||||
|
def _process(c) -> str:
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
elif isinstance(c, ImageContentItem):
|
||||||
|
return "<image>"
|
||||||
|
elif isinstance(c, TextContentItem):
|
||||||
|
return c.text
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
return sep.join(_process(c) for c in content)
|
||||||
|
else:
|
||||||
|
return _process(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def interleaved_content_convert_to_raw(
|
||||||
|
content: InterleavedContent,
|
||||||
|
) -> RawContent:
|
||||||
|
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
|
||||||
|
|
||||||
|
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
|
||||||
|
if isinstance(c, str):
|
||||||
|
return RawTextItem(text=c)
|
||||||
|
elif isinstance(c, TextContentItem):
|
||||||
|
return RawTextItem(text=c.text)
|
||||||
|
elif isinstance(c, ImageContentItem):
|
||||||
|
# load image and return PIL version
|
||||||
|
img = c.data
|
||||||
|
if isinstance(img, URL):
|
||||||
|
if img.uri.startswith("data"):
|
||||||
|
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
|
||||||
|
if not match:
|
||||||
|
raise ValueError("Invalid data URL format")
|
||||||
|
_, image_data = match.groups()
|
||||||
|
data = base64.b64decode(image_data)
|
||||||
|
elif img.uri.startswith("file://"):
|
||||||
|
path = img.uri[len("file://") :]
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = f.read() # type: ignore
|
||||||
|
elif img.uri.startswith("http"):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(img.uri)
|
||||||
|
data = response.content
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported URL type")
|
||||||
|
return RawMediaItem(data=data)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
return await asyncio.gather(*(_localize_single(c) for c in content))
|
||||||
|
else:
|
||||||
|
return await _localize_single(content)
|
||||||
|
|
||||||
|
|
||||||
|
def content_has_media(content: InterleavedContent):
|
||||||
def _has_media_content(c):
|
def _has_media_content(c):
|
||||||
return isinstance(c, ImageMedia)
|
return isinstance(c, ImageContentItem)
|
||||||
|
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
return any(_has_media_content(c) for c in content)
|
return any(_has_media_content(c) for c in content)
|
||||||
|
@ -52,37 +133,29 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||||
return content_has_media(request.content)
|
return content_has_media(request.content)
|
||||||
|
|
||||||
|
|
||||||
async def convert_image_media_to_url(
|
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||||
media: ImageMedia, download: bool = False, include_format: bool = True
|
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
||||||
) -> str:
|
async with httpx.AsyncClient() as client:
|
||||||
if isinstance(media.image, PIL_Image.Image):
|
r = await client.get(media.image.uri)
|
||||||
if media.image.format == "PNG":
|
content = r.content
|
||||||
format = "png"
|
content_type = r.headers.get("content-type")
|
||||||
elif media.image.format == "GIF":
|
if content_type:
|
||||||
format = "gif"
|
format = content_type.split("/")[-1]
|
||||||
elif media.image.format == "JPEG":
|
else:
|
||||||
format = "jpeg"
|
format = "png"
|
||||||
else:
|
return content, format
|
||||||
raise ValueError(f"Unsupported image format {media.image.format}")
|
|
||||||
|
|
||||||
bytestream = io.BytesIO()
|
|
||||||
media.image.save(bytestream, format=media.image.format)
|
|
||||||
bytestream.seek(0)
|
|
||||||
content = bytestream.getvalue()
|
|
||||||
else:
|
else:
|
||||||
if not download:
|
image = PIL_Image.open(media.data)
|
||||||
return media.image.uri
|
return media.data, image.format
|
||||||
else:
|
|
||||||
assert isinstance(media.image, URL)
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
r = await client.get(media.image.uri)
|
|
||||||
content = r.content
|
|
||||||
content_type = r.headers.get("content-type")
|
|
||||||
if content_type:
|
|
||||||
format = content_type.split("/")[-1]
|
|
||||||
else:
|
|
||||||
format = "png"
|
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_image_content_to_url(
|
||||||
|
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||||
|
) -> str:
|
||||||
|
if isinstance(media.data, URL) and not download:
|
||||||
|
return media.image.uri
|
||||||
|
|
||||||
|
content, format = await localize_image_content(media)
|
||||||
if include_format:
|
if include_format:
|
||||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
|
@ -91,32 +164,6 @@ async def convert_image_media_to_url(
|
||||||
return base64.b64encode(content).decode("utf-8")
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
# TODO: name this function better! this is about OpenAI compatibile image
|
|
||||||
# media conversion of the message. this should probably go in openai_compat.py
|
|
||||||
async def convert_message_to_dict(message: Message, download: bool = False) -> dict:
|
|
||||||
async def _convert_content(content) -> dict:
|
|
||||||
if isinstance(content, ImageMedia):
|
|
||||||
return {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": await convert_image_media_to_url(content, download=download),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
assert isinstance(content, str)
|
|
||||||
return {"type": "text", "text": content}
|
|
||||||
|
|
||||||
if isinstance(message.content, list):
|
|
||||||
content = [await _convert_content(c) for c in message.content]
|
|
||||||
else:
|
|
||||||
content = [await _convert_content(message.content)]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"role": message.role,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def completion_request_to_prompt(
|
def completion_request_to_prompt(
|
||||||
request: CompletionRequest, formatter: ChatFormat
|
request: CompletionRequest, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -330,7 +377,7 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
||||||
if existing_system_message:
|
if existing_system_message:
|
||||||
sys_content += interleaved_text_media_as_str(
|
sys_content += interleaved_content_as_str(
|
||||||
existing_system_message.content, sep="\n"
|
existing_system_message.content, sep="\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue