feat: make multi-turn tool call tests work with llama4 (#1886)

Running full Tool Calling required some updates to work e2e.
- Remove `python_start` and `python_end` tags 
- Tool Call messages and Tool Resposne messages should end with
`<|eom|>`
- System prompt needed updates 
```
You are a helpful assisant who can can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
```

### Test Plan 
- Start server with meta-reference 
```
LLAMA_STACK_DISABLE_VERSION_CHECK=1 LLAMA_MODELS_DEBUG=1 INFERENCE_MODEL=meta-llama/$MODEL  llama stack run meta-reference-gpu 
``` 
- Added **NEW** tests with 5 test cases for multi-turn tool calls 
```
pytest -s -v --stack-config http://localhost:8321 tests/integration/inference/test_text_inference.py --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
``` 
- Also verified all vision and agent tests pass
This commit is contained in:
Hardik Shah 2025-04-06 19:14:21 -07:00 committed by GitHub
parent 5a31e66a91
commit 28e262ecdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 468 additions and 18 deletions

View file

@ -279,6 +279,10 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{% endif -%}
{%- endfor %}
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
"""
)
return PromptTemplate(

View file

@ -216,9 +216,12 @@ class ChatFormat:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
elif message.role == "tool":
eom = True
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images

View file

@ -6,8 +6,11 @@
import asyncio
import logging
import os
from typing import AsyncGenerator, List, Optional, Union
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
@ -338,6 +341,9 @@ class MetaReferenceInferenceImpl(
stop_reason = None
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
tokens.append(token_result.token)
if token_result.token == tokenizer.eot_id:
@ -386,6 +392,9 @@ class MetaReferenceInferenceImpl(
ipython = False
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):