mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Introduce RAGToolRuntime as a specialized sub-protocol
This commit is contained in:
parent
78a481bb22
commit
2f76de1643
16 changed files with 260 additions and 224 deletions
|
@ -5,3 +5,4 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .tools import * # noqa: F401 F403
|
from .tools import * # noqa: F401 F403
|
||||||
|
from .rag_tool import * # noqa: F401 F403
|
||||||
|
|
95
llama_stack/apis/tools/rag_tool.py
Normal file
95
llama_stack/apis/tools/rag_tool.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGDocument(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
content: InterleavedContent | URL
|
||||||
|
mime_type: str | None = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGQueryResult(BaseModel):
|
||||||
|
content: Optional[InterleavedContent] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGQueryGenerator(Enum):
|
||||||
|
default = "default"
|
||||||
|
llm = "llm"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal["default"] = "default"
|
||||||
|
separator: str = " "
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal["llm"] = "llm"
|
||||||
|
model: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
|
||||||
|
RAGQueryGeneratorConfig = register_schema(
|
||||||
|
Annotated[
|
||||||
|
Union[
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
],
|
||||||
|
name="RAGQueryGeneratorConfig",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGQueryConfig(BaseModel):
|
||||||
|
# This config defines how a query is generated using the messages
|
||||||
|
# for memory bank retrieval.
|
||||||
|
query_generator_config: RAGQueryGeneratorConfig = Field(
|
||||||
|
default=DefaultRAGQueryGeneratorConfig()
|
||||||
|
)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class RAGToolRuntime(Protocol):
|
||||||
|
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
documents: List[RAGDocument],
|
||||||
|
vector_db_ids: List[str],
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
"""Index documents so they can be used by the RAG system"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
||||||
|
async def query_context(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
query_config: RAGQueryConfig,
|
||||||
|
vector_db_ids: List[str],
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
"""Query the RAG system for context; typically invoked by the agent"""
|
||||||
|
...
|
|
@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .rag_tool import RAGToolRuntime
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolParameter(BaseModel):
|
class ToolParameter(BaseModel):
|
||||||
|
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialToolGroups(Enum):
|
||||||
|
rag_tool = "rag_tool"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore
|
||||||
|
|
||||||
|
rag_tool: RAGToolRuntime
|
||||||
|
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -333,6 +333,8 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
|
# TODO: check compliance for special tool groups
|
||||||
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
if (
|
if (
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
|
|
|
@ -406,6 +406,12 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
# TODO: this should be in sync with "get_all_api_endpoints()"
|
||||||
|
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
||||||
|
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory")
|
||||||
|
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
||||||
|
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,8 @@ from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import SpecialToolGroups
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
|
@ -29,6 +31,15 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
endpoints = []
|
endpoints = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
if api == Api.tool_runtime:
|
||||||
|
for tool_group in SpecialToolGroups:
|
||||||
|
print(f"tool_group: {tool_group}")
|
||||||
|
sub_protocol = getattr(protocol, tool_group.value)
|
||||||
|
sub_protocol_methods = inspect.getmembers(
|
||||||
|
sub_protocol, predicate=inspect.isfunction
|
||||||
|
)
|
||||||
|
for name, method in sub_protocol_methods:
|
||||||
|
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||||
|
|
||||||
for name, method in protocol_methods:
|
for name, method in protocol_methods:
|
||||||
if not hasattr(method, "__webmethod__"):
|
if not hasattr(method, "__webmethod__"):
|
||||||
|
|
|
@ -19,9 +19,8 @@ async def get_provider_impl(
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config,
|
config,
|
||||||
deps[Api.inference],
|
deps[Api.inference],
|
||||||
deps[Api.memory],
|
deps[Api.vector_io],
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.memory_banks],
|
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
)
|
)
|
||||||
|
|
|
@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import (
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
ToolGroups,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "query_memory"
|
MEMORY_QUERY_TOOL = "rag_tool.query_context"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
MEMORY_GROUP = "builtin::memory"
|
MEMORY_GROUP = "builtin::memory"
|
||||||
|
|
||||||
|
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
tempdir: str,
|
tempdir: str,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
|
||||||
memory_banks_api: MemoryBanks,
|
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
|
vector_io_api: VectorIO,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.tempdir = tempdir
|
self.tempdir = tempdir
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
|
||||||
self.memory_banks_api = memory_banks_api
|
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
self.vector_io_api = vector_io_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
@ -384,10 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
await self.handle_documents(
|
await self.handle_documents(
|
||||||
session_id, documents, input_messages, tool_defs
|
session_id, documents, input_messages, tool_defs
|
||||||
)
|
)
|
||||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
|
||||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
|
||||||
if memory_tool_group is None:
|
|
||||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -398,17 +398,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
query_args = {
|
|
||||||
"messages": [msg.content for msg in input_messages],
|
|
||||||
**toolgroup_args.get(memory_tool_group, {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
args = toolgroup_args["builtin::memory"]
|
||||||
|
vector_db_ids = args.get("vector_db_ids", [])
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
if "memory_bank_ids" not in query_args:
|
vector_db_ids.append(session_info.memory_bank_id)
|
||||||
query_args["memory_bank_ids"] = []
|
|
||||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -425,9 +423,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
retrieved_context = await self.tool_runtime_api.rag_tool.query_context(
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
content=concat_interleaved_content(
|
||||||
args=query_args,
|
[msg.content for msg in input_messages]
|
||||||
|
),
|
||||||
|
query_config=RAGQueryConfig(
|
||||||
|
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||||
|
max_tokens_in_context=4096,
|
||||||
|
max_chunks=5,
|
||||||
|
),
|
||||||
|
vector_db_ids=vector_db_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -449,7 +454,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
ToolResponse(
|
ToolResponse(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
content=result.content,
|
content=retrieved_context or [],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -459,13 +464,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
)
|
)
|
||||||
span.set_attribute("output", result.content)
|
span.set_attribute("output", retrieved_context)
|
||||||
span.set_attribute("error_code", result.error_code)
|
span.set_attribute("tool_name", "builtin::memory")
|
||||||
span.set_attribute("error_message", result.error_message)
|
if retrieved_context:
|
||||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
|
||||||
if result.error_code == 0:
|
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = result.content
|
last_message.context = retrieved_context
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
|
@ -842,12 +845,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
if session_info.memory_bank_id is None:
|
||||||
bank_id = f"memory_bank_{session_id}"
|
bank_id = f"memory_bank_{session_id}"
|
||||||
await self.memory_banks_api.register_memory_bank(
|
|
||||||
memory_bank_id=bank_id,
|
# TODO: the semantic for registration is definitely not "creation"
|
||||||
params=VectorMemoryBankParams(
|
# so we need to fix it if we expect the agent to create a new vector db
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
# for each session
|
||||||
chunk_size_in_tokens=512,
|
await self.vector_io_api.register_vector_db(
|
||||||
),
|
vector_db_id=bank_id,
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
)
|
)
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
|
@ -860,7 +864,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
) -> None:
|
) -> None:
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
bank_id = await self._ensure_memory_bank(session_id)
|
||||||
documents = [
|
documents = [
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id=str(uuid.uuid4()),
|
document_id=str(uuid.uuid4()),
|
||||||
content=a.content,
|
content=a.content,
|
||||||
mime_type=a.mime_type,
|
mime_type=a.mime_type,
|
||||||
|
@ -868,9 +872,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
for a in data
|
for a in data
|
||||||
]
|
]
|
||||||
await self.memory_api.insert_documents(
|
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||||
bank_id=bank_id,
|
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
vector_db_ids=[bank_id],
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
|
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceAgentsImplConfig,
|
config: MetaReferenceAgentsImplConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
vector_io_api: VectorIO,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
memory_banks_api: MemoryBanks,
|
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.vector_io_api = vector_io_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.memory_banks_api = memory_banks_api
|
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
|
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tempdir=self.tempdir,
|
tempdir=self.tempdir,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
vector_io_api=self.vector_io_api,
|
||||||
memory_banks_api=self.memory_banks_api,
|
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
|
|
|
@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||||
impl = MemoryToolRuntimeImpl(
|
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,87 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from pydantic import BaseModel
|
||||||
from typing import Annotated, List, Literal, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class _MemoryBankConfigCommon(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["vector"] = "vector"
|
|
||||||
|
|
||||||
|
|
||||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyvalue"] = "keyvalue"
|
|
||||||
keys: List[str] # what keys to focus on
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyword"] = "keyword"
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["graph"] = "graph"
|
|
||||||
entities: List[str] # what entities to focus on
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
VectorMemoryBankConfig,
|
|
||||||
KeyValueMemoryBankConfig,
|
|
||||||
KeywordMemoryBankConfig,
|
|
||||||
GraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryQueryGenerator(Enum):
|
|
||||||
default = "default"
|
|
||||||
llm = "llm"
|
|
||||||
custom = "custom"
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
||||||
MemoryQueryGenerator.default.value
|
|
||||||
)
|
|
||||||
sep: str = " "
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
||||||
model: str
|
|
||||||
template: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryQueryGeneratorConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
CustomMemoryQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolConfig(BaseModel):
|
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeConfig(BaseModel):
|
class MemoryToolRuntimeConfig(BaseModel):
|
||||||
# This config defines how a query is generated using the messages
|
pass
|
||||||
# for memory bank retrieval.
|
|
||||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
||||||
default=DefaultMemoryQueryGeneratorConfig()
|
|
||||||
)
|
|
||||||
max_tokens_in_context: int = 4096
|
|
||||||
max_chunks: int = 5
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ async def default_rag_query_generator(
|
||||||
messages: List[InterleavedContent],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
return config.separator.join(interleaved_content_as_str(m) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
|
|
|
@ -10,20 +10,25 @@ import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
InterleavedContent,
|
||||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
TextContentItem,
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
URL,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
|
|
||||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
from .config import MemoryToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -35,65 +40,61 @@ def make_random_string(length: int = 8):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MemoryToolRuntimeConfig,
|
config: MemoryToolRuntimeConfig,
|
||||||
memory_api: Memory,
|
vector_io_api: VectorIO,
|
||||||
memory_banks_api: MemoryBanks,
|
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.memory_api = memory_api
|
self.vector_io_api = vector_io_api
|
||||||
self.memory_banks_api = memory_banks_api
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def shutdown(self):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
pass
|
||||||
) -> List[ToolDef]:
|
|
||||||
return [
|
async def insert_documents(
|
||||||
ToolDef(
|
self,
|
||||||
name="query_memory",
|
documents: List[RAGDocument],
|
||||||
description="Retrieve context from memory",
|
vector_db_ids: List[str],
|
||||||
parameters=[
|
chunk_size_in_tokens: int = 512,
|
||||||
ToolParameter(
|
) -> None:
|
||||||
name="messages",
|
pass
|
||||||
description="The input messages to search for",
|
|
||||||
parameter_type="array",
|
async def query_context(
|
||||||
),
|
self,
|
||||||
],
|
content: InterleavedContent,
|
||||||
)
|
query_config: RAGQueryConfig,
|
||||||
]
|
vector_db_ids: List[str],
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
if not vector_db_ids:
|
||||||
|
return RAGQueryResult(content=None)
|
||||||
|
|
||||||
async def _retrieve_context(
|
|
||||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
|
||||||
) -> Optional[List[InterleavedContent]]:
|
|
||||||
if not bank_ids:
|
|
||||||
return None
|
|
||||||
query = await generate_rag_query(
|
query = await generate_rag_query(
|
||||||
self.config.query_generator_config,
|
query_config.query_generator_config,
|
||||||
input_messages,
|
content,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
tasks = [
|
tasks = [
|
||||||
self.memory_api.query_documents(
|
self.vector_io_api.query_chunks(
|
||||||
bank_id=bank_id,
|
vector_db_id=vector_db_id,
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": self.config.max_chunks,
|
"max_chunks": query_config.max_chunks,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for bank_id in bank_ids
|
for vector_db_id in vector_db_ids
|
||||||
]
|
]
|
||||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||||
chunks = [c for r in results for c in r.chunks]
|
chunks = [c for r in results for c in r.chunks]
|
||||||
scores = [s for r in results for s in r.scores]
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return None
|
return RAGQueryResult(content=None)
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(
|
chunks, scores = zip(
|
||||||
|
@ -102,45 +103,47 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
for c in chunks[: self.config.max_chunks]:
|
for c in chunks[: query_config.max_chunks]:
|
||||||
tokens += c.token_count
|
tokens += c.token_count
|
||||||
if tokens > self.config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
log.error(
|
log.error(
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
|
return RAGQueryResult(
|
||||||
|
content=[
|
||||||
|
TextContentItem(
|
||||||
|
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
),
|
||||||
|
*picked,
|
||||||
|
TextContentItem(
|
||||||
|
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
# Parameters are not listed since these methods are not yet invoked automatically
|
||||||
|
# by the LLM. The method is only implemented so things like /tools can list without
|
||||||
|
# encountering fatals.
|
||||||
return [
|
return [
|
||||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
ToolDef(
|
||||||
*picked,
|
name="rag_tool.query_context",
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
description="Retrieve context from memory",
|
||||||
|
),
|
||||||
|
ToolDef(
|
||||||
|
name="rag_tool.insert_documents",
|
||||||
|
description="Insert documents into memory",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
raise RuntimeError(
|
||||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||||
final_args = tool_group.args or {}
|
|
||||||
final_args.update(args)
|
|
||||||
config = MemoryToolConfig()
|
|
||||||
if tool.metadata and tool.metadata.get("config") is not None:
|
|
||||||
config = MemoryToolConfig(**tool.metadata["config"])
|
|
||||||
if "memory_bank_ids" in final_args:
|
|
||||||
bank_ids = final_args["memory_bank_ids"]
|
|
||||||
else:
|
|
||||||
bank_ids = [
|
|
||||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
|
||||||
]
|
|
||||||
if "messages" not in final_args:
|
|
||||||
raise ValueError("messages are required")
|
|
||||||
context = await self._retrieve_context(
|
|
||||||
final_args["messages"],
|
|
||||||
bank_ids,
|
|
||||||
)
|
|
||||||
if context is None:
|
|
||||||
context = []
|
|
||||||
return ToolInvocationResult(
|
|
||||||
content=concat_interleaved_content(context), error_code=0
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||||
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
|
api_dependencies=[Api.vector_io, Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
|
|
|
@ -8,13 +8,12 @@ import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import RAGDocument
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||||
make_overlapped_chunks,
|
|
||||||
MemoryBankDocument,
|
|
||||||
)
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sample_chunks():
|
def sample_chunks():
|
||||||
docs = [
|
docs = [
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc1",
|
document_id="doc1",
|
||||||
content="Python is a high-level programming language.",
|
content="Python is a high-level programming language.",
|
||||||
metadata={"category": "programming", "difficulty": "beginner"},
|
metadata={"category": "programming", "difficulty": "beginner"},
|
||||||
),
|
),
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc2",
|
document_id="doc2",
|
||||||
content="Machine learning is a subset of artificial intelligence.",
|
content="Machine learning is a subset of artificial intelligence.",
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
),
|
),
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc3",
|
document_id="doc3",
|
||||||
content="Data structures are fundamental to computer science.",
|
content="Data structures are fundamental to computer science.",
|
||||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||||
),
|
),
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc4",
|
document_id="doc4",
|
||||||
content="Neural networks are inspired by biological neural networks.",
|
content="Neural networks are inspired by biological neural networks.",
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
URL,
|
URL,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryBankDocument(BaseModel):
|
|
||||||
document_id: str
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str | None = None
|
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
pdf_bytes = io.BytesIO(data)
|
pdf_bytes = io.BytesIO(data)
|
||||||
|
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
async def content_from_doc(doc: RAGDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
if doc.content.uri.startswith("data:"):
|
if doc.content.uri.startswith("data:"):
|
||||||
return content_from_data(doc.content.uri)
|
return content_from_data(doc.content.uri)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue