mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 04:39:40 +00:00
fix: Resolve Llama4 tool calling 500 errors (Issue #2584)
This commit fixes the tool calling failures with Llama4 models that were
returning 500 errors while Together API worked correctly. The root cause
was that the system was using Llama3's JSON format for all models instead
of Llama4's python_list format.
Key changes:
- NEW: llama_stack/models/llama/llama4/interface.py - Complete Llama4 interface
with python_list tool format support
- MODIFIED: prompt_adapter.py - Added model-aware decode_assistant_message()
that uses Llama4ChatFormat for llama4 models and Llama3ChatFormat for others
- MODIFIED: openai_compat.py - Updated to pass model_id parameter to enable
model-specific format detection
- MODIFIED: sku_list.py - Enhanced with provider alias support for better
model resolution
- NEW: tests/unit/models/test_decode_assistant_message.py - Comprehensive unit
tests for the new decode_assistant_message function
The fix ensures that:
- Llama4 models (meta-llama/Llama-4-*) use python_list format: [func(args)]
- Other models continue using JSON format: {"type": "function", ...}
- Backward compatibility is maintained for existing models
- Tool calling works correctly across different model families
- Graceful fallback when Llama4 dependencies are unavailable
Testing:
- All 17 unit tests pass (9 original + 8 new)
- Conditional imports prevent torch dependency issues
- Comprehensive test coverage for different model types and scenarios
Fixes #2584
This commit is contained in:
parent
f731f369a2
commit
857496ea3e
5 changed files with 482 additions and 7 deletions
|
|
@ -299,7 +299,9 @@ def process_chat_completion_response(
|
|||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
raw_message = decode_assistant_message(
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason), request.model
|
||||
)
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
|
@ -448,7 +450,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = decode_assistant_message(buffer, stop_reason)
|
||||
message = decode_assistant_message(buffer, stop_reason, request.model)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
|
|||
|
|
@ -51,9 +51,19 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
|
||||
# Conditional imports to avoid heavy dependencies during module loading
|
||||
try:
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
|
||||
LLAMA4_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Llama4 dependencies not available (e.g., torch not installed)
|
||||
LLAMA4_AVAILABLE = False
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
|
@ -69,8 +79,20 @@ class CompletionRequestWithRawContent(CompletionRequest):
|
|||
content: RawContent
|
||||
|
||||
|
||||
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
def decode_assistant_message(content: str, stop_reason: StopReason, model_id: str | None = None) -> RawMessage:
|
||||
"""Decode assistant message using the appropriate formatter for the model family."""
|
||||
if model_id and LLAMA4_AVAILABLE:
|
||||
model = resolve_model(model_id)
|
||||
if model and model.model_family == ModelFamily.llama4:
|
||||
# Use Llama4's ChatFormat for Llama4 models
|
||||
formatter = Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
else:
|
||||
# Use Llama3's ChatFormat for all other models (default)
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
else:
|
||||
# Default to Llama3 if no model specified or Llama4 not available
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue