forked from phoenix-oss/llama-stack-mirror
feat: tool outputs metadata (#1155)
Summary: Allows tools to output metadata. This is useful for evaluating tool outputs, e.g. RAG tool will output document IDs, which can be used to score recall. Will need to make a similar change on the client side to support ClientTool outputting metadata. Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py
This commit is contained in:
parent
36162c8c82
commit
25fddccfd8
8 changed files with 141 additions and 28 deletions
78
docs/_static/llama-stack-spec.html
vendored
78
docs/_static/llama-stack-spec.html
vendored
|
@ -4521,6 +4521,31 @@
|
||||||
},
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -6746,6 +6771,31 @@
|
||||||
},
|
},
|
||||||
"error_code": {
|
"error_code": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7595,9 +7645,37 @@
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"metadata"
|
||||||
|
],
|
||||||
"title": "RAGQueryResult"
|
"title": "RAGQueryResult"
|
||||||
},
|
},
|
||||||
"QueryChunksRequest": {
|
"QueryChunksRequest": {
|
||||||
|
|
32
docs/_static/llama-stack-spec.yaml
vendored
32
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2945,6 +2945,16 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- call_id
|
- call_id
|
||||||
|
@ -4381,6 +4391,16 @@ components:
|
||||||
type: string
|
type: string
|
||||||
error_code:
|
error_code:
|
||||||
type: integer
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- content
|
- content
|
||||||
|
@ -4954,7 +4974,19 @@ components:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- metadata
|
||||||
title: RAGQueryResult
|
title: RAGQueryResult
|
||||||
QueryChunksRequest:
|
QueryChunksRequest:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: Union[BuiltinTool, str]
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: Optional[InterleavedContent] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
error_code: Optional[int] = None
|
error_code: Optional[int] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
|
|
|
@ -62,7 +62,7 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -587,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
content=retrieved_context or [],
|
content=retrieved_context or [],
|
||||||
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -795,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now()
|
tool_execution_start_time = datetime.now()
|
||||||
result_messages = await execute_tool_call_maybe(
|
tool_call = message.tool_calls[0]
|
||||||
|
tool_result = await execute_tool_call_maybe(
|
||||||
self.tool_runtime_api,
|
self.tool_runtime_api,
|
||||||
session_id,
|
session_id,
|
||||||
[message],
|
tool_call,
|
||||||
toolgroup_args,
|
toolgroup_args,
|
||||||
tool_to_group,
|
tool_to_group,
|
||||||
)
|
)
|
||||||
|
result_messages = [
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=tool_call.call_id,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
content=tool_result.content,
|
||||||
|
)
|
||||||
|
]
|
||||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||||
result_message = result_messages[0]
|
result_message = result_messages[0]
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
span.set_attribute("output", result_message.model_dump_json())
|
||||||
|
@ -820,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
call_id=result_message.call_id,
|
call_id=result_message.call_id,
|
||||||
tool_name=result_message.tool_name,
|
tool_name=result_message.tool_name,
|
||||||
content=result_message.content,
|
content=result_message.content,
|
||||||
|
metadata=tool_result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
|
@ -1058,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[CompletionMessage],
|
tool_call: ToolCall,
|
||||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||||
tool_to_group: Dict[str, str],
|
tool_to_group: Dict[str, str],
|
||||||
) -> List[ToolResponseMessage]:
|
) -> ToolInvocationResult:
|
||||||
# While Tools.run interface takes a list of messages,
|
|
||||||
# All tools currently only run on a single message
|
|
||||||
# When this changes, we can drop this assert
|
|
||||||
# Whether to call tools on each message and aggregate
|
|
||||||
# or aggregate and call tool once, reamins to be seen.
|
|
||||||
assert len(messages) == 1, "Expected single message"
|
|
||||||
message = messages[0]
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
group_name = tool_to_group.get(name, None)
|
group_name = tool_to_group.get(name, None)
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
|
@ -1091,14 +1092,7 @@ async def execute_tool_call_maybe(
|
||||||
**tool_call_args,
|
**tool_call_args,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
return [
|
|
||||||
ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=result.content,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
|
|
|
@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||||
|
chunks = chunks[: query_config.max_chunks]
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
for c in chunks[: query_config.max_chunks]:
|
for c in chunks:
|
||||||
metadata = c.metadata
|
metadata = c.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
|
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
metadata={
|
||||||
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
|
|
@ -457,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
@ -492,11 +493,13 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
response = rag_agent.create_turn(
|
response = rag_agent.create_turn(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
# rag is called
|
||||||
logs_str = "".join(logs)
|
assert response.steps[0].tool_calls[0].tool_name == "query_from_memory"
|
||||||
assert "Tool:query_from_memory" in logs_str
|
# document ids are present in metadata
|
||||||
assert expected_kw in logs_str.lower()
|
assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"]
|
||||||
|
assert expected_kw in response.output_message.content
|
||||||
|
|
||||||
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue