mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
Introduce RAGToolRuntime as a specialized sub-protocol
This commit is contained in:
parent
78a481bb22
commit
2f76de1643
16 changed files with 260 additions and 224 deletions
|
@ -5,3 +5,4 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .tools import * # noqa: F401 F403
|
||||
from .rag_tool import * # noqa: F401 F403
|
||||
|
|
95
llama_stack/apis/tools/rag_tool.py
Normal file
95
llama_stack/apis/tools/rag_tool.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
content: Optional[InterleavedContent] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["default"] = "default"
|
||||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["llm"] = "llm"
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: RAGQueryGeneratorConfig = Field(
|
||||
default=DefaultRAGQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_ids: List[str],
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent"""
|
||||
...
|
|
@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
|
|||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
|
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
|
||||
class SpecialToolGroups(Enum):
|
||||
rag_tool = "rag_tool"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
rag_tool: RAGToolRuntime
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
|
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
|
|||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
||||
|
|
|
@ -333,6 +333,8 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
|
|
|
@ -406,6 +406,12 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
# TODO: this should be in sync with "get_all_api_endpoints()"
|
||||
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
||||
self.rag_tool = self.routing_table.get_provider_impl("builtin::memory")
|
||||
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
||||
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.tools import SpecialToolGroups
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
@ -29,6 +31,15 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroups:
|
||||
print(f"tool_group: {tool_group}")
|
||||
sub_protocol = getattr(protocol, tool_group.value)
|
||||
sub_protocol_methods = inspect.getmembers(
|
||||
sub_protocol, predicate=inspect.isfunction
|
||||
)
|
||||
for name, method in sub_protocol_methods:
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
|
|
|
@ -19,9 +19,8 @@ async def get_provider_impl(
|
|||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.memory],
|
||||
deps[Api.vector_io],
|
||||
deps[Api.safety],
|
||||
deps[Api.memory_banks],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
)
|
||||
|
|
|
@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
MEMORY_QUERY_TOOL = "rag_tool.query_context"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
MEMORY_GROUP = "builtin::memory"
|
||||
|
||||
|
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
@ -384,10 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||
if memory_tool_group is None:
|
||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||
if "builtin::memory" in toolgroup_args and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -398,17 +398,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(memory_tool_group, {}),
|
||||
}
|
||||
|
||||
args = toolgroup_args["builtin::memory"]
|
||||
vector_db_ids = args.get("vector_db_ids", [])
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
if "memory_bank_ids" not in query_args:
|
||||
query_args["memory_bank_ids"] = []
|
||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||
vector_db_ids.append(session_info.memory_bank_id)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -425,9 +423,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
retrieved_context = await self.tool_runtime_api.rag_tool.query_context(
|
||||
content=concat_interleaved_content(
|
||||
[msg.content for msg in input_messages]
|
||||
),
|
||||
query_config=RAGQueryConfig(
|
||||
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=5,
|
||||
),
|
||||
vector_db_ids=vector_db_ids,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -449,7 +454,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=result.content,
|
||||
content=retrieved_context or [],
|
||||
)
|
||||
],
|
||||
),
|
||||
|
@ -459,13 +464,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
if result.error_code == 0:
|
||||
span.set_attribute("output", retrieved_context)
|
||||
span.set_attribute("tool_name", "builtin::memory")
|
||||
if retrieved_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
last_message.context = retrieved_context
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
@ -842,12 +845,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
|
||||
# TODO: the semantic for registration is definitely not "creation"
|
||||
# so we need to fix it if we expect the agent to create a new vector db
|
||||
# for each session
|
||||
await self.vector_io_api.register_vector_db(
|
||||
vector_db_id=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
|
@ -860,7 +864,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
|
@ -868,9 +872,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for a in data
|
||||
]
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||
documents=documents,
|
||||
vector_db_ids=[bank_id],
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
|
|||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
|
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
config: MetaReferenceAgentsImplConfig,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
vector_io_api: VectorIO,
|
||||
safety_api: Safety,
|
||||
memory_banks_api: MemoryBanks,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
|
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tempdir=self.tempdir,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
memory_banks_api=self.memory_banks_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
|
|
|
@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
|
|||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(
|
||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||
)
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,87 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolConfig(BaseModel):
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
pass
|
||||
|
|
|
@ -47,7 +47,7 @@ async def default_rag_query_generator(
|
|||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||
return config.separator.join(interleaved_content_as_str(m) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
|
|
|
@ -10,20 +10,25 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
|
||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -35,65 +40,61 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
):
|
||||
self.config = config
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_ids: List[str],
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
async def _retrieve_context(
|
||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
input_messages,
|
||||
query_config.query_generator_config,
|
||||
content,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
self.vector_io_api.query_chunks(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": self.config.max_chunks,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
|
@ -102,45 +103,47 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: self.config.max_chunks]:
|
||||
for c in chunks[: query_config.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > self.config.max_tokens_in_context:
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return RAGQueryResult(
|
||||
content=[
|
||||
TextContentItem(
|
||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
*picked,
|
||||
TextContentItem(
|
||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
ToolDef(
|
||||
name="rag_tool.query_context",
|
||||
description="Retrieve context from memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="rag_tool.insert_documents",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
self, tool_name: str, kwargs: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
return ToolInvocationResult(
|
||||
content=concat_interleaved_content(context), error_code=0
|
||||
raise RuntimeError(
|
||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||
)
|
||||
|
|
|
@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
|
|
|
@ -8,13 +8,12 @@ import uuid
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
make_overlapped_chunks,
|
||||
MemoryBankDocument,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
docs = [
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
RAGDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
|||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryBankDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
|
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
|
|||
return ret
|
||||
|
||||
|
||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
async def content_from_doc(doc: RAGDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue