RAG Agent test passes

This commit is contained in:
Ashwin Bharambe 2025-01-21 15:16:17 -08:00
parent 2f76de1643
commit a1433c0899
19 changed files with 157 additions and 76 deletions

View file

@ -74,17 +74,17 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class RAGToolRuntime(Protocol): class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST") @webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
async def insert_documents( async def insert_documents(
self, self,
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_ids: List[str], vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
"""Index documents so they can be used by the RAG system""" """Index documents so they can be used by the RAG system"""
... ...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST") @webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
async def query_context( async def query_context(
self, self,
content: InterleavedContent, content: InterleavedContent,

View file

@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -400,6 +407,33 @@ class EvalRouter(Eval):
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
@ -408,7 +442,7 @@ class ToolRuntimeRouter(ToolRuntime):
# TODO: this should be in sync with "get_all_api_endpoints()" # TODO: this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere # TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory") self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context) setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents) setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
@ -418,10 +452,10 @@ class ToolRuntimeRouter(ToolRuntime):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
args=args, kwargs=kwargs,
) )
async def list_runtime_tools( async def list_runtime_tools(

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.tools import SpecialToolGroups from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroups
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
@ -24,21 +24,29 @@ class ApiEndpoint(BaseModel):
name: str name: str
def toolgroup_protocol_map():
return {
SpecialToolGroups.rag_tool: RAGToolRuntime,
}
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
endpoints = [] endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
if api == Api.tool_runtime: if api == Api.tool_runtime:
for tool_group in SpecialToolGroups: for tool_group in SpecialToolGroups:
print(f"tool_group: {tool_group}") sub_protocol = toolgroup_protocols[tool_group]
sub_protocol = getattr(protocol, tool_group.value)
sub_protocol_methods = inspect.getmembers( sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction sub_protocol, predicate=inspect.isfunction
) )
for name, method in sub_protocol_methods: for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method)) protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods: for name, method in protocol_methods:
@ -47,7 +55,6 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
webmethod = method.__webmethod__ webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET": if webmethod.method == "GET":
method = "get" method = "get"
elif webmethod.method == "DELETE": elif webmethod.method == "DELETE":

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v5" KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -373,21 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {} toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups: for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn: if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn: for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents: if documents:
await self.handle_documents( await self.handle_documents(
session_id, documents, input_messages, tool_defs session_id, documents, input_messages, tool_defs
) )
if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
if "builtin::memory" in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span: with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -399,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
args = toolgroup_args["builtin::memory"] args = toolgroup_args.get("builtin::memory", {})
vector_db_ids = args.get("vector_db_ids", []) vector_db_ids = args.get("vector_db_ids", [])
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
@ -423,7 +432,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
retrieved_context = await self.tool_runtime_api.rag_tool.query_context( result = await self.tool_runtime_api.rag_tool.query_context(
content=concat_interleaved_content( content=concat_interleaved_content(
[msg.content for msg in input_messages] [msg.content for msg in input_messages]
), ),
@ -434,6 +443,7 @@ class ChatAgent(ShieldRunnerMixin):
), ),
vector_db_ids=vector_db_ids, vector_db_ids=vector_db_ids,
) )
retrieved_context = result.content
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -862,7 +872,7 @@ class ChatAgent(ShieldRunnerMixin):
async def add_to_session_memory_bank( async def add_to_session_memory_bank(
self, session_id: str, data: List[Document] self, session_id: str, data: List[Document]
) -> None: ) -> None:
bank_id = await self._ensure_memory_bank(session_id) vector_db_id = await self._ensure_memory_bank(session_id)
documents = [ documents = [
RAGDocument( RAGDocument(
document_id=str(uuid.uuid4()), document_id=str(uuid.uuid4()),
@ -874,7 +884,7 @@ class ChatAgent(ShieldRunnerMixin):
] ]
await self.tool_runtime_api.rag_tool.insert_documents( await self.tool_runtime_api.rag_tool.insert_documents(
documents=documents, documents=documents,
vector_db_ids=[bank_id], vector_db_id=vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )

View file

@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
script = args["code"] script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script]) req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req) res = self.code_executor.execute(req)
pieces = [res["process_status"]] pieces = [res["process_status"]]

View file

@ -5,68 +5,64 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import List
from jinja2 import Template from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
async def generate_rag_query( async def generate_rag_query(
config: MemoryQueryGeneratorConfig, config: RAGQueryGeneratorConfig,
messages: List[InterleavedContent], content: InterleavedContent,
**kwargs, **kwargs,
): ):
""" """
Generates a query that will be used for Generates a query that will be used for
retrieving relevant information from the memory bank. retrieving relevant information from the memory bank.
""" """
if config.type == MemoryQueryGenerator.default.value: if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs) query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value: elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs) query = await llm_rag_query_generator(config, content, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}") raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query return query
async def default_rag_query_generator( async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig, config: DefaultRAGQueryGeneratorConfig,
messages: List[InterleavedContent], content: InterleavedContent,
**kwargs, **kwargs,
): ):
return config.separator.join(interleaved_content_as_str(m) for m in messages) return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator( async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig, config: LLMRAGQueryGeneratorConfig,
messages: List[InterleavedContent], content: InterleavedContent,
**kwargs, **kwargs,
): ):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"] inference_api = kwargs["inference_api"]
m_dict = { messages = []
"messages": [ if isinstance(content, list):
message.model_dump() if isinstance(message, BaseModel) else message messages = [interleaved_content_as_str(m) for m in content]
for message in messages else:
] messages = [interleaved_content_as_str(content)]
}
template = Template(config.template) template = Template(config.template)
content = template.render(m_dict) content = template.render({"messages": messages})
model = config.model model = config.model
message = UserMessage(content=content) message = UserMessage(content=content)

View file

@ -27,6 +27,10 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import MemoryToolRuntimeConfig from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
@ -60,10 +64,28 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def insert_documents( async def insert_documents(
self, self,
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_ids: List[str], vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
pass chunks = []
for doc in documents:
content = await content_from_doc(doc)
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query_context( async def query_context(
self, self,
@ -104,13 +126,18 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
tokens = 0 tokens = 0
picked = [] picked = []
for c in chunks[: query_config.max_chunks]: for c in chunks[: query_config.max_chunks]:
tokens += c.token_count metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context: if tokens > query_config.max_tokens_in_context:
log.error( log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
) )
break break
picked.append(f"id:{c.document_id}; content:{c.content}") picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
)
)
return RAGQueryResult( return RAGQueryResult(
content=[ content=[

View file

@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
headers = { headers = {
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
"count": self.config.top_k, "count": self.config.top_k,
"textDecorations": True, "textDecorations": True,
"textFormat": "HTML", "textFormat": "HTML",
"q": args["query"], "q": kwargs["query"],
} }
response = requests.get( response = requests.get(

View file

@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search" url = "https://api.search.brave.com/res/v1/web/search"
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
"Accept-Encoding": "gzip", "Accept-Encoding": "gzip",
"Accept": "application/json", "Accept": "application/json",
} }
payload = {"q": args["query"]} payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers) response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
results = self._clean_brave_response(response.json()) results = self._clean_brave_response(response.json())

View file

@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return tools return tools
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None: if tool.metadata is None or tool.metadata.get("endpoint") is None:
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client(endpoint) as streams: async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
result = await session.call_tool(tool.identifier, args) result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult( return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]), content="\n".join([result.model_dump_json() for result in result.content]),

View file

@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
response = requests.post( response = requests.post(
"https://api.tavily.com/search", "https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]}, json={"api_key": api_key, "query": kwargs["query"]},
) )
return ToolInvocationResult( return ToolInvocationResult(

View file

@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
params = { params = {
"input": args["query"], "input": kwargs["query"],
"appid": api_key, "appid": api_key,
"format": "plaintext", "format": "plaintext",
"output": "json", "output": "json",

View file

@ -12,10 +12,10 @@ from ..conftest import (
get_test_config_for_api, get_test_config_for_api,
) )
from ..inference.fixtures import INFERENCE_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import AGENTS_FIXTURES from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [ DEFAULT_PROVIDER_COMBINATIONS = [
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "meta_reference", "inference": "meta_reference",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "ollama", "inference": "ollama",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together", "inference": "together",
"safety": "llama_guard", "safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports # make this work with Weaviate which is what the together distro supports
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "fireworks", "inference": "fireworks",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "remote", "inference": "remote",
"safety": "remote", "safety": "remote",
"memory": "remote", "vector_io": "remote",
"agents": "remote", "agents": "remote",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
available_fixtures = { available_fixtures = {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES, "safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES, "vector_io": VECTOR_IO_FIXTURES,
"agents": AGENTS_FIXTURES, "agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES, "tool_runtime": TOOL_RUNTIME_FIXTURES,
} }

View file

@ -69,7 +69,7 @@ async def agents_stack(
providers = {} providers = {}
provider_data = {} provider_data = {}
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]: for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers providers[key] = fixture.providers
if key == "inference": if key == "inference":
@ -118,7 +118,7 @@ async def agents_stack(
) )
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], [Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
providers, providers,
provider_data, provider_data,
models=models, models=models,

View file

@ -214,9 +214,11 @@ class TestAgents:
turn_response = [ turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
# FIXME: we need to check the content of the turn response and ensure
# RAG actually worked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search( async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params self, agents_stack, search_query_messages, common_params

View file

@ -153,7 +153,13 @@ def make_overlapped_chunks(
chunk = tokenizer.decode(toks) chunk = tokenizer.decode(toks)
# chunk is a string # chunk is a string
chunks.append( chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id) Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
)
) )
return chunks return chunks

View file

@ -4,7 +4,7 @@ distribution_spec:
providers: providers:
inference: inference:
- remote::together - remote::together
memory: vector_io:
- inline::faiss - inline::faiss
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector

View file

@ -5,7 +5,7 @@ apis:
- datasetio - datasetio
- eval - eval
- inference - inference
- memory - vector_io
- safety - safety
- scoring - scoring
- telemetry - telemetry
@ -20,7 +20,7 @@ providers:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
memory: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
config: config:
@ -145,7 +145,6 @@ models:
model_type: embedding model_type: embedding
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []