mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
sys_prompt support in Agent (#938)
# What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
e777d965a1
commit
3922999118
6 changed files with 126 additions and 4 deletions
|
@ -25,7 +25,12 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
@ -76,6 +81,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
if agent_config.tool_config is None:
|
||||
agent_config.tool_config = ToolConfig(
|
||||
tool_choice=agent_config.tool_choice,
|
||||
tool_prompt_format=agent_config.tool_prompt_format,
|
||||
)
|
||||
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_config.model_dump_json(),
|
||||
|
@ -140,6 +151,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -148,6 +160,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue