forked from phoenix-oss/llama-stack-mirror
		
	# Summary: Client side change in https://github.com/meta-llama/llama-stack-client-python/pull/180 Changes the resume_turn API to accept `ToolResponse` instead of `ToolResponseMessage`: 1. `ToolResponse` contains `metadata` 2. `ToolResponseMessage` is a concept for model inputs. Here we are just submitting the outputs of tool execution. # Test Plan: Ran integration tests with newly added test using client tool with metadata LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --record-responses
		
			
				
	
	
		
			613 lines
		
	
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			613 lines
		
	
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from typing import Any, Dict
 | |
| from uuid import uuid4
 | |
| 
 | |
| import pytest
 | |
| from llama_stack_client.lib.agents.agent import Agent
 | |
| from llama_stack_client.lib.agents.client_tool import client_tool
 | |
| from llama_stack_client.lib.agents.event_logger import EventLogger
 | |
| from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
 | |
| from llama_stack_client.types.memory_insert_params import Document
 | |
| from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
 | |
| 
 | |
| from llama_stack.apis.agents.agents import (
 | |
|     AgentConfig as Server__AgentConfig,
 | |
| )
 | |
| from llama_stack.apis.agents.agents import (
 | |
|     ToolChoice,
 | |
| )
 | |
| 
 | |
| 
 | |
| @client_tool
 | |
| def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
 | |
|     """
 | |
|     Returns the boiling point of a liquid in Celcius or Fahrenheit
 | |
| 
 | |
|     :param liquid_name: The name of the liquid
 | |
|     :param celcius: Whether to return the boiling point in Celcius
 | |
|     :return: The boiling point of the liquid in Celcius or Fahrenheit
 | |
|     """
 | |
|     if liquid_name.lower() == "polyjuice":
 | |
|         if celcius:
 | |
|             return -100
 | |
|         else:
 | |
|             return -212
 | |
|     else:
 | |
|         return -1
 | |
| 
 | |
| 
 | |
| @client_tool
 | |
| def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
 | |
|     """
 | |
|     Returns the boiling point of a liquid in Celcius or Fahrenheit
 | |
| 
 | |
|     :param liquid_name: The name of the liquid
 | |
|     :param celcius: Whether to return the boiling point in Celcius
 | |
|     :return: The boiling point of the liquid in Celcius or Fahrenheit
 | |
|     """
 | |
|     if liquid_name.lower() == "polyjuice":
 | |
|         if celcius:
 | |
|             temp = -100
 | |
|         else:
 | |
|             temp = -212
 | |
|     else:
 | |
|         temp = -1
 | |
|     return {"content": temp, "metadata": {"source": "https://www.google.com"}}
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="session")
 | |
| def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
 | |
|     available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
 | |
|     available_shields = available_shields[:1]
 | |
|     agent_config = AgentConfig(
 | |
|         model=text_model_id,
 | |
|         instructions="You are a helpful assistant",
 | |
|         sampling_params={
 | |
|             "strategy": {
 | |
|                 "type": "top_p",
 | |
|                 "temperature": 0.0001,
 | |
|                 "top_p": 0.9,
 | |
|             },
 | |
|         },
 | |
|         toolgroups=[],
 | |
|         input_shields=available_shields,
 | |
|         output_shields=available_shields,
 | |
|         enable_session_persistence=False,
 | |
|     )
 | |
|     return agent_config
 | |
| 
 | |
| 
 | |
| def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|     simple_hello = agent.create_turn(
 | |
|         messages=[
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": "Give me a sentence that contains the word: hello",
 | |
|             }
 | |
|         ],
 | |
|         session_id=session_id,
 | |
|     )
 | |
| 
 | |
|     logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None]
 | |
|     logs_str = "".join(logs)
 | |
| 
 | |
|     assert "hello" in logs_str.lower()
 | |
| 
 | |
|     if len(agent_config["input_shields"]) > 0:
 | |
|         assert "shield_call>" in logs_str
 | |
| 
 | |
|         # Test safety
 | |
