mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
RAG Agent test passes
This commit is contained in:
parent
2f76de1643
commit
a1433c0899
19 changed files with 157 additions and 76 deletions
|
@ -74,17 +74,17 @@ class RAGQueryConfig(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class RAGToolRuntime(Protocol):
|
class RAGToolRuntime(Protocol):
|
||||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: List[RAGDocument],
|
||||||
vector_db_ids: List[str],
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Index documents so they can be used by the RAG system"""
|
"""Index documents so they can be used by the RAG system"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
|
||||||
async def query_context(
|
async def query_context(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
|
|
|
@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
from llama_stack.apis.tools import (
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
|
ToolDef,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
@ -400,6 +407,33 @@ class EvalRouter(Eval):
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
class RagToolImpl(RAGToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def query_context(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
query_config: RAGQueryConfig,
|
||||||
|
vector_db_ids: List[str],
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
return await self.routing_table.get_provider_impl(
|
||||||
|
"rag_tool.query_context"
|
||||||
|
).query_context(content, query_config, vector_db_ids)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
documents: List[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
return await self.routing_table.get_provider_impl(
|
||||||
|
"rag_tool.insert_documents"
|
||||||
|
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
@ -408,7 +442,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
# TODO: this should be in sync with "get_all_api_endpoints()"
|
# TODO: this should be in sync with "get_all_api_endpoints()"
|
||||||
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
||||||
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory")
|
self.rag_tool = self.RagToolImpl(routing_table)
|
||||||
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
||||||
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
||||||
|
|
||||||
|
@ -418,10 +452,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
args=args,
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.tools import SpecialToolGroups
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroups
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
|
@ -24,21 +24,29 @@ class ApiEndpoint(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
def toolgroup_protocol_map():
|
||||||
|
return {
|
||||||
|
SpecialToolGroups.rag_tool: RAGToolRuntime,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
toolgroup_protocols = toolgroup_protocol_map()
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
endpoints = []
|
endpoints = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
if api == Api.tool_runtime:
|
if api == Api.tool_runtime:
|
||||||
for tool_group in SpecialToolGroups:
|
for tool_group in SpecialToolGroups:
|
||||||
print(f"tool_group: {tool_group}")
|
sub_protocol = toolgroup_protocols[tool_group]
|
||||||
sub_protocol = getattr(protocol, tool_group.value)
|
|
||||||
sub_protocol_methods = inspect.getmembers(
|
sub_protocol_methods = inspect.getmembers(
|
||||||
sub_protocol, predicate=inspect.isfunction
|
sub_protocol, predicate=inspect.isfunction
|
||||||
)
|
)
|
||||||
for name, method in sub_protocol_methods:
|
for name, method in sub_protocol_methods:
|
||||||
|
if not hasattr(method, "__webmethod__"):
|
||||||
|
continue
|
||||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||||
|
|
||||||
for name, method in protocol_methods:
|
for name, method in protocol_methods:
|
||||||
|
@ -47,7 +55,6 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
|
||||||
webmethod = method.__webmethod__
|
webmethod = method.__webmethod__
|
||||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||||
|
|
||||||
if webmethod.method == "GET":
|
if webmethod.method == "GET":
|
||||||
method = "get"
|
method = "get"
|
||||||
elif webmethod.method == "DELETE":
|
elif webmethod.method == "DELETE":
|
||||||
|
|
|
@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v5"
|
KEY_VERSION = "v6"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -373,21 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
# TODO: simplify all of this code, it can be simpler
|
||||||
toolgroup_args = {}
|
toolgroup_args = {}
|
||||||
|
toolgroups = set()
|
||||||
for toolgroup in self.agent_config.toolgroups:
|
for toolgroup in self.agent_config.toolgroups:
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
toolgroups.add(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
|
else:
|
||||||
|
toolgroups.add(toolgroup)
|
||||||
if toolgroups_for_turn:
|
if toolgroups_for_turn:
|
||||||
for toolgroup in toolgroups_for_turn:
|
for toolgroup in toolgroups_for_turn:
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
toolgroups.add(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
|
else:
|
||||||
|
toolgroups.add(toolgroup)
|
||||||
|
|
||||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(
|
await self.handle_documents(
|
||||||
session_id, documents, input_messages, tool_defs
|
session_id, documents, input_messages, tool_defs
|
||||||
)
|
)
|
||||||
if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
|
|
||||||
|
if "builtin::memory" in toolgroups and len(input_messages) > 0:
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -399,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
args = toolgroup_args["builtin::memory"]
|
args = toolgroup_args.get("builtin::memory", {})
|
||||||
vector_db_ids = args.get("vector_db_ids", [])
|
vector_db_ids = args.get("vector_db_ids", [])
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
|
||||||
|
@ -423,7 +432,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
retrieved_context = await self.tool_runtime_api.rag_tool.query_context(
|
result = await self.tool_runtime_api.rag_tool.query_context(
|
||||||
content=concat_interleaved_content(
|
content=concat_interleaved_content(
|
||||||
[msg.content for msg in input_messages]
|
[msg.content for msg in input_messages]
|
||||||
),
|
),
|
||||||
|
@ -434,6 +443,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
),
|
),
|
||||||
vector_db_ids=vector_db_ids,
|
vector_db_ids=vector_db_ids,
|
||||||
)
|
)
|
||||||
|
retrieved_context = result.content
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -862,7 +872,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def add_to_session_memory_bank(
|
async def add_to_session_memory_bank(
|
||||||
self, session_id: str, data: List[Document]
|
self, session_id: str, data: List[Document]
|
||||||
) -> None:
|
) -> None:
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
vector_db_id = await self._ensure_memory_bank(session_id)
|
||||||
documents = [
|
documents = [
|
||||||
RAGDocument(
|
RAGDocument(
|
||||||
document_id=str(uuid.uuid4()),
|
document_id=str(uuid.uuid4()),
|
||||||
|
@ -874,7 +884,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
]
|
]
|
||||||
await self.tool_runtime_api.rag_tool.insert_documents(
|
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
vector_db_ids=[bank_id],
|
vector_db_id=vector_db_id,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
script = args["code"]
|
script = kwargs["code"]
|
||||||
req = CodeExecutionRequest(scripts=[script])
|
req = CodeExecutionRequest(scripts=[script])
|
||||||
res = self.code_executor.execute(req)
|
res = self.code_executor.execute(req)
|
||||||
pieces = [res["process_status"]]
|
pieces = [res["process_status"]]
|
||||||
|
|
|
@ -5,68 +5,64 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import UserMessage
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
|
||||||
|
from llama_stack.apis.tools.rag_tool import (
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
RAGQueryGenerator,
|
||||||
|
RAGQueryGeneratorConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import (
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
MemoryQueryGenerator,
|
|
||||||
MemoryQueryGeneratorConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
config: MemoryQueryGeneratorConfig,
|
config: RAGQueryGeneratorConfig,
|
||||||
messages: List[InterleavedContent],
|
content: InterleavedContent,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a query that will be used for
|
Generates a query that will be used for
|
||||||
retrieving relevant information from the memory bank.
|
retrieving relevant information from the memory bank.
|
||||||
"""
|
"""
|
||||||
if config.type == MemoryQueryGenerator.default.value:
|
if config.type == RAGQueryGenerator.default.value:
|
||||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
query = await default_rag_query_generator(config, content, **kwargs)
|
||||||
elif config.type == MemoryQueryGenerator.llm.value:
|
elif config.type == RAGQueryGenerator.llm.value:
|
||||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
async def default_rag_query_generator(
|
async def default_rag_query_generator(
|
||||||
config: DefaultMemoryQueryGeneratorConfig,
|
config: DefaultRAGQueryGeneratorConfig,
|
||||||
messages: List[InterleavedContent],
|
content: InterleavedContent,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.separator.join(interleaved_content_as_str(m) for m in messages)
|
return interleaved_content_as_str(content, sep=config.separator)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
config: LLMMemoryQueryGeneratorConfig,
|
config: LLMRAGQueryGeneratorConfig,
|
||||||
messages: List[InterleavedContent],
|
content: InterleavedContent,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
inference_api = kwargs["inference_api"]
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
m_dict = {
|
messages = []
|
||||||
"messages": [
|
if isinstance(content, list):
|
||||||
message.model_dump() if isinstance(message, BaseModel) else message
|
messages = [interleaved_content_as_str(m) for m in content]
|
||||||
for message in messages
|
else:
|
||||||
]
|
messages = [interleaved_content_as_str(content)]
|
||||||
}
|
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render(m_dict)
|
content = template.render({"messages": messages})
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
message = UserMessage(content=content)
|
message = UserMessage(content=content)
|
||||||
|
|
|
@ -27,6 +27,10 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
content_from_doc,
|
||||||
|
make_overlapped_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MemoryToolRuntimeConfig
|
from .config import MemoryToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
@ -60,10 +64,28 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: List[RAGDocument],
|
||||||
vector_db_ids: List[str],
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
chunks = []
|
||||||
|
for doc in documents:
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
chunks.extend(
|
||||||
|
make_overlapped_chunks(
|
||||||
|
doc.document_id,
|
||||||
|
content,
|
||||||
|
chunk_size_in_tokens,
|
||||||
|
chunk_size_in_tokens // 4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.vector_io_api.insert_chunks(
|
||||||
|
chunks=chunks,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def query_context(
|
async def query_context(
|
||||||
self,
|
self,
|
||||||
|
@ -104,13 +126,18 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
for c in chunks[: query_config.max_chunks]:
|
for c in chunks[: query_config.max_chunks]:
|
||||||
tokens += c.token_count
|
metadata = c.metadata
|
||||||
|
tokens += metadata["token_count"]
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
log.error(
|
log.error(
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
picked.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return RAGQueryResult(
|
return RAGQueryResult(
|
||||||
content=[
|
content=[
|
||||||
|
|
|
@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
|
||||||
"count": self.config.top_k,
|
"count": self.config.top_k,
|
||||||
"textDecorations": True,
|
"textDecorations": True,
|
||||||
"textFormat": "HTML",
|
"textFormat": "HTML",
|
||||||
"q": args["query"],
|
"q": kwargs["query"],
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
|
|
|
@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
url = "https://api.search.brave.com/res/v1/web/search"
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
"Accept-Encoding": "gzip",
|
"Accept-Encoding": "gzip",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
payload = {"q": args["query"]}
|
payload = {"q": kwargs["query"]}
|
||||||
response = requests.get(url=url, params=payload, headers=headers)
|
response = requests.get(url=url, params=payload, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = self._clean_brave_response(response.json())
|
results = self._clean_brave_response(response.json())
|
||||||
|
|
|
@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||||
|
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
async with sse_client(endpoint) as streams:
|
async with sse_client(endpoint) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
result = await session.call_tool(tool.identifier, args)
|
result = await session.call_tool(tool.identifier, kwargs)
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||||
|
|
|
@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.tavily.com/search",
|
"https://api.tavily.com/search",
|
||||||
json={"api_key": api_key, "query": args["query"]},
|
json={"api_key": api_key, "query": kwargs["query"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
|
|
|
@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
params = {
|
params = {
|
||||||
"input": args["query"],
|
"input": kwargs["query"],
|
||||||
"appid": api_key,
|
"appid": api_key,
|
||||||
"format": "plaintext",
|
"format": "plaintext",
|
||||||
"output": "json",
|
"output": "json",
|
||||||
|
|
|
@ -12,10 +12,10 @@ from ..conftest import (
|
||||||
get_test_config_for_api,
|
get_test_config_for_api,
|
||||||
)
|
)
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
|
|
||||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
|
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "meta_reference",
|
"inference": "meta_reference",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "ollama",
|
"inference": "ollama",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "fireworks",
|
"inference": "fireworks",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "remote",
|
"inference": "remote",
|
||||||
"safety": "remote",
|
"safety": "remote",
|
||||||
"memory": "remote",
|
"vector_io": "remote",
|
||||||
"agents": "remote",
|
"agents": "remote",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
"safety": SAFETY_FIXTURES,
|
"safety": SAFETY_FIXTURES,
|
||||||
"memory": MEMORY_FIXTURES,
|
"vector_io": VECTOR_IO_FIXTURES,
|
||||||
"agents": AGENTS_FIXTURES,
|
"agents": AGENTS_FIXTURES,
|
||||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ async def agents_stack(
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = fixture.providers
|
providers[key] = fixture.providers
|
||||||
if key == "inference":
|
if key == "inference":
|
||||||
|
@ -118,7 +118,7 @@ async def agents_stack(
|
||||||
)
|
)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=models,
|
models=models,
|
||||||
|
|
|
@ -214,9 +214,11 @@ class TestAgents:
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
# FIXME: we need to check the content of the turn response and ensure
|
||||||
|
# RAG actually worked
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn_with_tavily_search(
|
async def test_create_agent_turn_with_tavily_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
|
|
|
@ -153,7 +153,13 @@ def make_overlapped_chunks(
|
||||||
chunk = tokenizer.decode(toks)
|
chunk = tokenizer.decode(toks)
|
||||||
# chunk is a string
|
# chunk is a string
|
||||||
chunks.append(
|
chunks.append(
|
||||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
Chunk(
|
||||||
|
content=chunk,
|
||||||
|
metadata={
|
||||||
|
"token_count": len(toks),
|
||||||
|
"document_id": document_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
|
@ -4,7 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::together
|
- remote::together
|
||||||
memory:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
- remote::pgvector
|
- remote::pgvector
|
||||||
|
|
|
@ -5,7 +5,7 @@ apis:
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- inference
|
- inference
|
||||||
- memory
|
- vector_io
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
@ -20,7 +20,7 @@ providers:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
memory:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
config:
|
config:
|
||||||
|
@ -145,7 +145,6 @@ models:
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
memory_banks: []
|
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
eval_tasks: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue