mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
RAG Agent test passes
This commit is contained in:
parent
2f76de1643
commit
a1433c0899
19 changed files with 157 additions and 76 deletions
|
@ -74,17 +74,17 @@ class RAGQueryConfig(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_ids: List[str],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
||||
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
|
|
|
@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
@ -400,6 +407,33 @@ class EvalRouter(Eval):
|
|||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"rag_tool.query_context"
|
||||
).query_context(content, query_config, vector_db_ids)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"rag_tool.insert_documents"
|
||||
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
|
@ -408,7 +442,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
# TODO: this should be in sync with "get_all_api_endpoints()"
|
||||
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
||||
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory")
|
||||
self.rag_tool = self.RagToolImpl(routing_table)
|
||||
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
||||
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
||||
|
||||
|
@ -418,10 +452,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.tools import SpecialToolGroups
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroups
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
|
@ -24,21 +24,29 @@ class ApiEndpoint(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroups.rag_tool: RAGToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroups:
|
||||
print(f"tool_group: {tool_group}")
|
||||
sub_protocol = getattr(protocol, tool_group.value)
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(
|
||||
sub_protocol, predicate=inspect.isfunction
|
||||
)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
|
@ -47,7 +55,6 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
|
||||
webmethod = method.__webmethod__
|
||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
|
|
|
@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v5"
|
||||
KEY_VERSION = "v6"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
@ -373,21 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
|
||||
|
||||
if "builtin::memory" in toolgroups and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -399,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
args = toolgroup_args["builtin::memory"]
|
||||
args = toolgroup_args.get("builtin::memory", {})
|
||||
vector_db_ids = args.get("vector_db_ids", [])
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
||||
|
@ -423,7 +432,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
retrieved_context = await self.tool_runtime_api.rag_tool.query_context(
|
||||
result = await self.tool_runtime_api.rag_tool.query_context(
|
||||
content=concat_interleaved_content(
|
||||
[msg.content for msg in input_messages]
|
||||
),
|
||||
|
@ -434,6 +443,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
),
|
||||
vector_db_ids=vector_db_ids,
|
||||
)
|
||||
retrieved_context = result.content
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -862,7 +872,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
vector_db_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
RAGDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
|
@ -874,7 +884,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
]
|
||||
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||
documents=documents,
|
||||
vector_db_ids=[bank_id],
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
|
|
|
@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
script = args["code"]
|
||||
script = kwargs["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
|
|
|
@ -5,68 +5,64 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
RAGQueryGenerator,
|
||||
RAGQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: RAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
if config.type == RAGQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, content, **kwargs)
|
||||
elif config.type == RAGQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: DefaultRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
return config.separator.join(interleaved_content_as_str(m) for m in messages)
|
||||
return interleaved_content_as_str(content, sep=config.separator)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: LLMRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {
|
||||
"messages": [
|
||||
message.model_dump() if isinstance(message, BaseModel) else message
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
messages = []
|
||||
if isinstance(content, list):
|
||||
messages = [interleaved_content_as_str(m) for m in content]
|
||||
else:
|
||||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
content = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
|
|
|
@ -27,6 +27,10 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
@ -60,10 +64,28 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_ids: List[str],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
pass
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
)
|
||||
|
||||
async def query_context(
|
||||
self,
|
||||
|
@ -104,13 +126,18 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: query_config.max_chunks]:
|
||||
tokens += c.token_count
|
||||
metadata = c.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=[
|
||||
|
|
|
@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
headers = {
|
||||
|
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
|
|||
"count": self.config.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": args["query"],
|
||||
"q": kwargs["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
|
|
|
@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": args["query"]}
|
||||
payload = {"q": kwargs["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
|
|
|
@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
return tools
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
|
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, args)
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
|
|
|
@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": args["query"]},
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
|
|
|
@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
params = {
|
||||
"input": args["query"],
|
||||
"input": kwargs["query"],
|
||||
"appid": api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
|
|
|
@ -12,10 +12,10 @@ from ..conftest import (
|
|||
get_test_config_for_api,
|
||||
)
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
|
||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
|
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "ollama",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "fireworks",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "remote",
|
||||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"vector_io": "remote",
|
||||
"agents": "remote",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
|
|||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ async def agents_stack(
|
|||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
@ -118,7 +118,7 @@ async def agents_stack(
|
|||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
|
|
|
@ -214,9 +214,11 @@ class TestAgents:
|
|||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
# FIXME: we need to check the content of the turn response and ensure
|
||||
# RAG actually worked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
|
|
|
@ -153,7 +153,13 @@ def make_overlapped_chunks(
|
|||
chunk = tokenizer.decode(toks)
|
||||
# chunk is a string
|
||||
chunks.append(
|
||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||
Chunk(
|
||||
content=chunk,
|
||||
metadata={
|
||||
"token_count": len(toks),
|
||||
"document_id": document_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
|
|
@ -4,7 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::together
|
||||
memory:
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
|
|
|
@ -5,7 +5,7 @@ apis:
|
|||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- memory
|
||||
- vector_io
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
|
@ -20,7 +20,7 @@ providers:
|
|||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
memory:
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
|
@ -145,7 +145,6 @@ models:
|
|||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue