mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
chore: Rename RagTool FileSearchTool
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
4566eebe05
commit
2d9163529a
288 changed files with 16985 additions and 2071 deletions
|
|
@ -11,11 +11,11 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
FileSearchConfig,
|
||||
FileSearchResult,
|
||||
FileSearchToolRuntime,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -26,21 +26,21 @@ logger = get_logger(name=__name__, category="core::routers")
|
|||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
class FileSearchToolImpl(FileSearchToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
logger.debug("Initializing ToolRuntimeRouter.FileSearchToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
|
||||
query_config: FileSearchConfig | None = None,
|
||||
) -> FileSearchResult:
|
||||
logger.debug(f"ToolRuntimeRouter.FileSearchToolImpl.query: {vector_store_ids}")
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_store_ids, query_config)
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
f"ToolRuntimeRouter.FileSearchToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
|
||||
|
|
@ -64,7 +64,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
self.rag_tool = self.RagToolImpl(routing_table)
|
||||
self.rag_tool = self.FileSearchToolImpl(routing_table)
|
||||
for method in ("query", "insert"):
|
||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ logger = get_logger(name=__name__, category="core::routing_tables")
|
|||
|
||||
|
||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||
# handle the funny case like "builtin::rag/knowledge_search"
|
||||
# handle the funny case like "builtin::file_search/knowledge_search"
|
||||
parts = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(parts) == 2:
|
||||
return parts[0]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from aiohttp import hdrs
|
|||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.tools import FileSearchToolRuntime, SpecialToolGroup
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
|||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||
SpecialToolGroup.rag_tool: FileSearchToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import FileSearchToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
|
|
@ -80,7 +80,7 @@ class LlamaStack(
|
|||
Inspect,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
FileSearchToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def main():
|
|||
|
||||
# Playground pages
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
rag_page = st.Page("page/playground/file_search.py", title="RAG", icon="💬", default=False)
|
||||
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||
|
||||
# Distribution pages
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def tool_chat_page():
|
|||
help="List of built-in tools from your llama stack server.",
|
||||
)
|
||||
|
||||
if "builtin::rag" in toolgroup_selection:
|
||||
if "builtin::file_search" in toolgroup_selection:
|
||||
vector_stores = llama_stack_api.client.vector_stores.list() or []
|
||||
if not vector_stores:
|
||||
st.info("No vector databases available for selection.")
|
||||
|
|
@ -115,9 +115,9 @@ def tool_chat_page():
|
|||
)
|
||||
|
||||
for i, tool_name in enumerate(toolgroup_selection):
|
||||
if tool_name == "builtin::rag":
|
||||
if tool_name == "builtin::file_search":
|
||||
tool_dict = dict(
|
||||
name="builtin::rag",
|
||||
name="builtin::file_search",
|
||||
args={
|
||||
"vector_store_ids": list(selected_vector_stores),
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue