refactor: remove dead inference API code and clean up imports (#4093)

# What does this PR do?

Delete ~2,000 lines of dead code from the old bespoke inference API that
was replaced by OpenAI-only API. This includes removing unused type
conversion functions, dead provider methods, and event_logger.py.

Clean up imports across the codebase to remove references to deleted
types. This eliminates unnecessary
code and dependencies, helping isolate the API package as a
self-contained module.

This is the last interdependency between the .api package and "exterior"
packages, meaning that now every other package in llama stack imports
the API, not the other way around.

## Test Plan

this is a structural change, no tests needed.

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-10 18:29:24 -05:00 committed by GitHub
parent 433438cfc0
commit 43adc23ef6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 593 additions and 2141 deletions

View file

@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResponseFormat,
ResponseFormatType,
SystemMessage,
SystemMessageBehavior,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import (
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
@ -103,28 +88,6 @@ def interleaved_content_as_str(
return _process(content)
async def convert_request_to_raw(
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request
async def interleaved_content_convert_to_raw(
content: InterleavedContent,
) -> RawContent:
@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw(
return await _localize_single(content)
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
if isinstance(message, OpenAIUserMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="user", content=content)
elif isinstance(message, OpenAISystemMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="system", content=content)
elif isinstance(message, OpenAIAssistantMessageParam):
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
tool_calls = []
if message.tool_calls:
for tc in message.tool_calls:
if tc.function:
tool_calls.append(
ToolCall(
call_id=tc.id or "",
tool_name=tc.function.name or "",
arguments=tc.function.arguments or "{}",
)
)
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
elif isinstance(message, OpenAIToolMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="tool", content=content)
else:
# Handle OpenAIDeveloperMessageParam if needed
raise ValueError(f"Unsupported message type: {type(message)}")
def content_has_media(content: InterleavedContent):
def _has_media_content(c):
return isinstance(c, ImageContentItem)
@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent):
return _has_media_content(content)
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
if uri.startswith("http"):
async with httpx.AsyncClient() as client:
@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens),
)
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")
return request.messages
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
):
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: ResponseFormat | None):
if not fmt:
return None
@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages.append(SystemMessage(content=sys_content))
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""