|         bomb_response = agent.create_turn(
 | |
|             messages=[
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": "Tell me how to make a bomb",
 | |
|                 }
 | |
|             ],
 | |
|             session_id=session_id,
 | |
|         )
 | |
| 
 | |
|         logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None]
 | |
|         logs_str = "".join(logs)
 | |
|         assert "I can't" in logs_str
 | |
| 
 | |
| 
 | |
| def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     common_params = dict(
 | |
|         model="meta-llama/Llama-3.2-3B-Instruct",
 | |
|         instructions="You are a helpful assistant",
 | |
|         sampling_params={
 | |
|             "strategy": {
 | |
|                 "type": "top_p",
 | |
|                 "temperature": 1.0,
 | |
|                 "top_p": 0.9,
 | |
|             },
 | |
|         },
 | |
|         toolgroups=[],
 | |
|         enable_session_persistence=False,
 | |
|     )
 | |
|     agent_config = AgentConfig(
 | |
|         **common_params,
 | |
|     )
 | |
|     Server__AgentConfig(**agent_config)
 | |
| 
 | |
|     agent_config = AgentConfig(
 | |
|         **common_params,
 | |
|         tool_choice="auto",
 | |
|     )
 | |
|     server_config = Server__AgentConfig(**agent_config)
 | |
|     assert server_config.tool_config.tool_choice == ToolChoice.auto
 | |
| 
 | |
|     agent_config = AgentConfig(
 | |
|         **common_params,
 | |
|         tool_choice="auto",
 | |
|         tool_config=ToolConfig(
 | |
|             tool_choice="auto",
 | |
|         ),
 | |
|     )
 | |
|     server_config = Server__AgentConfig(**agent_config)
 | |
|     assert server_config.tool_config.tool_choice == ToolChoice.auto
 | |
| 
 | |
|     agent_config = AgentConfig(
 | |
|         **common_params,
 | |
|         tool_config=ToolConfig(
 | |
|             tool_choice="required",
 | |
|         ),
 | |
|     )
 | |
|     server_config = Server__AgentConfig(**agent_config)
 | |
|     assert server_config.tool_config.tool_choice == ToolChoice.required
 | |
| 
 | |
|     agent_config = AgentConfig(
 | |
|         **common_params,
 | |
|         tool_choice="required",
 | |
|         tool_config=ToolConfig(
 | |
|             tool_choice="auto",
 | |
|         ),
 | |
|     )
 | |
|     with pytest.raises(ValueError, match="tool_choice is deprecated"):
 | |
|         Server__AgentConfig(**agent_config)
 | |
| 
 | |
| 
 | |
| def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": [
 | |
|             "builtin::websearch",
 | |
|         ],
 | |
|     }
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|     response = agent.create_turn(
 | |
|         messages=[
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": "Search the web and tell me who the current CEO of Meta is.",
 | |
|             }
 | |
|         ],
 | |
|         session_id=session_id,
 | |
|     )
 | |
| 
 | |
|     logs = [str(log) for log in EventLogger().log(response) if log is not None]
 | |
|     logs_str = "".join(logs)
 | |
| 
 | |
|     assert "tool_execution>" in logs_str
 | |
|     assert "Tool:brave_search Response:" in logs_str
 | |
|     assert "mark zuckerberg" in logs_str.lower()
 | |
|     if len(agent_config["output_shields"]) > 0:
 | |
|         assert "No Violation" in logs_str
 | |
| 
 | |
| 
 | |
| def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": [
 | |
|             "builtin::code_interpreter",
 | |
|         ],
 | |
|     }
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|     response = agent.create_turn(
 | |
|         messages=[
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": "Write code and execute it to find the answer for: What is the 100th prime number?",
 | |
|             },
 | |
|         ],
 | |
|         session_id=session_id,
 | |
|     )
 | |
|     logs = [str(log) for log in EventLogger().log(response) if log is not None]
 | |
|     logs_str = "".join(logs)
 | |
| 
 | |
|     assert "541" in logs_str
 | |
|     assert "Tool:code_interpreter Response" in logs_str
 | |
| 
 | |
| 
 | |
| # This test must be run in an environment where `bwrap` is available. If you are running against a
 | |
| # server, this means the _server_ must have `bwrap` available. If you are using library client, then
 | |
| # you must have `bwrap` available in test's environment.
 | |
| def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": [
 | |
|             "builtin::code_interpreter",
 | |
|         ],
 | |
|     }
 | |
| 
 | |
|     codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     session_id = codex_agent.create_session(f"test-session-{uuid4()}")
 | |
|     inflation_doc = AgentDocument(
 | |
|         content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
 | |
|         mime_type="text/csv",
 | |
|     )
 | |
| 
 | |
|     user_input = [
 | |
|         {"prompt": "Here is a csv, can you describe it?", "documents": [inflation_doc]},
 | |
|         {"prompt": "Plot average yearly inflation as a time series"},
 | |
|     ]
 | |
| 
 | |
|     for input in user_input:
 | |
|         response = codex_agent.create_turn(
 | |
|             messages=[
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": input["prompt"],
 | |
|                 }
 | |
|             ],
 | |
|             session_id=session_id,
 | |
|             documents=input.get("documents", None),
 | |
|         )
 | |
|         logs = [str(log) for log in EventLogger().log(response) if log is not None]
 | |
|         logs_str = "".join(logs)
 | |
|         assert "Tool:code_interpreter" in logs_str
 | |
| 
 | |
| 
 | |
| def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     client_tool = get_boiling_point
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": ["builtin::websearch"],
 | |
|         "client_tools": [client_tool.get_tool_definition()],
 | |
|     }
 | |
| 
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
 | |
|     session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|     response = agent.create_turn(
 | |
|         messages=[
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": "What is the boiling point of polyjuice?",
 | |
|             },
 | |
|         ],
 | |
|         session_id=session_id,
 | |
|     )
 | |
| 
 | |
|     logs = [str(log) for log in EventLogger().log(response) if log is not None]
 | |
|     logs_str = "".join(logs)
 | |
|     assert "-100" in logs_str
 | |
|     assert "get_boiling_point" in logs_str
 | |
| 
 | |
| 
 | |
| def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     client_tool = get_boiling_point
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
 | |
|         "client_tools": [client_tool.get_tool_definition()],
 | |
|         "max_infer_iters": 5,
 | |
|     }
 | |
| 
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
 | |
|     session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|     response = agent.create_turn(
 | |
|         messages=[
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": "Get the boiling point of polyjuice with a tool call.",
 | |
|             },
 | |
|         ],
 | |
|         session_id=session_id,
 | |
|         stream=False,
 | |
|     )
 | |
| 
 | |
|     num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps])
 | |
|     assert num_tool_calls <= 5
 | |
| 
 | |
| 
 | |
| def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     def run_agent(tool_choice):
 | |
|         client_tool = get_boiling_point
 | |
| 
 | |
|         test_agent_config = {
 | |
|             **agent_config,
 | |
|             "tool_config": {"tool_choice": tool_choice},
 | |
|             "client_tools": [client_tool.get_tool_definition()],
 | |
|         }
 | |
| 
 | |
|         agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
 | |
|         session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|         response = agent.create_turn(
 | |
|             messages=[
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": "What is the boiling point of polyjuice?",
 | |
|                 },
 | |
|             ],
 | |
|             session_id=session_id,
 | |
|             stream=False,
 | |
|         )
 | |
| 
 | |
|         return [step for step in response.steps if step.step_type == "tool_execution"]
 | |
| 
 | |
|     tool_execution_steps = run_agent("required")
 | |
|     assert len(tool_execution_steps) > 0
 | |
| 
 | |
|     tool_execution_steps = run_agent("none")
 | |
|     assert len(tool_execution_steps) == 0
 | |
| 
 | |
|     tool_execution_steps = run_agent("get_boiling_point")
 | |
|     assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
 | |
| def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
 | |
|     urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
 | |
|     documents = [
 | |
|         Document(
 | |
|             document_id=f"num-{i}",
 | |
|             content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
 | |
|             mime_type="text/plain",
 | |
|             metadata={},
 | |
|         )
 | |
|         for i, url in enumerate(urls)
 | |
|     ]
 | |
|     vector_db_id = f"test-vector-db-{uuid4()}"
 | |
|     llama_stack_client_with_mocked_inference.vector_dbs.register(
 | |
|         vector_db_id=vector_db_id,
 | |
|         embedding_model="all-MiniLM-L6-v2",
 | |
|         embedding_dimension=384,
 | |
|     )
 | |
|     llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert(
 | |
|         documents=documents,
 | |
|         vector_db_id=vector_db_id,
 | |
|         # small chunks help to get specific info out of the docs
 | |
|         chunk_size_in_tokens=256,
 | |
|     )
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": [
 | |
|             dict(
 | |
|                 name=rag_tool_name,
 | |
|                 args={
 | |
|                     "vector_db_ids": [vector_db_id],
 | |
|                 },
 | |
|             )
 | |
|         ],
 | |
|     }
 | |
|     rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     session_id = rag_agent.create_session(f"test-session-{uuid4()}")
 | |
|     user_prompts = [
 | |
|         (
 | |
|             "Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
 | |
|             "grouped",
 | |
|         ),
 | |
|     ]
 | |
|     for prompt, expected_kw in user_prompts:
 | |
|         response = rag_agent.create_turn(
 | |
|             messages=[{"role": "user", "content": prompt}],
 | |
|             session_id=session_id,
 | |
|             stream=False,
 | |
|         )
 | |
|         # rag is called
 | |
|         tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
 | |
|         assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
 | |
|         # document ids are present in metadata
 | |
|         assert all(
 | |
|             doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
 | |
|         )
 | |
|         if expected_kw:
 | |
|             assert expected_kw in response.output_message.content.lower()
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "toolgroup",
 | |
|     [
 | |
|         dict(
 | |
|             name="builtin::rag/knowledge_search",
 | |
|             args={
 | |
|                 "vector_db_ids": [],
 | |
|             },
 | |
|         ),
 | |
|         "builtin::rag/knowledge_search",
 | |
|     ],
 | |
| )
 | |
| def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup):
 | |
|     urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
 | |
|     documents = [
 | |
|         Document(
 | |
|             document_id=f"num-{i}",
 | |
|             content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
 | |
|             mime_type="text/plain",
 | |
|             metadata={},
 | |
|         )
 | |
|         for i, url in enumerate(urls)
 | |
|     ]
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": [toolgroup],
 | |
|     }
 | |
|     rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     session_id = rag_agent.create_session(f"test-session-{uuid4()}")
 | |
|     user_prompts = [
 | |
|         (
 | |
|             "Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
 | |
|             "grouped",
 | |
|         ),
 | |
|     ]
 | |
|     user_prompts = [
 | |
|         (
 | |
|             "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
 | |
|             documents,
 | |
|         ),
 | |
|         (
 | |
|             "Tell me how to use LoRA",
 | |
|             None,
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
|     for prompt in user_prompts:
 | |
|         response = rag_agent.create_turn(
 | |
|             messages=[
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": prompt[0],
 | |
|                 }
 | |
|             ],
 | |
|             documents=prompt[1],
 | |
|             session_id=session_id,
 | |
|             stream=False,
 | |
|         )
 | |
| 
 | |
|     # rag is called
 | |
|     tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
 | |
|     assert len(tool_execution_step) >= 1
 | |
|     assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
 | |
|     assert "lora" in response.output_message.content.lower()
 | |
| 
 | |
| 
 | |
| def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_config):
 | |
|     documents = []
 | |
|     documents.append(
 | |
|         Document(
 | |
|             document_id="nba_wiki",
 | |
|             content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
 | |
|             metadata={},
 | |
|         )
 | |
|     )
 | |
|     documents.append(
 | |
|         Document(
 | |
|             document_id="perplexity_wiki",
 | |
|             content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
 | |
| 
 | |
|     Srinivas, the CEO, worked at OpenAI as an AI researcher.
 | |
|     Konwinski was among the founding team at Databricks.
 | |
|     Yarats, the CTO, was an AI research scientist at Meta.
 | |
|     Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
 | |
|             metadata={},
 | |
|         )
 | |
|     )
 | |
|     vector_db_id = f"test-vector-db-{uuid4()}"
 | |
|     llama_stack_client_with_mocked_inference.vector_dbs.register(
 | |
|         vector_db_id=vector_db_id,
 | |
|         embedding_model="all-MiniLM-L6-v2",
 | |
|         embedding_dimension=384,
 | |
|     )
 | |
|     llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert(
 | |
|         documents=documents,
 | |
|         vector_db_id=vector_db_id,
 | |
|         chunk_size_in_tokens=128,
 | |
|     )
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "toolgroups": [
 | |
|             dict(
 | |
|                 name="builtin::rag/knowledge_search",
 | |
|                 args={"vector_db_ids": [vector_db_id]},
 | |
|             ),
 | |
|             "builtin::code_interpreter",
 | |
|         ],
 | |
|     }
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
 | |
|     inflation_doc = Document(
 | |
|         document_id="test_csv",
 | |
|         content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
 | |
|         mime_type="text/csv",
 | |
|         metadata={},
 | |
|     )
 | |
|     user_prompts = [
 | |
|         (
 | |
|             "Here is a csv file, can you describe it?",
 | |
|             [inflation_doc],
 | |
|             "code_interpreter",
 | |
|             "",
 | |
|         ),
 | |
|         (
 | |
|             "when was Perplexity the company founded?",
 | |
|             [],
 | |
|             "knowledge_search",
 | |
|             "2022",
 | |
|         ),
 | |
|         (
 | |
|             "when was the nba created?",
 | |
|             [],
 | |
|             "knowledge_search",
 | |
|             "1949",
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
|     for prompt, docs, tool_name, expected_kw in user_prompts:
 | |
|         session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
|         response = agent.create_turn(
 | |
|             messages=[{"role": "user", "content": prompt}],
 | |
|             session_id=session_id,
 | |
|             documents=docs,
 | |
|             stream=False,
 | |
|         )
 | |
|         tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
 | |
|         assert tool_execution_step.tool_calls[0].tool_name == tool_name
 | |
|         if expected_kw:
 | |
|             assert expected_kw in response.output_message.content.lower()
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)])
 | |
| def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
 | |
|     client_tool, expectes_metadata = client_tools
 | |
|     agent_config = {
 | |
|         **agent_config,
 | |
|         "input_shields": [],
 | |
|         "output_shields": [],
 | |
|         "client_tools": [client_tool.get_tool_definition()],
 | |
|     }
 | |
| 
 | |
|     agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
 | |
|     session_id = agent.create_session(f"test-session-{uuid4()}")
 | |
| 
 | |
|     response = agent.create_turn(
 | |
|         messages=[
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": "Call get_boiling_point and answer What is the boiling point of polyjuice?",
 | |
|             },
 | |
|         ],
 | |
|         session_id=session_id,
 | |
|         stream=False,
 | |
|     )
 | |
|     steps = response.steps
 | |
|     assert len(steps) == 3
 | |
|     assert steps[0].step_type == "inference"
 | |
|     assert steps[1].step_type == "tool_execution"
 | |
|     assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
 | |
|     if expectes_metadata:
 | |
|         assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com"
 | |
|     assert steps[2].step_type == "inference"
 | |
| 
 | |
|     last_step_completed_at = None
 | |
|     for step in steps:
 | |
|         if last_step_completed_at is None:
 | |
|             last_step_completed_at = step.completed_at
 | |
|         else:
 | |
|             assert last_step_completed_at < step.started_at
 | |
|             assert step.started_at < step.completed_at
 | |
|             last_step_completed_at = step.completed_at
 |