forked from phoenix-oss/llama-stack-mirror
sys_prompt support in Agent (#938)
# What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
e777d965a1
commit
3922999118
6 changed files with 126 additions and 4 deletions
|
@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
assert "CustomTool" in logs_str
|
||||
|
||||
|
||||
def test_override_system_message_behavior(llama_stack_client, agent_config):
|
||||
client_tool = TestClientTool()
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"instructions": "You are a pirate",
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tell me a joke about bicycles",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
print(logs_str)
|
||||
# can't tell a joke: "I don't have a function"
|
||||
assert "function" in logs_str
|
||||
|
||||
# with system message behavior replace
|
||||
instructions = """
|
||||
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function,
|
||||
also point it out.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"instructions": instructions,
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tool_config": {
|
||||
"system_message_behavior": "replace",
|
||||
},
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tell me a joke about bicycles",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
print(logs_str)
|
||||
assert "bicycle" in logs_str
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
print(logs_str)
|
||||
assert "-100" in logs_str
|
||||
assert "CustomTool" in logs_str
|
||||
|
||||
|
||||
def test_rag_agent(llama_stack_client, agent_config):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue