forked from phoenix-oss/llama-stack-mirror
feat: remove special handling of builtin::rag tool (#1015)
Summary: Lets the model decide which tool it needs to call to respond to a query. Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B ``` Also evaluated on a small benchmark with 20 questions from HotpotQA. With this PR and some prompting, the performance is 77% recall compared to 50% currently. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1015). * #1268 * #1239 * __->__ #1015
This commit is contained in:
parent
c64f0d5888
commit
bb2690f176
4 changed files with 94 additions and 133 deletions
|
@ -10,6 +10,8 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
|
|||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
|
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||
chunks = chunks[: query_config.max_chunks]
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks:
|
||||
picked = [
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
]
|
||||
for i, c in enumerate(chunks):
|
||||
metadata = c.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
|
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
break
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
|
||||
)
|
||||
)
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
||||
return RAGQueryResult(
|
||||
content=[
|
||||
TextContentItem(
|
||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
*picked,
|
||||
TextContentItem(
|
||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
],
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
|
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_from_memory",
|
||||
description="Retrieve context from memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="insert_into_memory",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
raise RuntimeError(
|
||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
result = await self.query(
|
||||
content=query,
|
||||
vector_db_ids=vector_db_ids,
|
||||
query_config=query_config,
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content,
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue