Move tool_runtime.memory -> tool_runtime.rag

This commit is contained in:
Ashwin Bharambe 2025-01-22 20:25:02 -08:00
parent f3d8864c36
commit 0bff6e1658
5 changed files with 7 additions and 7 deletions

View file

@ -8,11 +8,11 @@ from typing import Any, Dict
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from .config import MemoryToolRuntimeConfig from .config import RagToolRuntimeConfig
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -7,5 +7,5 @@
from pydantic import BaseModel from pydantic import BaseModel
class MemoryToolRuntimeConfig(BaseModel): class RagToolRuntimeConfig(BaseModel):
pass pass

View file

@ -32,7 +32,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks, make_overlapped_chunks,
) )
from .config import MemoryToolRuntimeConfig from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -47,7 +47,7 @@ def make_random_string(length: int = 8):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
config: MemoryToolRuntimeConfig, config: RagToolRuntimeConfig,
vector_io_api: VectorIO, vector_io_api: VectorIO,
inference_api: Inference, inference_api: Inference,
): ):

View file

@ -21,8 +21,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.tool_runtime, api=Api.tool_runtime,
provider_type="inline::rag-runtime", provider_type="inline::rag-runtime",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory", module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference], api_dependencies=[Api.vector_io, Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(