mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
precommit
This commit is contained in:
parent
57b3d14895
commit
eba2eb2b0e
3 changed files with 24 additions and 16 deletions
|
|
@ -46,9 +46,7 @@ def collect_turn(
|
|||
*,
|
||||
extra_headers: dict[str, Any] | None = None,
|
||||
):
|
||||
chunks = list(
|
||||
agent.create_turn(messages=messages, session_id=session_id, stream=True, extra_headers=extra_headers)
|
||||
)
|
||||
chunks = list(agent.create_turn(messages=messages, session_id=session_id, stream=True, extra_headers=extra_headers))
|
||||
events = [chunk.event for chunk in chunks]
|
||||
final_response = next((chunk.response for chunk in reversed(chunks) if chunk.response), None)
|
||||
if final_response is None:
|
||||
|
|
@ -250,9 +248,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
agent,
|
||||
session_id,
|
||||
messages=[
|
||||
text_message(
|
||||
"Write code and execute it to find the answer for: What is the 100th prime number?"
|
||||
),
|
||||
text_message("Write code and execute it to find the answer for: What is the 100th prime number?"),
|
||||
],
|
||||
)
|
||||
logs = [str(log) for log in AgentEventLogger().log(chunks) if log is not None]
|
||||
|
|
@ -305,7 +301,9 @@ def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
|
|||
messages=[text_message("Get the boiling point of polyjuice with a tool call.")],
|
||||
)
|
||||
|
||||
num_tool_calls = sum(1 for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution")
|
||||
num_tool_calls = sum(
|
||||
1 for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
)
|
||||
assert num_tool_calls <= 5
|
||||
|
||||
|
||||
|
|
@ -325,7 +323,9 @@ def test_tool_choice_get_boiling_point(llama_stack_client, agent_config):
|
|||
pytest.xfail("NotImplemented for non-llama models")
|
||||
|
||||
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client, agent_config, "get_boiling_point")
|
||||
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].result.tool_calls[0].tool_name == "get_boiling_point"
|
||||
assert (
|
||||
len(tool_execution_steps) >= 1 and tool_execution_steps[0].result.tool_calls[0].tool_name == "get_boiling_point"
|
||||
)
|
||||
|
||||
|
||||
def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
||||
|
|
@ -373,14 +373,18 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
|
|||
messages=[text_message(input_prompt)],
|
||||
)
|
||||
|
||||
tool_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
|
||||
tool_events = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
]
|
||||
assert len(tool_events) >= 1
|
||||
tool_exec = tool_events[0]
|
||||
assert tool_exec.result.tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||
if expects_metadata:
|
||||
assert tool_exec.result.tool_responses[0]["metadata"]["source"] == "https://www.google.com"
|
||||
|
||||
inference_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "inference"]
|
||||
inference_events = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "inference"
|
||||
]
|
||||
assert len(inference_events) >= 2
|
||||
assert "polyjuice" in final_response.output_text.lower()
|
||||
|
||||
|
|
@ -407,7 +411,9 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
|
|||
],
|
||||
)
|
||||
|
||||
tool_exec_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
|
||||
tool_exec_events = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
]
|
||||
assert len(tool_exec_events) >= 1
|
||||
tool_exec = tool_exec_events[0]
|
||||
assert len(tool_exec.result.tool_calls) == 2
|
||||
|
|
|
|||
|
|
@ -117,13 +117,13 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
|
|||
assert final_response is not None
|
||||
|
||||
issued_calls = [
|
||||
event
|
||||
for event in events
|
||||
if isinstance(event, StepProgress) and isinstance(event.delta, ToolCallIssuedDelta)
|
||||
event for event in events if isinstance(event, StepProgress) and isinstance(event.delta, ToolCallIssuedDelta)
|
||||
]
|
||||
assert issued_calls and issued_calls[0].delta.tool_name == "greet_everyone"
|
||||
|
||||
tool_events = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
|
||||
tool_events = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
]
|
||||
assert tool_events and tool_events[0].result.tool_calls[0].tool_name == "greet_everyone"
|
||||
|
||||
assert "hello" in final_response.output_text.lower()
|
||||
|
|
|
|||
|
|
@ -416,7 +416,9 @@ class TestAgentWithMCPTools:
|
|||
)
|
||||
|
||||
events = [chunk.event for chunk in chunks]
|
||||
tool_execution_steps = [event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"]
|
||||
tool_execution_steps = [
|
||||
event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution"
|
||||
]
|
||||
|
||||
for step in tool_execution_steps:
|
||||
for tool_response in step.result.tool_responses:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue