sys_prompt support in Agent (#938)

# What does this PR do?

The current default system prompt for llama3.2 tends to overindex on
tool calling and doesn't work well when the prompt does not require tool
calling.

This PR adds an option to override the default system prompt, and
organizes tool-related configs into a new config object.

- [ ] Addresses issue (#issue)


## Test Plan


LLAMA_STACK_CONFIG=together pytest
\-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v
tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
ehhuang 2025-02-05 21:11:32 -08:00 committed by GitHub
parent e777d965a1
commit 3922999118
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 126 additions and 4 deletions

View file

@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
@ -76,6 +81,12 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
if agent_config.tool_config is None:
agent_config.tool_config = ToolConfig(
tool_choice=agent_config.tool_choice,
tool_prompt_format=agent_config.tool_prompt_format,
)
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
@ -140,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -148,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)