refactor: remove dead inference API code and clean up imports (#4093)

# What does this PR do?

Delete ~2,000 lines of dead code from the old bespoke inference API that
was replaced by OpenAI-only API. This includes removing unused type
conversion functions, dead provider methods, and event_logger.py.

Clean up imports across the codebase to remove references to deleted
types. This eliminates unnecessary
code and dependencies, helping isolate the API package as a
self-contained module.

This is the last interdependency between the .api package and "exterior"
packages, meaning that now every other package in llama stack imports
the API, not the other way around.

## Test Plan

this is a structural change, no tests needed.

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-10 18:29:24 -05:00 committed by GitHub
parent 433438cfc0
commit 43adc23ef6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 593 additions and 2141 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import math
from collections.abc import Generator
from typing import Optional
import torch
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIResponseFormatJSONSchema,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
get_default_tool_prompt_format,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
return temperature, top_p
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
tool_config = request.tool_config
if tool_config is not None and tool_config.tool_prompt_format is not None:
return tool_config.tool_prompt_format
else:
return get_default_tool_prompt_format(request.model)
class LlamaGenerator:
def __init__(
self,
@ -157,55 +146,56 @@ class LlamaGenerator:
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
)
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
request: OpenAIChatCompletionRequestWithExtraBody,
raw_messages: list,
):
"""Generate chat completion using OpenAI request format.
Args:
request: OpenAI chat completion request
raw_messages: Pre-converted list of RawMessage objects
"""
# Determine tool prompt format
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
# Prepare sampling params
sampling_params = SamplingParams()
if request.temperature is not None or request.top_p is not None:
sampling_params.strategy = TopPSamplingStrategy(
temperature=request.temperature if request.temperature is not None else 1.0,
top_p=request.top_p if request.top_p is not None else 1.0,
)
if request.max_tokens:
sampling_params.max_tokens = request.max_tokens
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
# Get logits processor for response format
logits_processor = None
if request.response_format:
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
# Extract the actual schema from OpenAIJSONSchema TypedDict
schema_dict = request.response_format.json_schema.get("schema") or {}
json_schema_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema,
json_schema=schema_dict,
)
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
# Generate
yield from self.inner_generator.generate(
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
logprobs=False,
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
logits_processor=logits_processor,
)

View file

@ -5,12 +5,19 @@
# the root directory of this source tree.
import asyncio
import time
import uuid
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
"""Convert OpenAI tool format to ToolDefinition format."""
# OpenAI tools have function.name and function.parameters
return ToolDefinition(
tool_name=tool.function.name,
description=tool.function.description or "",
parameters=tool.function.parameters or {},
)
def _get_tool_choice_prompt(tool_choice, tools) -> str:
"""Generate prompt text for tool_choice behavior."""
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
return ""
elif tool_choice == ToolChoice.required or tool_choice == "required":
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none or tool_choice == "none":
return ""
else:
# Specific tool specified
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def _raw_content_as_str(content) -> str:
"""Convert RawContent to string for system messages."""
if isinstance(content, str):
return content
elif isinstance(content, RawTextItem):
return content.text
elif isinstance(content, list):
return "\n".join(_raw_content_as_str(c) for c in content)
else:
return "<media>"
def _augment_raw_messages_for_tools_llama_3_1(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions first (if present)
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# For OpenAI format, all tools are custom (have string names)
tool_gen = JsonCustomToolGenerator()
tool_template = tool_gen.gen(tool_definitions)
sys_content += tool_template.render()
sys_content += "\n"
# Add default system prompt
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content += default_template.render()
# Add existing system message if present
if existing_system_message:
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
# Create new system message
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
def _augment_raw_messages_for_tools_llama_4(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions if present
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# Use python_list format for Llama 4
tool_gen = PythonListCustomToolGeneratorLlama4()
system_prompt = None
if existing_system_message:
system_prompt = _raw_content_as_str(existing_system_message.content)
tool_template = tool_gen.gen(tool_definitions, system_prompt)
sys_content = tool_template.render()
elif existing_system_message:
# No tools, just use existing system message
sys_content = _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
if sys_content:
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
return messages
def augment_raw_messages_for_tools(
raw_messages: list[RawMessage],
params: OpenAIChatCompletionRequestWithExtraBody,
llama_model,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions based on model family."""
if not params.tools:
return raw_messages
# Determine augmentation strategy based on model family
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# Llama 3.1 and Llama 3.2 multimodal use JSON format
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# Llama 3.2/3.3/4 use python_list format
return _augment_raw_messages_for_tools_llama_4(
raw_messages,
params.tools,
params.tool_choice,
)
else:
# Default to Llama 3.1 style
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
params=OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
max_tokens=20,
)
)
log.info("Warmed up!")
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
self.check_model(params)
# Convert OpenAI messages to RawMessages
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
decode_assistant_message,
)
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
# Augment messages with tool definitions if tools are present
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
if isinstance(self.generator, LlamaGenerator):
generator = self.generator.chat_completion(params, raw_messages)
else:
# Model parallel: submit task to process group
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
# Check if streaming is requested
if params.stream:
return self._stream_chat_completion(generator, params)
# Non-streaming: collect all generated text
generated_text = ""
for result_batch in generator:
for result in result_batch:
if not result.ignore_token and result.source == "output":
generated_text += result.text
# Decode assistant message to extract tool calls and determine stop_reason
# Default to end_of_turn if generation completed normally
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# Convert tool calls to OpenAI format
openai_tool_calls = None
if decoded_message.tool_calls:
from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
)
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Determine finish_reason based on whether tool calls are present
finish_reason = "tool_calls" if openai_tool_calls else "stop"
# Extract content from decoded message
content = ""
if isinstance(decoded_message.content, str):
content = decoded_message.content
elif isinstance(decoded_message.content, list):
for item in decoded_message.content:
if isinstance(item, RawTextItem):
content += item.text
# Create OpenAI response
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
return OpenAIChatCompletion(
id=response_id,
object="chat.completion",
created=created,
model=params.model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=content,
tool_calls=openai_tool_calls,
),
finish_reason=finish_reason,
logprobs=None,
)
],
usage=OpenAIChatCompletionUsage(
prompt_tokens=0, # TODO: calculate properly
completion_tokens=0, # TODO: calculate properly
total_tokens=0, # TODO: calculate properly
),
)
async def _stream_chat_completion(
self,
generator,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream chat completion chunks as they're generated."""
from llama_stack.apis.inference import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoiceDelta,
OpenAIChunkChoice,
)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
generated_text = ""
# Yield chunks as tokens are generated
for result_batch in generator:
for result in result_batch:
if result.ignore_token or result.source != "output":
continue
generated_text += result.text
# Yield delta chunk with the new text
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
content=result.text,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
# After generation completes, decode the full message to extract tool calls
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# If tool calls are present, yield a final chunk with tool_calls
if decoded_message.tool_calls:
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Yield chunk with tool_calls
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
tool_calls=openai_tool_calls,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
finish_reason = "tool_calls"
else:
finish_reason = "stop"
# Yield final chunk with finish_reason
final_chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(),
finish_reason=finish_reason,
logprobs=None,
)
],
)
yield final_chunk

View file

@ -4,17 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable, Generator
from copy import deepcopy
from collections.abc import Callable
from functools import partial
from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .parallel_utils import ModelParallelProcessGroup
@ -23,12 +18,14 @@ class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
task_type = task[0]
if task_type == "chat_completion":
# task[1] is [params, raw_messages]
params, raw_messages = task[1]
return self.llama.chat_completion(params, raw_messages)
else:
raise ValueError(f"Unexpected task type {task[0]}")
raise ValueError(f"Unexpected task type {task_type}")
def init_model_cb(
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = get_logger(name=__name__, category="inference")
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
task: tuple[str, list]
class TaskResponse(BaseModel):
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
req: tuple[str, list],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,

View file

@ -11,9 +11,7 @@ from collections.abc import AsyncIterator
import litellm
from llama_stack.apis.inference import (
ChatCompletionRequest,
InferenceProvider,
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
@ -23,15 +21,11 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ToolChoice,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params,
)
@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin(
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
from typing import Any
input_dict: dict[str, Any] = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
# Convert to dict for manipulation
fmt_dict = dict(fmt.json_schema)
name = fmt_dict["title"]
del fmt_dict["title"]
fmt_dict["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt_dict,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field

File diff suppressed because it is too large Load diff

View file

@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResponseFormat,
ResponseFormatType,
SystemMessage,
SystemMessageBehavior,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import (
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
@ -103,28 +88,6 @@ def interleaved_content_as_str(
return _process(content)
async def convert_request_to_raw(
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request
async def interleaved_content_convert_to_raw(
content: InterleavedContent,
) -> RawContent:
@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw(
return await _localize_single(content)
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
if isinstance(message, OpenAIUserMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="user", content=content)
elif isinstance(message, OpenAISystemMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="system", content=content)
elif isinstance(message, OpenAIAssistantMessageParam):
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
tool_calls = []
if message.tool_calls:
for tc in message.tool_calls:
if tc.function:
tool_calls.append(
ToolCall(
call_id=tc.id or "",
tool_name=tc.function.name or "",
arguments=tc.function.arguments or "{}",
)
)
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
elif isinstance(message, OpenAIToolMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="tool", content=content)
else:
# Handle OpenAIDeveloperMessageParam if needed
raise ValueError(f"Unsupported message type: {type(message)}")
def content_has_media(content: InterleavedContent):
def _has_media_content(c):
return isinstance(c, ImageContentItem)
@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent):
return _has_media_content(content)
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
if uri.startswith("http"):
async with httpx.AsyncClient() as client:
@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens),
)
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")
return request.messages
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
):
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: ResponseFormat | None):
if not fmt:
return None
@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages.append(SystemMessage(content=sys_content))
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""