feat: new system prompt for llama4 (#2031)

Tests:

LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v
tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B
--vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model
meta-llama/Llama-4-Scout-17B-16E-Instruct

Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
ehhuang 2025-04-25 11:29:08 -07:00 committed by GitHub
parent 4bbd0c0693
commit 29072f40ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 154 additions and 5 deletions

View file

@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
return messages
def augment_messages_for_tools_llama_3_2(
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
existing_messages = request.messages
existing_system_message = None
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"