mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? Delete ~2,000 lines of dead code from the old bespoke inference API that was replaced by OpenAI-only API. This includes removing unused type conversion functions, dead provider methods, and event_logger.py. Clean up imports across the codebase to remove references to deleted types. This eliminates unnecessary code and dependencies, helping isolate the API package as a self-contained module. This is the last interdependency between the .api package and "exterior" packages, meaning that now every other package in llama stack imports the API, not the other way around. ## Test Plan this is a structural change, no tests needed. --------- Signed-off-by: Charlie Doern <cdoern@redhat.com>
282 lines
10 KiB
Python
282 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import base64
|
|
import io
|
|
import json
|
|
import re
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from PIL import Image as PIL_Image
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
CompletionRequest,
|
|
OpenAIAssistantMessageParam,
|
|
OpenAIChatCompletionContentPartImageParam,
|
|
OpenAIChatCompletionContentPartTextParam,
|
|
OpenAIFile,
|
|
OpenAIMessageParam,
|
|
OpenAISystemMessageParam,
|
|
OpenAIToolMessageParam,
|
|
OpenAIUserMessageParam,
|
|
ResponseFormat,
|
|
ResponseFormatType,
|
|
ToolChoice,
|
|
)
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.datatypes import (
|
|
RawContent,
|
|
RawContentItem,
|
|
RawMediaItem,
|
|
RawMessage,
|
|
RawTextItem,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
|
|
|
log = get_logger(name=__name__, category="providers::utils")
|
|
|
|
|
|
class CompletionRequestWithRawContent(CompletionRequest):
|
|
content: RawContent
|
|
|
|
|
|
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
|
formatter = ChatFormat(Tokenizer.get_instance())
|
|
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
|
|
|
|
|
def interleaved_content_as_str(
|
|
content: Any,
|
|
sep: str = " ",
|
|
) -> str:
|
|
if content is None:
|
|
return ""
|
|
|
|
def _process(c) -> str:
|
|
if isinstance(c, str):
|
|
return c
|
|
elif isinstance(c, TextContentItem) or isinstance(c, OpenAIChatCompletionContentPartTextParam):
|
|
return c.text
|
|
elif isinstance(c, ImageContentItem) or isinstance(c, OpenAIChatCompletionContentPartImageParam):
|
|
return "<image>"
|
|
elif isinstance(c, OpenAIFile):
|
|
return "<file>"
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {type(c)}")
|
|
|
|
if isinstance(content, list):
|
|
return sep.join(_process(c) for c in content)
|
|
else:
|
|
return _process(content)
|
|
|
|
|
|
async def interleaved_content_convert_to_raw(
|
|
content: InterleavedContent,
|
|
) -> RawContent:
|
|
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
|
|
|
|
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
|
|
if isinstance(c, str):
|
|
return RawTextItem(text=c)
|
|
elif isinstance(c, TextContentItem):
|
|
return RawTextItem(text=c.text)
|
|
elif isinstance(c, ImageContentItem):
|
|
image = c.image
|
|
if image.url:
|
|
# Load image bytes from URL
|
|
if image.url.uri.startswith("data"):
|
|
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
|
if not match:
|
|
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
|
_, image_data = match.groups()
|
|
data = base64.b64decode(image_data)
|
|
elif image.url.uri.startswith("file://"):
|
|
path = image.url.uri[len("file://") :]
|
|
with open(path, "rb") as f:
|
|
data = f.read() # type: ignore
|
|
elif image.url.uri.startswith("http"):
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(image.url.uri)
|
|
data = response.content
|
|
else:
|
|
raise ValueError("Unsupported URL type")
|
|
elif image.data:
|
|
# data is a base64 encoded string, decode it to bytes for RawMediaItem
|
|
data = base64.b64decode(image.data)
|
|
else:
|
|
raise ValueError("No data or URL provided")
|
|
|
|
return RawMediaItem(data=data)
|
|
else:
|
|
raise ValueError(f"Unsupported content type: {type(c)}")
|
|
|
|
if isinstance(content, list):
|
|
return await asyncio.gather(*(_localize_single(c) for c in content))
|
|
else:
|
|
return await _localize_single(content)
|
|
|
|
|
|
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
|
|
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
|
|
if isinstance(message, OpenAIUserMessageParam):
|
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
|
return RawMessage(role="user", content=content)
|
|
elif isinstance(message, OpenAISystemMessageParam):
|
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
|
return RawMessage(role="system", content=content)
|
|
elif isinstance(message, OpenAIAssistantMessageParam):
|
|
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
|
|
tool_calls = []
|
|
if message.tool_calls:
|
|
for tc in message.tool_calls:
|
|
if tc.function:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
call_id=tc.id or "",
|
|
tool_name=tc.function.name or "",
|
|
arguments=tc.function.arguments or "{}",
|
|
)
|
|
)
|
|
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
|
|
elif isinstance(message, OpenAIToolMessageParam):
|
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
|
return RawMessage(role="tool", content=content)
|
|
else:
|
|
# Handle OpenAIDeveloperMessageParam if needed
|
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
|
|
|
|
|
def content_has_media(content: InterleavedContent):
|
|
def _has_media_content(c):
|
|
return isinstance(c, ImageContentItem)
|
|
|
|
if isinstance(content, list):
|
|
return any(_has_media_content(c) for c in content)
|
|
else:
|
|
return _has_media_content(content)
|
|
|
|
|
|
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
|
if uri.startswith("http"):
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.get(uri)
|
|
content = r.content
|
|
content_type = r.headers.get("content-type")
|
|
if content_type:
|
|
format = content_type.split("/")[-1]
|
|
else:
|
|
format = "png"
|
|
|
|
return content, format
|
|
elif uri.startswith("data"):
|
|
# data:image/{format};base64,{data}
|
|
match = re.match(r"data:image/(\w+);base64,(.+)", uri)
|
|
if not match:
|
|
raise ValueError(f"Invalid data URL format, {uri[:40]}...")
|
|
fmt, image_data = match.groups()
|
|
content = base64.b64decode(image_data)
|
|
return content, fmt
|
|
else:
|
|
return None
|
|
|
|
|
|
async def convert_image_content_to_url(
|
|
media: ImageContentItem, download: bool = False, include_format: bool = True
|
|
) -> str:
|
|
image = media.image
|
|
if image.url and (not download or image.url.uri.startswith("data")):
|
|
return image.url.uri
|
|
|
|
if image.data:
|
|
# data is a base64 encoded string, decode it to bytes first
|
|
# TODO(mf): do this more efficiently, decode less
|
|
content = base64.b64decode(image.data)
|
|
pil_image = PIL_Image.open(io.BytesIO(content))
|
|
format = pil_image.format
|
|
else:
|
|
localize_result = await localize_image_content(image.url.uri)
|
|
if localize_result is None:
|
|
raise ValueError(f"Failed to localize image content from {image.url.uri}")
|
|
content, format = localize_result
|
|
|
|
if include_format:
|
|
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
|
else:
|
|
return base64.b64encode(content).decode("utf-8")
|
|
|
|
|
|
def augment_content_with_response_format_prompt(response_format, content):
|
|
if fmt_prompt := response_format_prompt(response_format):
|
|
if isinstance(content, list):
|
|
return content + [TextContentItem(text=fmt_prompt)]
|
|
elif isinstance(content, str):
|
|
return [TextContentItem(text=content), TextContentItem(text=fmt_prompt)]
|
|
else:
|
|
return [content, TextContentItem(text=fmt_prompt)]
|
|
|
|
return content
|
|
|
|
|
|
def response_format_prompt(fmt: ResponseFormat | None):
|
|
if not fmt:
|
|
return None
|
|
|
|
if fmt.type == ResponseFormatType.json_schema.value:
|
|
return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
|
|
elif fmt.type == ResponseFormatType.grammar.value:
|
|
raise NotImplementedError("Grammar response format not supported yet")
|
|
else:
|
|
raise ValueError(f"Unknown response format {fmt.type}")
|
|
|
|
|
|
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
|
if tool_choice == ToolChoice.auto:
|
|
return ""
|
|
elif tool_choice == ToolChoice.required:
|
|
return "You MUST use one of the provided functions/tools to answer the user query."
|
|
elif tool_choice == ToolChoice.none:
|
|
# tools are already not passed in
|
|
return ""
|
|
else:
|
|
# specific tool
|
|
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
|
|
|
|
|
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
|
llama_model = resolve_model(model)
|
|
if llama_model is None:
|
|
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
|
return ToolPromptFormat.json
|
|
|
|
if llama_model.model_family == ModelFamily.llama3_1 or (
|
|
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
|
):
|
|
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
|
return ToolPromptFormat.json
|
|
elif llama_model.model_family in (
|
|
ModelFamily.llama3_2,
|
|
ModelFamily.llama3_3,
|
|
ModelFamily.llama4,
|
|
):
|
|
# llama3.2 and llama3.3 models follow the same tool prompt format
|
|
return ToolPromptFormat.python_list
|
|
else:
|
|
return ToolPromptFormat.json
|