forked from phoenix-oss/llama-stack-mirror
test: add test for Agent.create_turn non-streaming response (#1078)
Summary: This tests the fix to the SDK in https://github.com/meta-llama/llama-stack-client-python/pull/141 Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
This commit is contained in:
parent
32d1e50a6f
commit
225dd38e5c
1 changed files with 32 additions and 2 deletions
|
@ -319,7 +319,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "-100" in logs_str
|
assert "-100" in logs_str
|
||||||
assert "CustomTool" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
# TODO: fix this flaky test
|
# TODO: fix this flaky test
|
||||||
|
@ -403,7 +403,7 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
print(logs_str)
|
print(logs_str)
|
||||||
assert "-100" in logs_str
|
assert "-100" in logs_str
|
||||||
assert "CustomTool" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent(llama_stack_client, agent_config):
|
def test_rag_agent(llama_stack_client, agent_config):
|
||||||
|
@ -527,3 +527,33 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert f"Tool:{tool_name}" in logs_str
|
assert f"Tool:{tool_name}" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_turn_response(llama_stack_client, agent_config):
|
||||||
|
client_tool = TestClientTool()
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
|
}
|
||||||
|
|
||||||
|
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the boiling point of polyjuice?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
steps = response.steps
|
||||||
|
assert len(steps) == 3
|
||||||
|
assert steps[0].step_type == "inference"
|
||||||
|
assert steps[1].step_type == "tool_execution"
|
||||||
|
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
|
||||||
|
assert steps[2].step_type == "inference"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue