RAG Agent test passes

This commit is contained in:
Ashwin Bharambe 2025-01-21 15:16:17 -08:00
parent 2f76de1643
commit a1433c0899
19 changed files with 157 additions and 76 deletions

View file

@ -74,17 +74,17 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_ids: List[str],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
async def query_context(
self,
content: InterleavedContent,

View file

@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
@ -400,6 +407,33 @@ class EvalRouter(Eval):
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: RoutingTable,
@ -408,7 +442,7 @@ class ToolRuntimeRouter(ToolRuntime):
# TODO: this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory")
self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
@ -418,10 +452,10 @@ class ToolRuntimeRouter(ToolRuntime):
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
kwargs=kwargs,
)
async def list_runtime_tools(

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.tools import SpecialToolGroups
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroups
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
@ -24,21 +24,29 @@ class ApiEndpoint(BaseModel):
name: str
def toolgroup_protocol_map():
return {
SpecialToolGroups.rag_tool: RAGToolRuntime,
}
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
if api == Api.tool_runtime:
for tool_group in SpecialToolGroups:
print(f"tool_group: {tool_group}")
sub_protocol = getattr(protocol, tool_group.value)
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
@ -47,7 +55,6 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v5"
KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -373,21 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
if "builtin::memory" in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -399,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
args = toolgroup_args["builtin::memory"]
args = toolgroup_args.get("builtin::memory", {})
vector_db_ids = args.get("vector_db_ids", [])
session_info = await self.storage.get_session_info(session_id)
@ -423,7 +432,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
retrieved_context = await self.tool_runtime_api.rag_tool.query_context(
result = await self.tool_runtime_api.rag_tool.query_context(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
@ -434,6 +443,7 @@ class ChatAgent(ShieldRunnerMixin):
),
vector_db_ids=vector_db_ids,
)
retrieved_context = result.content
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -862,7 +872,7 @@ class ChatAgent(ShieldRunnerMixin):
async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
bank_id = await self._ensure_memory_bank(session_id)
vector_db_id = await self._ensure_memory_bank(session_id)
documents = [
RAGDocument(
document_id=str(uuid.uuid4()),
@ -874,7 +884,7 @@ class ChatAgent(ShieldRunnerMixin):
]
await self.tool_runtime_api.rag_tool.insert_documents(
documents=documents,
vector_db_ids=[bank_id],
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)

View file

@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
script = args["code"]
script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]

View file

@ -5,68 +5,64 @@
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: RAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs)
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: DefaultRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return config.separator.join(interleaved_content_as_str(m) for m in messages)
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: LLMRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
messages = []
if isinstance(content, list):
messages = [interleaved_content_as_str(m) for m in content]
else:
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render(m_dict)
content = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)

View file

@ -27,6 +27,10 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
@ -60,10 +64,28 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_ids: List[str],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
pass
chunks = []
for doc in documents:
content = await content_from_doc(doc)
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query_context(
self,
@ -104,13 +126,18 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
tokens = 0
picked = []
for c in chunks[: query_config.max_chunks]:
tokens += c.token_count
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
)
)
return RAGQueryResult(
content=[

View file

@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
"count": self.config.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": args["query"],
"q": kwargs["query"],
}
response = requests.get(

View file

@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": args["query"]}
payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
results = self._clean_brave_response(response.json())

View file

@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return tools
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)
result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),

View file

@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]},
json={"api_key": api_key, "query": kwargs["query"]},
)
return ToolInvocationResult(

View file

@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": args["query"],
"input": kwargs["query"],
"appid": api_key,
"format": "plaintext",
"output": "json",

View file

@ -12,10 +12,10 @@ from ..conftest import (
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "ollama",
"safety": "llama_guard",
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together",
"safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "fireworks",
"safety": "llama_guard",
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "remote",
"safety": "remote",
"memory": "remote",
"vector_io": "remote",
"agents": "remote",
"tool_runtime": "memory_and_search",
},
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
"agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}

View file

@ -69,7 +69,7 @@ async def agents_stack(
providers = {}
provider_data = {}
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -118,7 +118,7 @@ async def agents_stack(
)
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
providers,
provider_data,
models=models,

View file

@ -214,9 +214,11 @@ class TestAgents:
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
# FIXME: we need to check the content of the turn response and ensure
# RAG actually worked
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params

View file

@ -153,7 +153,13 @@ def make_overlapped_chunks(
chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
)
)
return chunks

View file

@ -4,7 +4,7 @@ distribution_spec:
providers:
inference:
- remote::together
memory:
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector

View file

@ -5,7 +5,7 @@ apis:
- datasetio
- eval
- inference
- memory
- vector_io
- safety
- scoring
- telemetry
@ -20,7 +20,7 @@ providers:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory:
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
@ -145,7 +145,6 @@ models:
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []