forked from phoenix-oss/llama-stack-mirror
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
78a481bb22
commit
1a7490470a
33 changed files with 1648 additions and 1345 deletions
|
@ -172,10 +172,16 @@ def _get_endpoint_functions(
|
|||
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
||||
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
||||
|
||||
# This import must be dynamic here
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
|
||||
|
||||
# iterate in reverse member resolution order to find most specific class first
|
||||
for cls in reversed(inspect.getmro(derived_cls)):
|
||||
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
||||
if name == member_fn:
|
||||
# HACK ALERT
|
||||
if cls == RAGToolRuntime:
|
||||
return ToolRuntime
|
||||
return cls
|
||||
|
||||
raise ValidationError(
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -5,3 +5,4 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .tools import * # noqa: F401 F403
|
||||
from .rag_tool import * # noqa: F401 F403
|
||||
|
|
95
llama_stack/apis/tools/rag_tool.py
Normal file
95
llama_stack/apis/tools/rag_tool.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
content: Optional[InterleavedContent] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["default"] = "default"
|
||||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["llm"] = "llm"
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: RAGQueryGeneratorConfig = Field(
|
||||
default=DefaultRAGQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent"""
|
||||
...
|
|
@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
|
|||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
|
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
|
||||
class SpecialToolGroup(Enum):
|
||||
rag_tool = "rag_tool"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
rag_tool: RAGToolRuntime
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
|
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
|
|||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
||||
|
|
|
@ -333,6 +333,8 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
|
|
|
@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
@ -400,22 +407,55 @@ class EvalRouter(Eval):
|
|||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"rag_tool.query_context"
|
||||
).query_context(content, query_config, vector_db_ids)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"rag_tool.insert_documents"
|
||||
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
||||
self.rag_tool = self.RagToolImpl(routing_table)
|
||||
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
||||
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
|
|
|
@ -9,6 +9,8 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
# HACK ALERT
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroup:
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(
|
||||
sub_protocol, predicate=inspect.isfunction
|
||||
)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
|
|
|
@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
|
@ -62,6 +62,7 @@ class LlamaStack(
|
|||
Inspect,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v5"
|
||||
KEY_VERSION = "v6"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
@ -19,9 +19,8 @@ async def get_provider_impl(
|
|||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
||||
deps[Api.vector_io],
|
||||
deps[Api.safety],
|
||||
deps[Api.memory_banks],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
)
|
||||
|
|
|
@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
MEMORY_QUERY_TOOL = "rag_tool.query_context"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
MEMORY_GROUP = "builtin::memory"
|
||||
|
||||
|
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||
if memory_tool_group is None:
|
||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||
|
||||
if MEMORY_GROUP in toolgroups and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(memory_tool_group, {}),
|
||||
}
|
||||
|
||||
args = toolgroup_args.get(MEMORY_GROUP, {})
|
||||
vector_db_ids = args.get("vector_db_ids", [])
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
if "memory_bank_ids" not in query_args:
|
||||
query_args["memory_bank_ids"] = []
|
||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||
vector_db_ids.append(session_info.memory_bank_id)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
result = await self.tool_runtime_api.rag_tool.query_context(
|
||||
content=concat_interleaved_content(
|
||||
[msg.content for msg in input_messages]
|
||||
),
|
||||
query_config=RAGQueryConfig(
|
||||
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=5,
|
||||
),
|
||||
vector_db_ids=vector_db_ids,
|
||||
)
|
||||
retrieved_context = result.content
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=result.content,
|
||||
content=retrieved_context or [],
|
||||
)
|
||||
],
|
||||
),
|
||||
|
@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("output", retrieved_context)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
if result.error_code == 0:
|
||||
if retrieved_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
last_message.context = retrieved_context
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
|
||||
# TODO: the semantic for registration is definitely not "creation"
|
||||
# so we need to fix it if we expect the agent to create a new vector db
|
||||
# for each session
|
||||
await self.vector_io_api.register_vector_db(
|
||||
vector_db_id=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
|
@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
vector_db_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
|
@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for a in data
|
||||
]
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
|
||||
|
@ -955,7 +970,7 @@ async def execute_tool_call_maybe(
|
|||
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
kwargs=dict(
|
||||
session_id=session_id,
|
||||
**tool_call_args,
|
||||
),
|
||||
|
|
|
@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
|
|||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
|
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
config: MetaReferenceAgentsImplConfig,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
vector_io_api: VectorIO,
|
||||
safety_api: Safety,
|
||||
memory_banks_api: MemoryBanks,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
|
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tempdir=self.tempdir,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
memory_banks_api=self.memory_banks_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
|
|
|
@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
script = args["code"]
|
||||
script = kwargs["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
|
|
|
@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
|
|||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(
|
||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||
)
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,87 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolConfig(BaseModel):
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
pass
|
||||
|
|
|
@ -5,68 +5,64 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
RAGQueryGenerator,
|
||||
RAGQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: RAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == MemoryQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
||||
elif config.type == MemoryQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
||||
if config.type == RAGQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, content, **kwargs)
|
||||
elif config.type == RAGQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: DefaultRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||
return interleaved_content_as_str(content, sep=config.separator)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[InterleavedContent],
|
||||
config: LLMRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {
|
||||
"messages": [
|
||||
message.model_dump() if isinstance(message, BaseModel) else message
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
messages = []
|
||||
if isinstance(content, list):
|
||||
messages = [interleaved_content_as_str(m) for m in content]
|
||||
else:
|
||||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
||||
content = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
|
|
|
@ -10,20 +10,29 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
):
|
||||
self.config = config
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
)
|
||||
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
async def _retrieve_context(
|
||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
input_messages,
|
||||
query_config.query_generator_config,
|
||||
content,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
self.vector_io_api.query_chunks(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": self.config.max_chunks,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
|
@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: self.config.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > self.config.max_tokens_in_context:
|
||||
for c in chunks[: query_config.max_chunks]:
|
||||
metadata = c.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=[
|
||||
TextContentItem(
|
||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
*picked,
|
||||
TextContentItem(
|
||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
ToolDef(
|
||||
name="rag_tool.query_context",
|
||||
description="Retrieve context from memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="rag_tool.insert_documents",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
return ToolInvocationResult(
|
||||
content=concat_interleaved_content(context), error_code=0
|
||||
raise RuntimeError(
|
||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||
)
|
||||
|
|
|
@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
|
|
|
@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
headers = {
|
||||
|
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
|
|||
"count": self.config.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": args["query"],
|
||||
"q": kwargs["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
|
|
|
@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": args["query"]}
|
||||
payload = {"q": kwargs["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
|
|
|
@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
return tools
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
|
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, args)
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
|
|
|
@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": args["query"]},
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
|
|
|
@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
|
|||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
params = {
|
||||
"input": args["query"],
|
||||
"input": kwargs["query"],
|
||||
"appid": api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
|
|
|
@ -12,10 +12,10 @@ from ..conftest import (
|
|||
get_test_config_for_api,
|
||||
)
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
|
||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
|
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "ollama",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "fireworks",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "remote",
|
||||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"vector_io": "remote",
|
||||
"agents": "remote",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
|
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
|
|||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ async def agents_stack(
|
|||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
@ -118,7 +118,7 @@ async def agents_stack(
|
|||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
|
|
|
@ -214,9 +214,11 @@ class TestAgents:
|
|||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
# FIXME: we need to check the content of the turn response and ensure
|
||||
# RAG actually worked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
|
|
|
@ -8,13 +8,12 @@ import uuid
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
make_overlapped_chunks,
|
||||
MemoryBankDocument,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
docs = [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
|||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryBankDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
|
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
|
|||
return ret
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
async def content_from_doc(doc: RAGDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
|
@ -161,7 +153,13 @@ def make_overlapped_chunks(
|
|||
chunk = tokenizer.decode(toks)
|
||||
# chunk is a string
|
||||
chunks.append(
|
||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||
Chunk(
|
||||
content=chunk,
|
||||
metadata={
|
||||
"token_count": len(toks),
|
||||
"document_id": document_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
|
105
llama_stack/scripts/test_rag_via_curl.py
Normal file
105
llama_stack/scripts/test_rag_via_curl.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
|
||||
|
||||
|
||||
class TestRAGToolEndpoints:
|
||||
@pytest.fixture
|
||||
def base_url(self) -> str:
|
||||
return "http://localhost:8321/v1" # Adjust port if needed
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents(self) -> List[RAGDocument]:
|
||||
return [
|
||||
RAGDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_workflow(
|
||||
self, base_url: str, sample_documents: List[RAGDocument]
|
||||
):
|
||||
vector_db_payload = {
|
||||
"vector_db_id": "test_vector_db",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"embedding_dimension": 384,
|
||||
}
|
||||
|
||||
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
|
||||
assert response.status_code == 200
|
||||
vector_db = VectorDB(**response.json())
|
||||
|
||||
insert_payload = {
|
||||
"documents": [
|
||||
json.loads(doc.model_dump_json()) for doc in sample_documents
|
||||
],
|
||||
"vector_db_id": vector_db.identifier,
|
||||
"chunk_size_in_tokens": 512,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/tool-runtime/rag-tool/insert-documents",
|
||||
json=insert_payload,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
query = "What is Python?"
|
||||
query_config = RAGQueryConfig(
|
||||
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=2,
|
||||
)
|
||||
|
||||
query_payload = {
|
||||
"content": query,
|
||||
"query_config": json.loads(query_config.model_dump_json()),
|
||||
"vector_db_ids": [vector_db.identifier],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/tool-runtime/rag-tool/query-context",
|
||||
json=query_payload,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
result = TypeAdapter(RAGQueryResult).validate_python(result)
|
||||
|
||||
content_str = interleaved_content_as_str(result.content)
|
||||
print(f"content: {content_str}")
|
||||
assert len(content_str) > 0
|
||||
assert "Python" in content_str
|
||||
|
||||
# Clean up: Delete the vector DB
|
||||
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
|
||||
assert response.status_code == 200
|
|
@ -4,7 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::together
|
||||
memory:
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
|
|
|
@ -5,7 +5,7 @@ apis:
|
|||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- memory
|
||||
- vector_io
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
|
@ -20,7 +20,7 @@ providers:
|
|||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
memory:
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
|
@ -145,7 +145,6 @@ models:
|
|||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue