mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
add a test
This commit is contained in:
parent
4f94f5a708
commit
7bb5f2bf0e
1 changed files with 27 additions and 0 deletions
|
@ -278,6 +278,33 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
assert "get_boiling_point" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
|
||||||
|
client_tool = get_boiling_point
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
|
||||||
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
|
"max_infer_iters": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Get the boiling point of polyjuice with a tool call.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps])
|
||||||
|
assert num_tool_calls <= 5
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice(llama_stack_client, agent_config):
|
def test_tool_choice(llama_stack_client, agent_config):
|
||||||
def run_agent(tool_choice):
|
def run_agent(tool_choice):
|
||||||
client_tool = get_boiling_point
|
client_tool = get_boiling_point
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue