diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py similarity index 77% rename from llama_stack/providers/inline/tool_runtime/memory/__init__.py rename to llama_stack/providers/inline/tool_runtime/rag/__init__.py index 42a0a6b01..542872091 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -8,11 +8,11 @@ from typing import Any, Dict from llama_stack.providers.datatypes import Api -from .config import MemoryToolRuntimeConfig +from .config import RagToolRuntimeConfig from .memory import MemoryToolRuntimeImpl -async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]): impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/rag/config.py similarity index 85% rename from llama_stack/providers/inline/tool_runtime/memory/config.py rename to llama_stack/providers/inline/tool_runtime/rag/config.py index 4a20c986c..2d0d2f595 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/config.py +++ b/llama_stack/providers/inline/tool_runtime/rag/config.py @@ -7,5 +7,5 @@ from pydantic import BaseModel -class MemoryToolRuntimeConfig(BaseModel): +class RagToolRuntimeConfig(BaseModel): pass diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py similarity index 100% rename from llama_stack/providers/inline/tool_runtime/memory/context_retriever.py rename to llama_stack/providers/inline/tool_runtime/rag/context_retriever.py diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py similarity index 98% rename from llama_stack/providers/inline/tool_runtime/memory/memory.py rename to llama_stack/providers/inline/tool_runtime/rag/memory.py index 7798ed711..9a2687925 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -32,7 +32,7 @@ from llama_stack.providers.utils.memory.vector_store import ( make_overlapped_chunks, ) -from .config import MemoryToolRuntimeConfig +from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) @@ -47,7 +47,7 @@ def make_random_string(length: int = 8): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, - config: MemoryToolRuntimeConfig, + config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, ): diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 927ca1886..33d880f30 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -21,8 +21,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.tool_runtime, provider_type="inline::rag-runtime", pip_packages=[], - module="llama_stack.providers.inline.tool_runtime.memory", - config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", + module="llama_stack.providers.inline.tool_runtime.rag", + config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", api_dependencies=[Api.vector_io, Api.inference], ), InlineProviderSpec(