sys_prompt support in Agent (#938)

# What does this PR do?

The current default system prompt for llama3.2 tends to overindex on
tool calling and doesn't work well when the prompt does not require tool
calling.

This PR adds an option to override the default system prompt, and
organizes tool-related configs into a new config object.

- [ ] Addresses issue (#issue)


## Test Plan


LLAMA_STACK_CONFIG=together pytest
\-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v
tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
ehhuang 2025-02-05 21:11:32 -08:00 committed by GitHub
parent e777d965a1
commit 3922999118
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 126 additions and 4 deletions

View file

@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)
max_infer_iters: Optional[int] = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
if self.tool_config is None:
self.tool_config = ToolConfig(
tool_choice=self.tool_choice,
tool_prompt_format=self.tool_prompt_format,
)
@json_schema_type
class AgentConfig(AgentConfigCommon):
@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
@json_schema_type
@ -315,6 +330,7 @@ class Agents(Protocol):
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")