Support sys_prompt behavior in inference (#937)

# What does this PR do?

The current default system prompt for llama3.2 tends to overindex on
tool calling and doesn't work well when the prompt does not require tool
calling.

This PR adds an option to override the default system prompt, and
organizes tool-related configs into a new config object.

- [ ] Addresses issue (#issue)


## Test Plan

python -m unittest
llama_stack.providers.tests.inference.test_prompt_adapter


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/937).
* #938
* __->__ #937
This commit is contained in:
ehhuang 2025-02-03 23:35:16 -08:00 committed by GitHub
parent 62cd3c391e
commit c9ab72fa82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 308 additions and 48 deletions

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
SystemMessage,
ToolChoice,
UserMessage,
SystemMessageBehavior,
)
from llama_stack.providers.utils.inference import supported_inference_models
@ -309,7 +310,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
@ -354,7 +355,7 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.json
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
@ -375,7 +376,7 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
@ -403,20 +404,25 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message:
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
messages.append(SystemMessage(content=sys_content))
messages.append(SystemMessage(content=sys_content.strip("\n")))
# Add back existing messages from the request
messages += existing_messages