forked from phoenix-oss/llama-stack-mirror
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
78a481bb22
commit
1a7490470a
33 changed files with 1648 additions and 1345 deletions
95
llama_stack/apis/tools/rag_tool.py
Normal file
95
llama_stack/apis/tools/rag_tool.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
content: Optional[InterleavedContent] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["default"] = "default"
|
||||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["llm"] = "llm"
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: RAGQueryGeneratorConfig = Field(
|
||||
default=DefaultRAGQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
|
||||
async def query_context(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
query_config: RAGQueryConfig,
|
||||
vector_db_ids: List[str],
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent"""
|
||||
...
|
Loading…
Add table
Add a link
Reference in a new issue