mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
78a481bb22
commit
1a7490470a
33 changed files with 1648 additions and 1345 deletions
|
@ -172,10 +172,16 @@ def _get_endpoint_functions(
|
||||||
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
||||||
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
||||||
|
|
||||||
|
# This import must be dynamic here
|
||||||
|
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
|
||||||
|
|
||||||
# iterate in reverse member resolution order to find most specific class first
|
# iterate in reverse member resolution order to find most specific class first
|
||||||
for cls in reversed(inspect.getmro(derived_cls)):
|
for cls in reversed(inspect.getmro(derived_cls)):
|
||||||
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
||||||
if name == member_fn:
|
if name == member_fn:
|
||||||
|
# HACK ALERT
|
||||||
|
if cls == RAGToolRuntime:
|
||||||
|
return ToolRuntime
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -5,3 +5,4 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .tools import * # noqa: F401 F403
|
from .tools import * # noqa: F401 F403
|
||||||
|
from .rag_tool import * # noqa: F401 F403
|
||||||
|
|
95
llama_stack/apis/tools/rag_tool.py
Normal file
95
llama_stack/apis/tools/rag_tool.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGDocument(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
content: InterleavedContent | URL
|
||||||
|
mime_type: str | None = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGQueryResult(BaseModel):
|
||||||
|
content: Optional[InterleavedContent] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGQueryGenerator(Enum):
|
||||||
|
default = "default"
|
||||||
|
llm = "llm"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal["default"] = "default"
|
||||||
|
separator: str = " "
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal["llm"] = "llm"
|
||||||
|
model: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
|
||||||
|
RAGQueryGeneratorConfig = register_schema(
|
||||||
|
Annotated[
|
||||||
|
Union[
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
],
|
||||||
|
name="RAGQueryGeneratorConfig",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGQueryConfig(BaseModel):
|
||||||
|
# This config defines how a query is generated using the messages
|
||||||
|
# for memory bank retrieval.
|
||||||
|
query_generator_config: RAGQueryGeneratorConfig = Field(
|
||||||
|
default=DefaultRAGQueryGeneratorConfig()
|
||||||
|
)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class RAGToolRuntime(Protocol):
|
||||||
|
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
documents: List[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
"""Index documents so they can be used by the RAG system"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
|
||||||
|
async def query_context(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
query_config: RAGQueryConfig,
|
||||||
|
vector_db_ids: List[str],
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
"""Query the RAG system for context; typically invoked by the agent"""
|
||||||
|
...
|
|
@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
from .rag_tool import RAGToolRuntime
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolParameter(BaseModel):
|
class ToolParameter(BaseModel):
|
||||||
|
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialToolGroup(Enum):
|
||||||
|
rag_tool = "rag_tool"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore
|
||||||
|
|
||||||
|
rag_tool: RAGToolRuntime
|
||||||
|
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -333,6 +333,8 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
|
# TODO: check compliance for special tool groups
|
||||||
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
if (
|
if (
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
|
|
|
@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
from llama_stack.apis.tools import (
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
|
ToolDef,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
@ -400,22 +407,55 @@ class EvalRouter(Eval):
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
class RagToolImpl(RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def query_context(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
query_config: RAGQueryConfig,
|
||||||
|
vector_db_ids: List[str],
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
return await self.routing_table.get_provider_impl(
|
||||||
|
"rag_tool.query_context"
|
||||||
|
).query_context(content, query_config, vector_db_ids)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
documents: List[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
return await self.routing_table.get_provider_impl(
|
||||||
|
"rag_tool.insert_documents"
|
||||||
|
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
|
||||||
|
self.rag_tool = self.RagToolImpl(routing_table)
|
||||||
|
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
|
||||||
|
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
args=args,
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
|
|
|
@ -9,6 +9,8 @@ from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
|
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
|
@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
def toolgroup_protocol_map():
|
||||||
|
return {
|
||||||
|
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
toolgroup_protocols = toolgroup_protocol_map()
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
endpoints = []
|
endpoints = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
|
# HACK ALERT
|
||||||
|
if api == Api.tool_runtime:
|
||||||
|
for tool_group in SpecialToolGroup:
|
||||||
|
sub_protocol = toolgroup_protocols[tool_group]
|
||||||
|
sub_protocol_methods = inspect.getmembers(
|
||||||
|
sub_protocol, predicate=inspect.isfunction
|
||||||
|
)
|
||||||
|
for name, method in sub_protocol_methods:
|
||||||
|
if not hasattr(method, "__webmethod__"):
|
||||||
|
continue
|
||||||
|
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||||
|
|
||||||
for name, method in protocol_methods:
|
for name, method in protocol_methods:
|
||||||
if not hasattr(method, "__webmethod__"):
|
if not hasattr(method, "__webmethod__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
webmethod = method.__webmethod__
|
webmethod = method.__webmethod__
|
||||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||||
|
|
||||||
if webmethod.method == "GET":
|
if webmethod.method == "GET":
|
||||||
method = "get"
|
method = "get"
|
||||||
elif webmethod.method == "DELETE":
|
elif webmethod.method == "DELETE":
|
||||||
|
|
|
@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDBs
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
|
@ -62,6 +62,7 @@ class LlamaStack(
|
||||||
Inspect,
|
Inspect,
|
||||||
ToolGroups,
|
ToolGroups,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
|
RAGToolRuntime,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v5"
|
KEY_VERSION = "v6"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,8 @@ async def get_provider_impl(
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config,
|
config,
|
||||||
deps[Api.inference],
|
deps[Api.inference],
|
||||||
deps[Api.memory],
|
deps[Api.vector_io],
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.memory_banks],
|
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
)
|
)
|
||||||
|
|
|
@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import (
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
ToolGroups,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "query_memory"
|
MEMORY_QUERY_TOOL = "rag_tool.query_context"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
MEMORY_GROUP = "builtin::memory"
|
MEMORY_GROUP = "builtin::memory"
|
||||||
|
|
||||||
|
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
tempdir: str,
|
tempdir: str,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
|
||||||
memory_banks_api: MemoryBanks,
|
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
|
vector_io_api: VectorIO,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.tempdir = tempdir
|
self.tempdir = tempdir
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
|
||||||
self.memory_banks_api = memory_banks_api
|
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
self.vector_io_api = vector_io_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
# TODO: simplify all of this code, it can be simpler
|
||||||
toolgroup_args = {}
|
toolgroup_args = {}
|
||||||
|
toolgroups = set()
|
||||||
for toolgroup in self.agent_config.toolgroups:
|
for toolgroup in self.agent_config.toolgroups:
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
toolgroups.add(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
|
else:
|
||||||
|
toolgroups.add(toolgroup)
|
||||||
if toolgroups_for_turn:
|
if toolgroups_for_turn:
|
||||||
for toolgroup in toolgroups_for_turn:
|
for toolgroup in toolgroups_for_turn:
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
|
toolgroups.add(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
|
else:
|
||||||
|
toolgroups.add(toolgroup)
|
||||||
|
|
||||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(
|
await self.handle_documents(
|
||||||
session_id, documents, input_messages, tool_defs
|
session_id, documents, input_messages, tool_defs
|
||||||
)
|
)
|
||||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
|
||||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
if MEMORY_GROUP in toolgroups and len(input_messages) > 0:
|
||||||
if memory_tool_group is None:
|
|
||||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
query_args = {
|
|
||||||
"messages": [msg.content for msg in input_messages],
|
|
||||||
**toolgroup_args.get(memory_tool_group, {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
args = toolgroup_args.get(MEMORY_GROUP, {})
|
||||||
|
vector_db_ids = args.get("vector_db_ids", [])
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
if "memory_bank_ids" not in query_args:
|
vector_db_ids.append(session_info.memory_bank_id)
|
||||||
query_args["memory_bank_ids"] = []
|
|
||||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.rag_tool.query_context(
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
content=concat_interleaved_content(
|
||||||
args=query_args,
|
[msg.content for msg in input_messages]
|
||||||
|
),
|
||||||
|
query_config=RAGQueryConfig(
|
||||||
|
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||||
|
max_tokens_in_context=4096,
|
||||||
|
max_chunks=5,
|
||||||
|
),
|
||||||
|
vector_db_ids=vector_db_ids,
|
||||||
)
|
)
|
||||||
|
retrieved_context = result.content
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
ToolResponse(
|
ToolResponse(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
content=result.content,
|
content=retrieved_context or [],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
)
|
)
|
||||||
span.set_attribute("output", result.content)
|
span.set_attribute("output", retrieved_context)
|
||||||
span.set_attribute("error_code", result.error_code)
|
|
||||||
span.set_attribute("error_message", result.error_message)
|
|
||||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||||
if result.error_code == 0:
|
if retrieved_context:
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = result.content
|
last_message.context = retrieved_context
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
|
@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
if session_info.memory_bank_id is None:
|
||||||
bank_id = f"memory_bank_{session_id}"
|
bank_id = f"memory_bank_{session_id}"
|
||||||
await self.memory_banks_api.register_memory_bank(
|
|
||||||
memory_bank_id=bank_id,
|
# TODO: the semantic for registration is definitely not "creation"
|
||||||
params=VectorMemoryBankParams(
|
# so we need to fix it if we expect the agent to create a new vector db
|
||||||
|
# for each session
|
||||||
|
await self.vector_io_api.register_vector_db(
|
||||||
|
vector_db_id=bank_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
|
@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def add_to_session_memory_bank(
|
async def add_to_session_memory_bank(
|
||||||
self, session_id: str, data: List[Document]
|
self, session_id: str, data: List[Document]
|
||||||
) -> None:
|
) -> None:
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
vector_db_id = await self._ensure_memory_bank(session_id)
|
||||||
documents = [
|
documents = [
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id=str(uuid.uuid4()),
|
document_id=str(uuid.uuid4()),
|
||||||
content=a.content,
|
content=a.content,
|
||||||
mime_type=a.mime_type,
|
mime_type=a.mime_type,
|
||||||
|
@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
for a in data
|
for a in data
|
||||||
]
|
]
|
||||||
await self.memory_api.insert_documents(
|
await self.tool_runtime_api.rag_tool.insert_documents(
|
||||||
bank_id=bank_id,
|
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -955,7 +970,7 @@ async def execute_tool_call_maybe(
|
||||||
|
|
||||||
result = await tool_runtime_api.invoke_tool(
|
result = await tool_runtime_api.invoke_tool(
|
||||||
tool_name=name,
|
tool_name=name,
|
||||||
args=dict(
|
kwargs=dict(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
**tool_call_args,
|
**tool_call_args,
|
||||||
),
|
),
|
||||||
|
|
|
@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
|
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceAgentsImplConfig,
|
config: MetaReferenceAgentsImplConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
vector_io_api: VectorIO,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
memory_banks_api: MemoryBanks,
|
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.vector_io_api = vector_io_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.memory_banks_api = memory_banks_api
|
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
|
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tempdir=self.tempdir,
|
tempdir=self.tempdir,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
vector_io_api=self.vector_io_api,
|
||||||
memory_banks_api=self.memory_banks_api,
|
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
|
|
|
@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
script = args["code"]
|
script = kwargs["code"]
|
||||||
req = CodeExecutionRequest(scripts=[script])
|
req = CodeExecutionRequest(scripts=[script])
|
||||||
res = self.code_executor.execute(req)
|
res = self.code_executor.execute(req)
|
||||||
pieces = [res["process_status"]]
|
pieces = [res["process_status"]]
|
||||||
|
|
|
@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||||
impl = MemoryToolRuntimeImpl(
|
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,87 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from pydantic import BaseModel
|
||||||
from typing import Annotated, List, Literal, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class _MemoryBankConfigCommon(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["vector"] = "vector"
|
|
||||||
|
|
||||||
|
|
||||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyvalue"] = "keyvalue"
|
|
||||||
keys: List[str] # what keys to focus on
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyword"] = "keyword"
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["graph"] = "graph"
|
|
||||||
entities: List[str] # what entities to focus on
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
VectorMemoryBankConfig,
|
|
||||||
KeyValueMemoryBankConfig,
|
|
||||||
KeywordMemoryBankConfig,
|
|
||||||
GraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryQueryGenerator(Enum):
|
|
||||||
default = "default"
|
|
||||||
llm = "llm"
|
|
||||||
custom = "custom"
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
||||||
MemoryQueryGenerator.default.value
|
|
||||||
)
|
|
||||||
sep: str = " "
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
||||||
model: str
|
|
||||||
template: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryQueryGeneratorConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
CustomMemoryQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolConfig(BaseModel):
|
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeConfig(BaseModel):
|
class MemoryToolRuntimeConfig(BaseModel):
|
||||||
# This config defines how a query is generated using the messages
|
pass
|
||||||
# for memory bank retrieval.
|
|
||||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
||||||
default=DefaultMemoryQueryGeneratorConfig()
|
|
||||||
)
|
|
||||||
max_tokens_in_context: int = 4096
|
|
||||||
max_chunks: int = 5
|
|
||||||
|
|
|
@ -5,68 +5,64 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import UserMessage
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
|
||||||
|
from llama_stack.apis.tools.rag_tool import (
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
RAGQueryGenerator,
|
||||||
|
RAGQueryGeneratorConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import (
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
MemoryQueryGenerator,
|
|
||||||
MemoryQueryGeneratorConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
config: MemoryQueryGeneratorConfig,
|
config: RAGQueryGeneratorConfig,
|
||||||
messages: List[InterleavedContent],
|
content: InterleavedContent,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a query that will be used for
|
Generates a query that will be used for
|
||||||
retrieving relevant information from the memory bank.
|
retrieving relevant information from the memory bank.
|
||||||
"""
|
"""
|
||||||
if config.type == MemoryQueryGenerator.default.value:
|
if config.type == RAGQueryGenerator.default.value:
|
||||||
query = await default_rag_query_generator(config, messages, **kwargs)
|
query = await default_rag_query_generator(config, content, **kwargs)
|
||||||
elif config.type == MemoryQueryGenerator.llm.value:
|
elif config.type == RAGQueryGenerator.llm.value:
|
||||||
query = await llm_rag_query_generator(config, messages, **kwargs)
|
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
async def default_rag_query_generator(
|
async def default_rag_query_generator(
|
||||||
config: DefaultMemoryQueryGeneratorConfig,
|
config: DefaultRAGQueryGeneratorConfig,
|
||||||
messages: List[InterleavedContent],
|
content: InterleavedContent,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
return interleaved_content_as_str(content, sep=config.separator)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
config: LLMMemoryQueryGeneratorConfig,
|
config: LLMRAGQueryGeneratorConfig,
|
||||||
messages: List[InterleavedContent],
|
content: InterleavedContent,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
inference_api = kwargs["inference_api"]
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
m_dict = {
|
messages = []
|
||||||
"messages": [
|
if isinstance(content, list):
|
||||||
message.model_dump() if isinstance(message, BaseModel) else message
|
messages = [interleaved_content_as_str(m) for m in content]
|
||||||
for message in messages
|
else:
|
||||||
]
|
messages = [interleaved_content_as_str(content)]
|
||||||
}
|
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render(m_dict)
|
content = template.render({"messages": messages})
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
message = UserMessage(content=content)
|
message = UserMessage(content=content)
|
||||||
|
|
|
@ -10,20 +10,29 @@ import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
InterleavedContent,
|
||||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
TextContentItem,
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
URL,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
content_from_doc,
|
||||||
|
make_overlapped_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
from .config import MemoryToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MemoryToolRuntimeConfig,
|
config: MemoryToolRuntimeConfig,
|
||||||
memory_api: Memory,
|
vector_io_api: VectorIO,
|
||||||
memory_banks_api: MemoryBanks,
|
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.memory_api = memory_api
|
self.vector_io_api = vector_io_api
|
||||||
self.memory_banks_api = memory_banks_api
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def shutdown(self):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
pass
|
||||||
) -> List[ToolDef]:
|
|
||||||
return [
|
async def insert_documents(
|
||||||
ToolDef(
|
self,
|
||||||
name="query_memory",
|
documents: List[RAGDocument],
|
||||||
description="Retrieve context from memory",
|
vector_db_id: str,
|
||||||
parameters=[
|
chunk_size_in_tokens: int = 512,
|
||||||
ToolParameter(
|
) -> None:
|
||||||
name="messages",
|
chunks = []
|
||||||
description="The input messages to search for",
|
for doc in documents:
|
||||||
parameter_type="array",
|
content = await content_from_doc(doc)
|
||||||
),
|
chunks.extend(
|
||||||
],
|
make_overlapped_chunks(
|
||||||
)
|
doc.document_id,
|
||||||
]
|
content,
|
||||||
|
chunk_size_in_tokens,
|
||||||
|
chunk_size_in_tokens // 4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.vector_io_api.insert_chunks(
|
||||||
|
chunks=chunks,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query_context(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
query_config: RAGQueryConfig,
|
||||||
|
vector_db_ids: List[str],
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
if not vector_db_ids:
|
||||||
|
return RAGQueryResult(content=None)
|
||||||
|
|
||||||
async def _retrieve_context(
|
|
||||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
|
||||||
) -> Optional[List[InterleavedContent]]:
|
|
||||||
if not bank_ids:
|
|
||||||
return None
|
|
||||||
query = await generate_rag_query(
|
query = await generate_rag_query(
|
||||||
self.config.query_generator_config,
|
query_config.query_generator_config,
|
||||||
input_messages,
|
content,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
tasks = [
|
tasks = [
|
||||||
self.memory_api.query_documents(
|
self.vector_io_api.query_chunks(
|
||||||
bank_id=bank_id,
|
vector_db_id=vector_db_id,
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": self.config.max_chunks,
|
"max_chunks": query_config.max_chunks,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for bank_id in bank_ids
|
for vector_db_id in vector_db_ids
|
||||||
]
|
]
|
||||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||||
chunks = [c for r in results for c in r.chunks]
|
chunks = [c for r in results for c in r.chunks]
|
||||||
scores = [s for r in results for s in r.scores]
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return None
|
return RAGQueryResult(content=None)
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(
|
chunks, scores = zip(
|
||||||
|
@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
for c in chunks[: self.config.max_chunks]:
|
for c in chunks[: query_config.max_chunks]:
|
||||||
tokens += c.token_count
|
metadata = c.metadata
|
||||||
if tokens > self.config.max_tokens_in_context:
|
tokens += metadata["token_count"]
|
||||||
|
if tokens > query_config.max_tokens_in_context:
|
||||||
log.error(
|
log.error(
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
picked.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return RAGQueryResult(
|
||||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
content=[
|
||||||
|
TextContentItem(
|
||||||
|
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
),
|
||||||
*picked,
|
*picked,
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
TextContentItem(
|
||||||
|
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
# Parameters are not listed since these methods are not yet invoked automatically
|
||||||
|
# by the LLM. The method is only implemented so things like /tools can list without
|
||||||
|
# encountering fatals.
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="rag_tool.query_context",
|
||||||
|
description="Retrieve context from memory",
|
||||||
|
),
|
||||||
|
ToolDef(
|
||||||
|
name="rag_tool.insert_documents",
|
||||||
|
description="Insert documents into memory",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
raise RuntimeError(
|
||||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||||
final_args = tool_group.args or {}
|
|
||||||
final_args.update(args)
|
|
||||||
config = MemoryToolConfig()
|
|
||||||
if tool.metadata and tool.metadata.get("config") is not None:
|
|
||||||
config = MemoryToolConfig(**tool.metadata["config"])
|
|
||||||
if "memory_bank_ids" in final_args:
|
|
||||||
bank_ids = final_args["memory_bank_ids"]
|
|
||||||
else:
|
|
||||||
bank_ids = [
|
|
||||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
|
||||||
]
|
|
||||||
if "messages" not in final_args:
|
|
||||||
raise ValueError("messages are required")
|
|
||||||
context = await self._retrieve_context(
|
|
||||||
final_args["messages"],
|
|
||||||
bank_ids,
|
|
||||||
)
|
|
||||||
if context is None:
|
|
||||||
context = []
|
|
||||||
return ToolInvocationResult(
|
|
||||||
content=concat_interleaved_content(context), error_code=0
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||||
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
|
api_dependencies=[Api.vector_io, Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
|
|
|
@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
|
||||||
"count": self.config.top_k,
|
"count": self.config.top_k,
|
||||||
"textDecorations": True,
|
"textDecorations": True,
|
||||||
"textFormat": "HTML",
|
"textFormat": "HTML",
|
||||||
"q": args["query"],
|
"q": kwargs["query"],
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
|
|
|
@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
url = "https://api.search.brave.com/res/v1/web/search"
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
"Accept-Encoding": "gzip",
|
"Accept-Encoding": "gzip",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
payload = {"q": args["query"]}
|
payload = {"q": kwargs["query"]}
|
||||||
response = requests.get(url=url, params=payload, headers=headers)
|
response = requests.get(url=url, params=payload, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = self._clean_brave_response(response.json())
|
results = self._clean_brave_response(response.json())
|
||||||
|
|
|
@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||||
|
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
async with sse_client(endpoint) as streams:
|
async with sse_client(endpoint) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
result = await session.call_tool(tool.identifier, args)
|
result = await session.call_tool(tool.identifier, kwargs)
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||||
|
|
|
@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.tavily.com/search",
|
"https://api.tavily.com/search",
|
||||||
json={"api_key": api_key, "query": args["query"]},
|
json={"api_key": api_key, "query": kwargs["query"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
|
|
|
@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, kwargs: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
api_key = self._get_api_key()
|
api_key = self._get_api_key()
|
||||||
params = {
|
params = {
|
||||||
"input": args["query"],
|
"input": kwargs["query"],
|
||||||
"appid": api_key,
|
"appid": api_key,
|
||||||
"format": "plaintext",
|
"format": "plaintext",
|
||||||
"output": "json",
|
"output": "json",
|
||||||
|
|
|
@ -12,10 +12,10 @@ from ..conftest import (
|
||||||
get_test_config_for_api,
|
get_test_config_for_api,
|
||||||
)
|
)
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
|
|
||||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
|
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "meta_reference",
|
"inference": "meta_reference",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "ollama",
|
"inference": "ollama",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "fireworks",
|
"inference": "fireworks",
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"vector_io": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
{
|
{
|
||||||
"inference": "remote",
|
"inference": "remote",
|
||||||
"safety": "remote",
|
"safety": "remote",
|
||||||
"memory": "remote",
|
"vector_io": "remote",
|
||||||
"agents": "remote",
|
"agents": "remote",
|
||||||
"tool_runtime": "memory_and_search",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
|
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
"safety": SAFETY_FIXTURES,
|
"safety": SAFETY_FIXTURES,
|
||||||
"memory": MEMORY_FIXTURES,
|
"vector_io": VECTOR_IO_FIXTURES,
|
||||||
"agents": AGENTS_FIXTURES,
|
"agents": AGENTS_FIXTURES,
|
||||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ async def agents_stack(
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = fixture.providers
|
providers[key] = fixture.providers
|
||||||
if key == "inference":
|
if key == "inference":
|
||||||
|
@ -118,7 +118,7 @@ async def agents_stack(
|
||||||
)
|
)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=models,
|
models=models,
|
||||||
|
|
|
@ -214,9 +214,11 @@ class TestAgents:
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
# FIXME: we need to check the content of the turn response and ensure
|
||||||
|
# RAG actually worked
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn_with_tavily_search(
|
async def test_create_agent_turn_with_tavily_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
|
|
|
@ -8,13 +8,12 @@ import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import RAGDocument
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||||
make_overlapped_chunks,
|
|
||||||
MemoryBankDocument,
|
|
||||||
)
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sample_chunks():
|
def sample_chunks():
|
||||||
docs = [
|
docs = [
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc1",
|
document_id="doc1",
|
||||||
content="Python is a high-level programming language.",
|
content="Python is a high-level programming language.",
|
||||||
metadata={"category": "programming", "difficulty": "beginner"},
|
metadata={"category": "programming", "difficulty": "beginner"},
|
||||||
),
|
),
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc2",
|
document_id="doc2",
|
||||||
content="Machine learning is a subset of artificial intelligence.",
|
content="Machine learning is a subset of artificial intelligence.",
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
),
|
),
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc3",
|
document_id="doc3",
|
||||||
content="Data structures are fundamental to computer science.",
|
content="Data structures are fundamental to computer science.",
|
||||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||||
),
|
),
|
||||||
MemoryBankDocument(
|
RAGDocument(
|
||||||
document_id="doc4",
|
document_id="doc4",
|
||||||
content="Neural networks are inspired by biological neural networks.",
|
content="Neural networks are inspired by biological neural networks.",
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
URL,
|
URL,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryBankDocument(BaseModel):
|
|
||||||
document_id: str
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str | None = None
|
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
pdf_bytes = io.BytesIO(data)
|
pdf_bytes = io.BytesIO(data)
|
||||||
|
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
async def content_from_doc(doc: RAGDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
if doc.content.uri.startswith("data:"):
|
if doc.content.uri.startswith("data:"):
|
||||||
return content_from_data(doc.content.uri)
|
return content_from_data(doc.content.uri)
|
||||||
|
@ -161,7 +153,13 @@ def make_overlapped_chunks(
|
||||||
chunk = tokenizer.decode(toks)
|
chunk = tokenizer.decode(toks)
|
||||||
# chunk is a string
|
# chunk is a string
|
||||||
chunks.append(
|
chunks.append(
|
||||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
Chunk(
|
||||||
|
content=chunk,
|
||||||
|
metadata={
|
||||||
|
"token_count": len(toks),
|
||||||
|
"document_id": document_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
105
llama_stack/scripts/test_rag_via_curl.py
Normal file
105
llama_stack/scripts/test_rag_via_curl.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
DefaultRAGQueryGeneratorConfig,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGToolEndpoints:
|
||||||
|
@pytest.fixture
|
||||||
|
def base_url(self) -> str:
|
||||||
|
return "http://localhost:8321/v1" # Adjust port if needed
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_documents(self) -> List[RAGDocument]:
|
||||||
|
return [
|
||||||
|
RAGDocument(
|
||||||
|
document_id="doc1",
|
||||||
|
content="Python is a high-level programming language.",
|
||||||
|
metadata={"category": "programming", "difficulty": "beginner"},
|
||||||
|
),
|
||||||
|
RAGDocument(
|
||||||
|
document_id="doc2",
|
||||||
|
content="Machine learning is a subset of artificial intelligence.",
|
||||||
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
),
|
||||||
|
RAGDocument(
|
||||||
|
document_id="doc3",
|
||||||
|
content="Data structures are fundamental to computer science.",
|
||||||
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_workflow(
|
||||||
|
self, base_url: str, sample_documents: List[RAGDocument]
|
||||||
|
):
|
||||||
|
vector_db_payload = {
|
||||||
|
"vector_db_id": "test_vector_db",
|
||||||
|
"embedding_model": "all-MiniLM-L6-v2",
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
|
||||||
|
assert response.status_code == 200
|
||||||
|
vector_db = VectorDB(**response.json())
|
||||||
|
|
||||||
|
insert_payload = {
|
||||||
|
"documents": [
|
||||||
|
json.loads(doc.model_dump_json()) for doc in sample_documents
|
||||||
|
],
|
||||||
|
"vector_db_id": vector_db.identifier,
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/tool-runtime/rag-tool/insert-documents",
|
||||||
|
json=insert_payload,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
query = "What is Python?"
|
||||||
|
query_config = RAGQueryConfig(
|
||||||
|
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||||
|
max_tokens_in_context=4096,
|
||||||
|
max_chunks=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_payload = {
|
||||||
|
"content": query,
|
||||||
|
"query_config": json.loads(query_config.model_dump_json()),
|
||||||
|
"vector_db_ids": [vector_db.identifier],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/tool-runtime/rag-tool/query-context",
|
||||||
|
json=query_payload,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
result = TypeAdapter(RAGQueryResult).validate_python(result)
|
||||||
|
|
||||||
|
content_str = interleaved_content_as_str(result.content)
|
||||||
|
print(f"content: {content_str}")
|
||||||
|
assert len(content_str) > 0
|
||||||
|
assert "Python" in content_str
|
||||||
|
|
||||||
|
# Clean up: Delete the vector DB
|
||||||
|
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
|
||||||
|
assert response.status_code == 200
|
|
@ -4,7 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::together
|
- remote::together
|
||||||
memory:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
- remote::pgvector
|
- remote::pgvector
|
||||||
|
|
|
@ -5,7 +5,7 @@ apis:
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- inference
|
- inference
|
||||||
- memory
|
- vector_io
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
|
@ -20,7 +20,7 @@ providers:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
memory:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
config:
|
config:
|
||||||
|
@ -145,7 +145,6 @@ models:
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
memory_banks: []
|
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
eval_tasks: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue