From 1b77826aba20725bbf5c75e879f7548c3e6c6973 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Mon, 10 Nov 2025 15:17:08 -0500 Subject: [PATCH] fix: address review comments This commit addresses comments regarding the OpenAI chat completion implementation in the meta_reference provider. Tool Augmentation - Add `augment_raw_messages_for_tools()` to properly inject tool definitions into prompts - Support model-family-specific tool formats: * Llama 3.1/3.2 multimodal: JsonCustomToolGenerator with JSON format * Llama 3.2/3.3/4: PythonListCustomToolGenerator with Python list format - Handle tool_choice hints (auto/required/specific tool) - Preserve existing system messages while adding tool context Streaming & Tool Call Detection - Implement streaming support via `params.stream` with `_stream_chat_completion()` - Add tool call detection by decoding assistant messages after generation - Set proper `finish_reason` based on content ("stop" vs "tool_calls") - Convert internal ToolCall format to OpenAI-compatible types - Stream chunks incrementally with proper delta formatting Type Corrections - Fix response_format handling in generators.py to properly extract schema from OpenAIJSONSchema TypedDict and use correct ResponseFormatType enum - Use correct OpenAI types: OpenAIChatCompletionToolCall, OpenAIChunkChoice, OpenAIChoiceDelta, OpenAIChatCompletionToolCallFunction Signed-off-by: Charlie Doern --- .../inference/meta_reference/generators.py | 12 +- .../inference/meta_reference/inference.py | 342 +++++++++++++++++- 2 files changed, 345 insertions(+), 9 deletions(-) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/generators.py b/src/llama_stack/providers/inline/inference/meta_reference/generators.py index 02494120b..51a2ddfad 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -14,7 +14,9 @@ from llama_stack.apis.inference import ( GreedySamplingStrategy, JsonSchemaResponseFormat, OpenAIChatCompletionRequestWithExtraBody, + OpenAIResponseFormatJSONSchema, ResponseFormat, + ResponseFormatType, SamplingParams, TopPSamplingStrategy, ) @@ -163,7 +165,8 @@ class LlamaGenerator: sampling_params = SamplingParams() if request.temperature is not None or request.top_p is not None: sampling_params.strategy = TopPSamplingStrategy( - temperature=request.temperature or 1.0, top_p=request.top_p or 1.0 + temperature=request.temperature if request.temperature is not None else 1.0, + top_p=request.top_p if request.top_p is not None else 1.0, ) if request.max_tokens: sampling_params.max_tokens = request.max_tokens @@ -177,9 +180,12 @@ class LlamaGenerator: # Get logits processor for response format logits_processor = None if request.response_format: - if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema": + if isinstance(request.response_format, OpenAIResponseFormatJSONSchema): + # Extract the actual schema from OpenAIJSONSchema TypedDict + schema_dict = request.response_format.json_schema.get("schema") or {} json_schema_format = JsonSchemaResponseFormat( - type="json_schema", json_schema=request.response_format.get("json_schema", {}) + type=ResponseFormatType.json_schema, + json_schema=schema_dict, ) logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 5acf0e26d..ef21132a0 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,6 +16,8 @@ from llama_stack.apis.inference import ( OpenAIChatCompletionUsage, OpenAIChoice, OpenAICompletionRequestWithExtraBody, + OpenAIUserMessageParam, + ToolChoice, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -24,12 +26,20 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat +from llama_stack.models.llama.llama3.prompt_templates import ( + JsonCustomToolGenerator, + SystemDefaultGenerator, +) from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.models.llama.sku_types import ModelFamily +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -49,6 +59,170 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) +def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition: + """Convert OpenAI tool format to ToolDefinition format.""" + # OpenAI tools have function.name and function.parameters + return ToolDefinition( + tool_name=tool.function.name, + description=tool.function.description or "", + parameters=tool.function.parameters or {}, + ) + + +def _get_tool_choice_prompt(tool_choice, tools) -> str: + """Generate prompt text for tool_choice behavior.""" + if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto": + return "" + elif tool_choice == ToolChoice.required or tool_choice == "required": + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none or tool_choice == "none": + return "" + else: + # Specific tool specified + return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def _raw_content_as_str(content) -> str: + """Convert RawContent to string for system messages.""" + if isinstance(content, str): + return content + elif isinstance(content, RawTextItem): + return content.text + elif isinstance(content, list): + return "\n".join(_raw_content_as_str(c) for c in content) + else: + return "" + + +def _augment_raw_messages_for_tools_llama_3_1( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 3.1 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions first (if present) + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # For OpenAI format, all tools are custom (have string names) + tool_gen = JsonCustomToolGenerator() + tool_template = tool_gen.gen(tool_definitions) + sys_content += tool_template.render() + sys_content += "\n" + + # Add default system prompt + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + sys_content += default_template.render() + + # Add existing system message if present + if existing_system_message: + sys_content += "\n" + _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + # Create new system message + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + + return [new_system_message] + messages + + +def _augment_raw_messages_for_tools_llama_4( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions if present + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # Use python_list format for Llama 4 + tool_gen = PythonListCustomToolGeneratorLlama4() + system_prompt = None + if existing_system_message: + system_prompt = _raw_content_as_str(existing_system_message.content) + + tool_template = tool_gen.gen(tool_definitions, system_prompt) + sys_content = tool_template.render() + elif existing_system_message: + # No tools, just use existing system message + sys_content = _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + if sys_content: + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + return [new_system_message] + messages + + return messages + + +def augment_raw_messages_for_tools( + raw_messages: list[RawMessage], + params: OpenAIChatCompletionRequestWithExtraBody, + llama_model, +) -> list[RawMessage]: + """Augment raw messages with tool definitions based on model family.""" + if not params.tools: + return raw_messages + + # Determine augmentation strategy based on model family + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # Llama 3.1 and Llama 3.2 multimodal use JSON format + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + elif llama_model.model_family in ( + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ): + # Llama 3.2/3.3/4 use python_list format + return _augment_raw_messages_for_tools_llama_4( + raw_messages, + params.tools, + params.tool_choice, + ) + else: + # Default to Llama 3.1 style + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + + def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -141,7 +315,6 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") - from llama_stack.apis.inference import OpenAIUserMessageParam await self.openai_chat_completion( params=OpenAIChatCompletionRequestWithExtraBody( @@ -167,10 +340,17 @@ class MetaReferenceInferenceImpl( self.check_model(params) # Convert OpenAI messages to RawMessages - from llama_stack.providers.utils.inference.prompt_adapter import convert_openai_message_to_raw_message + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, + decode_assistant_message, + ) raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages] + # Augment messages with tool definitions if tools are present + raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model) + # Call generator's chat_completion method (works for both single-GPU and model-parallel) if isinstance(self.generator, LlamaGenerator): generator = self.generator.chat_completion(params, raw_messages) @@ -178,14 +358,56 @@ class MetaReferenceInferenceImpl( # Model parallel: submit task to process group generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages])) - # Collect all generated text + # Check if streaming is requested + if params.stream: + return self._stream_chat_completion(generator, params) + + # Non-streaming: collect all generated text generated_text = "" for result_batch in generator: for result in result_batch: if not result.ignore_token and result.source == "output": generated_text += result.text + # Decode assistant message to extract tool calls and determine stop_reason + # Default to end_of_turn if generation completed normally + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # Convert tool calls to OpenAI format + openai_tool_calls = None + if decoded_message.tool_calls: + from llama_stack.apis.inference import ( + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + ) + + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Determine finish_reason based on whether tool calls are present + finish_reason = "tool_calls" if openai_tool_calls else "stop" + + # Extract content from decoded message + content = "" + if isinstance(decoded_message.content, str): + content = decoded_message.content + elif isinstance(decoded_message.content, list): + for item in decoded_message.content: + if isinstance(item, RawTextItem): + content += item.text + # Create OpenAI response + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" created = int(time.time()) @@ -199,9 +421,10 @@ class MetaReferenceInferenceImpl( index=0, message=OpenAIAssistantMessageParam( role="assistant", - content=generated_text, + content=content, + tool_calls=openai_tool_calls, ), - finish_reason="stop", + finish_reason=finish_reason, logprobs=None, ) ], @@ -211,3 +434,110 @@ class MetaReferenceInferenceImpl( total_tokens=0, # TODO: calculate properly ), ) + + async def _stream_chat_completion( + self, + generator, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream chat completion chunks as they're generated.""" + from llama_stack.apis.inference import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoiceDelta, + OpenAIChunkChoice, + ) + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message + + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + generated_text = "" + + # Yield chunks as tokens are generated + for result_batch in generator: + for result in result_batch: + if result.ignore_token or result.source != "output": + continue + + generated_text += result.text + + # Yield delta chunk with the new text + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + content=result.text, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + # After generation completes, decode the full message to extract tool calls + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # If tool calls are present, yield a final chunk with tool_calls + if decoded_message.tool_calls: + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Yield chunk with tool_calls + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + tool_calls=openai_tool_calls, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + finish_reason = "tool_calls" + else: + finish_reason = "stop" + + # Yield final chunk with finish_reason + final_chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta(), + finish_reason=finish_reason, + logprobs=None, + ) + ], + ) + yield final_chunk