forked from phoenix-oss/llama-stack-mirror
feat: tool outputs metadata (#1155)
Summary: Allows tools to output metadata. This is useful for evaluating tool outputs, e.g. RAG tool will output document IDs, which can be used to score recall. Will need to make a similar change on the client side to support ClientTool outputting metadata. Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py
This commit is contained in:
parent
36162c8c82
commit
25fddccfd8
8 changed files with 141 additions and 28 deletions
|
@ -62,7 +62,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -587,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=retrieved_context or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
],
|
||||
),
|
||||
|
@ -795,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now()
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
[message],
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
@ -820,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
|
@ -1058,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
messages: List[CompletionMessage],
|
||||
tool_call: ToolCall,
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> List[ToolResponseMessage]:
|
||||
# While Tools.run interface takes a list of messages,
|
||||
# All tools currently only run on a single message
|
||||
# When this changes, we can drop this assert
|
||||
# Whether to call tools on each message and aggregate
|
||||
# or aggregate and call tool once, reamins to be seen.
|
||||
assert len(messages) == 1, "Expected single message"
|
||||
message = messages[0]
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
) -> ToolInvocationResult:
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
|
@ -1091,14 +1092,7 @@ async def execute_tool_call_maybe(
|
|||
**tool_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result.content,
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue