fix: address review comments

This commit addresses comments  regarding the
 OpenAI chat completion implementation in the meta_reference provider.

Tool Augmentation
 - Add `augment_raw_messages_for_tools()` to properly inject tool definitions into prompts
 - Support model-family-specific tool formats:
   * Llama 3.1/3.2 multimodal: JsonCustomToolGenerator with JSON format
   * Llama 3.2/3.3/4: PythonListCustomToolGenerator with Python list format
 - Handle tool_choice hints (auto/required/specific tool)
 - Preserve existing system messages while adding tool context

Streaming & Tool Call Detection
 - Implement streaming support via `params.stream` with `_stream_chat_completion()`
 - Add tool call detection by decoding assistant messages after generation
 - Set proper `finish_reason` based on content ("stop" vs "tool_calls")
 - Convert internal ToolCall format to OpenAI-compatible types
 - Stream chunks incrementally with proper delta formatting

Type Corrections
 - Fix response_format handling in generators.py to properly extract schema from
     OpenAIJSONSchema TypedDict and use correct ResponseFormatType enum
- Use correct OpenAI types: OpenAIChatCompletionToolCall, OpenAIChunkChoice,
     OpenAIChoiceDelta, OpenAIChatCompletionToolCallFunction

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-11-10 15:17:08 -05:00
parent dac1ff1f57
commit 1b77826aba
2 changed files with 345 additions and 9 deletions

View file

@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIResponseFormatJSONSchema,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TopPSamplingStrategy,
)
@ -163,7 +165,8 @@ class LlamaGenerator:
sampling_params = SamplingParams()
if request.temperature is not None or request.top_p is not None:
sampling_params.strategy = TopPSamplingStrategy(
temperature=request.temperature or 1.0, top_p=request.top_p or 1.0
temperature=request.temperature if request.temperature is not None else 1.0,
top_p=request.top_p if request.top_p is not None else 1.0,
)
if request.max_tokens:
sampling_params.max_tokens = request.max_tokens
@ -177,9 +180,12 @@ class LlamaGenerator:
# Get logits processor for response format
logits_processor = None
if request.response_format:
if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema":
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
# Extract the actual schema from OpenAIJSONSchema TypedDict
schema_dict = request.response_format.json_schema.get("schema") or {}
json_schema_format = JsonSchemaResponseFormat(
type="json_schema", json_schema=request.response_format.get("json_schema", {})
type=ResponseFormatType.json_schema,
json_schema=schema_dict,
)
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)

View file

@ -16,6 +16,8 @@ from llama_stack.apis.inference import (
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -24,12 +26,20 @@ from llama_stack.apis.inference.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -49,6 +59,170 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
"""Convert OpenAI tool format to ToolDefinition format."""
# OpenAI tools have function.name and function.parameters
return ToolDefinition(
tool_name=tool.function.name,
description=tool.function.description or "",
parameters=tool.function.parameters or {},
)
def _get_tool_choice_prompt(tool_choice, tools) -> str:
"""Generate prompt text for tool_choice behavior."""
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
return ""
elif tool_choice == ToolChoice.required or tool_choice == "required":
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none or tool_choice == "none":
return ""
else:
# Specific tool specified
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def _raw_content_as_str(content) -> str:
"""Convert RawContent to string for system messages."""
if isinstance(content, str):
return content
elif isinstance(content, RawTextItem):
return content.text
elif isinstance(content, list):
return "\n".join(_raw_content_as_str(c) for c in content)
else:
return "<media>"
def _augment_raw_messages_for_tools_llama_3_1(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions first (if present)
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# For OpenAI format, all tools are custom (have string names)
tool_gen = JsonCustomToolGenerator()
tool_template = tool_gen.gen(tool_definitions)
sys_content += tool_template.render()
sys_content += "\n"
# Add default system prompt
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content += default_template.render()
# Add existing system message if present
if existing_system_message:
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
# Create new system message
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
def _augment_raw_messages_for_tools_llama_4(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions if present
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# Use python_list format for Llama 4
tool_gen = PythonListCustomToolGeneratorLlama4()
system_prompt = None
if existing_system_message:
system_prompt = _raw_content_as_str(existing_system_message.content)
tool_template = tool_gen.gen(tool_definitions, system_prompt)
sys_content = tool_template.render()
elif existing_system_message:
# No tools, just use existing system message
sys_content = _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
if sys_content:
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
return messages
def augment_raw_messages_for_tools(
raw_messages: list[RawMessage],
params: OpenAIChatCompletionRequestWithExtraBody,
llama_model,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions based on model family."""
if not params.tools:
return raw_messages
# Determine augmentation strategy based on model family
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# Llama 3.1 and Llama 3.2 multimodal use JSON format
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# Llama 3.2/3.3/4 use python_list format
return _augment_raw_messages_for_tools_llama_4(
raw_messages,
params.tools,
params.tool_choice,
)
else:
# Default to Llama 3.1 style
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -141,7 +315,6 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
from llama_stack.apis.inference import OpenAIUserMessageParam
await self.openai_chat_completion(
params=OpenAIChatCompletionRequestWithExtraBody(
@ -167,10 +340,17 @@ class MetaReferenceInferenceImpl(
self.check_model(params)
# Convert OpenAI messages to RawMessages
from llama_stack.providers.utils.inference.prompt_adapter import convert_openai_message_to_raw_message
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
decode_assistant_message,
)
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
# Augment messages with tool definitions if tools are present
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
if isinstance(self.generator, LlamaGenerator):
generator = self.generator.chat_completion(params, raw_messages)
@ -178,14 +358,56 @@ class MetaReferenceInferenceImpl(
# Model parallel: submit task to process group
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
# Collect all generated text
# Check if streaming is requested
if params.stream:
return self._stream_chat_completion(generator, params)
# Non-streaming: collect all generated text
generated_text = ""
for result_batch in generator:
for result in result_batch:
if not result.ignore_token and result.source == "output":
generated_text += result.text
# Decode assistant message to extract tool calls and determine stop_reason
# Default to end_of_turn if generation completed normally
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# Convert tool calls to OpenAI format
openai_tool_calls = None
if decoded_message.tool_calls:
from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
)
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Determine finish_reason based on whether tool calls are present
finish_reason = "tool_calls" if openai_tool_calls else "stop"
# Extract content from decoded message
content = ""
if isinstance(decoded_message.content, str):
content = decoded_message.content
elif isinstance(decoded_message.content, list):
for item in decoded_message.content:
if isinstance(item, RawTextItem):
content += item.text
# Create OpenAI response
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
@ -199,9 +421,10 @@ class MetaReferenceInferenceImpl(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=generated_text,
content=content,
tool_calls=openai_tool_calls,
),
finish_reason="stop",
finish_reason=finish_reason,
logprobs=None,
)
],
@ -211,3 +434,110 @@ class MetaReferenceInferenceImpl(
total_tokens=0, # TODO: calculate properly
),
)
async def _stream_chat_completion(
self,
generator,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream chat completion chunks as they're generated."""
from llama_stack.apis.inference import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoiceDelta,
OpenAIChunkChoice,
)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
generated_text = ""
# Yield chunks as tokens are generated
for result_batch in generator:
for result in result_batch:
if result.ignore_token or result.source != "output":
continue
generated_text += result.text
# Yield delta chunk with the new text
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
content=result.text,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
# After generation completes, decode the full message to extract tool calls
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# If tool calls are present, yield a final chunk with tool_calls
if decoded_message.tool_calls:
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Yield chunk with tool_calls
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
tool_calls=openai_tool_calls,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
finish_reason = "tool_calls"
else:
finish_reason = "stop"
# Yield final chunk with finish_reason
final_chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(),
finish_reason=finish_reason,
logprobs=None,
)
],
)
yield final_chunk