feat: support ClientTool output metadata (#1426)

# Summary:
Client side change in
https://github.com/meta-llama/llama-stack-client-python/pull/180
Changes the resume_turn API to accept `ToolResponse` instead of
`ToolResponseMessage`:
1. `ToolResponse` contains `metadata`
2. `ToolResponseMessage` is a concept for model inputs. Here we are just
submitting the outputs of tool execution.

# Test Plan:
Ran integration tests with newly added test using client tool with
metadata

LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --record-responses
This commit is contained in:
ehhuang 2025-03-05 14:30:27 -08:00 committed by GitHub
parent ac717f38dc
commit 6cf79437b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 3984 additions and 2172 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from uuid import uuid4
import pytest
@ -40,6 +41,25 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
return -1
@client_tool
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
temp = -100
else:
temp = -212
else:
temp = -1
return {"content": temp, "metadata": {"source": "https://www.google.com"}}
@pytest.fixture(scope="session")
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
@ -551,8 +571,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
assert expected_kw in response.output_message.content.lower()
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
@pytest.mark.parametrize("client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)])
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
client_tool, expectes_metadata = client_tools
agent_config = {
**agent_config,
"input_shields": [],
@ -577,7 +598,9 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
assert len(steps) == 3
assert steps[0].step_type == "inference"
assert steps[1].step_type == "tool_execution"
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
if expectes_metadata:
assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com"
assert steps[2].step_type == "inference"
last_step_completed_at = None