forked from phoenix-oss/llama-stack-mirror
feat: support ClientTool output metadata (#1426)
# Summary: Client side change in https://github.com/meta-llama/llama-stack-client-python/pull/180 Changes the resume_turn API to accept `ToolResponse` instead of `ToolResponseMessage`: 1. `ToolResponse` contains `metadata` 2. `ToolResponseMessage` is a concept for model inputs. Here we are just submitting the outputs of tool execution. # Test Plan: Ran integration tests with newly added test using client tool with metadata LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --record-responses
This commit is contained in:
parent
ac717f38dc
commit
6cf79437b3
10 changed files with 3984 additions and 2172 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
@ -40,6 +41,25 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
|||
return -1
|
||||
|
||||
|
||||
@client_tool
|
||||
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
||||
|
||||
:param liquid_name: The name of the liquid
|
||||
:param celcius: Whether to return the boiling point in Celcius
|
||||
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||
"""
|
||||
if liquid_name.lower() == "polyjuice":
|
||||
if celcius:
|
||||
temp = -100
|
||||
else:
|
||||
temp = -212
|
||||
else:
|
||||
temp = -1
|
||||
return {"content": temp, "metadata": {"source": "https://www.google.com"}}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
||||
|
@ -551,8 +571,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config):
|
||||
client_tool = get_boiling_point
|
||||
@pytest.mark.parametrize("client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)])
|
||||
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
|
||||
client_tool, expectes_metadata = client_tools
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"input_shields": [],
|
||||
|
@ -577,7 +598,9 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
|||
assert len(steps) == 3
|
||||
assert steps[0].step_type == "inference"
|
||||
assert steps[1].step_type == "tool_execution"
|
||||
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
|
||||
assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||
if expectes_metadata:
|
||||
assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com"
|
||||
assert steps[2].step_type == "inference"
|
||||
|
||||
last_step_completed_at = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue