mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
sys_prompt support in Agent (#938)
# What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
e777d965a1
commit
3922999118
6 changed files with 126 additions and 4 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
|
@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel):
|
|||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||
|
||||
max_infer_iters: Optional[int] = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
if self.tool_config is None:
|
||||
self.tool_config = ToolConfig(
|
||||
tool_choice=self.tool_choice,
|
||||
tool_prompt_format=self.tool_prompt_format,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
|
@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -315,6 +330,7 @@ class Agents(Protocol):
|
|||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue