precommit

This commit is contained in:
Ashwin Bharambe 2025-10-15 09:26:04 -07:00
parent 57b3d14895
commit eba2eb2b0e
3 changed files with 24 additions and 16 deletions

View file

@ -46,9 +46,7 @@ def collect_turn(
*,
extra_headers: dict[str, Any] | None = None,
):
chunks = list(
agent.create_turn(messages=messages, session_id=session_id, stream=True, extra_headers=extra_headers)
)
chunks = list(agent.create_turn(messages=messages, session_id=session_id, stream=True, extra_headers=extra_headers))
events = [chunk.event for chunk in chunks]
final_response = next((chunk.response for chunk in reversed(chunks) if chunk.response), None)
if final_response is None:
@ -250,9 +248,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent,
session_id,
messages=[
text_message(
"Write code and execute it to find the answer for: What is the 100th prime number?"
),
text_message("Write code and execute it to find the answer for: What is the 100th prime number?"),
],
)
logs = [str(log) for log in AgentEventLogger().log(chunks) if log is not None]
@ -305,7 +301,9 @@ def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
messages=[text_message("Get the boiling point of polyjuice with a tool call.")],
)
num_tool_calls = sum(1 for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution")
num_tool_calls = sum(
1 for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
)
assert num_tool_calls <= 5
@ -325,7 +323,9 @@ def test_tool_choice_get_boiling_point(llama_stack_client, agent_config):
pytest.xfail("NotImplemented for non-llama models")
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client, agent_config, "get_boiling_point")
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].result.tool_calls[0].tool_name == "get_boiling_point"
assert (
len(tool_execution_steps) >= 1 and tool_execution_steps[0].result.tool_calls[0].tool_name == "get_boiling_point"
)
def run_agent_with_tool_choice(client, agent_config, tool_choice):
@ -373,14 +373,18 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
messages=[text_message(input_prompt)],
)
tool_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
tool_events = [
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
]
assert len(tool_events) >= 1
tool_exec = tool_events[0]
assert tool_exec.result.tool_calls[0].tool_name.startswith("get_boiling_point")
if expects_metadata:
assert tool_exec.result.tool_responses[0]["metadata"]["source"] == "https://www.google.com"
inference_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "inference"]
inference_events = [
event for event in events if isinstance(event, StepCompleted) and event.step_type == "inference"
]
assert len(inference_events) >= 2
assert "polyjuice" in final_response.output_text.lower()
@ -407,7 +411,9 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
],
)
tool_exec_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
tool_exec_events = [
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
]
assert len(tool_exec_events) >= 1
tool_exec = tool_exec_events[0]
assert len(tool_exec.result.tool_calls) == 2

View file

@ -117,13 +117,13 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
assert final_response is not None
issued_calls = [
event
for event in events
if isinstance(event, StepProgress) and isinstance(event.delta, ToolCallIssuedDelta)
event for event in events if isinstance(event, StepProgress) and isinstance(event.delta, ToolCallIssuedDelta)
]
assert issued_calls and issued_calls[0].delta.tool_name == "greet_everyone"
tool_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
tool_events = [
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
]
assert tool_events and tool_events[0].result.tool_calls[0].tool_name == "greet_everyone"
assert "hello" in final_response.output_text.lower()

View file

@ -416,7 +416,9 @@ class TestAgentWithMCPTools:
)
events = [chunk.event for chunk in chunks]
tool_execution_steps = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
tool_execution_steps = [
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
]
for step in tool_execution_steps:
for tool_response in step.result.tool_responses: