forked from phoenix-oss/llama-stack-mirror
Move tool_runtime.memory -> tool_runtime.rag
This commit is contained in:
parent
f3d8864c36
commit
0bff6e1658
5 changed files with 7 additions and 7 deletions
|
@ -8,11 +8,11 @@ from typing import Any, Dict
|
|||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -7,5 +7,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
pass
|
|
@ -32,7 +32,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -47,7 +47,7 @@ def make_random_string(length: int = 8):
|
|||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
):
|
|
@ -21,8 +21,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue