mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: new system prompt for llama4 (#2031)
Tests: LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
parent
4bbd0c0693
commit
29072f40ab
2 changed files with 154 additions and 5 deletions
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
|||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
# llama3.2, llama3.3 follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||
elif model.model_family == ModelFamily.llama4:
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
||||
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
|||
return messages
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_2(
|
||||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> List[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue