diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 63d0920fb..20cb8f828 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -33,7 +33,6 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -133,8 +132,6 @@ class Session(BaseModel): turns: List[Turn] started_at: datetime - memory_bank: Optional[MemoryBank] = None - class AgentToolGroupWithArgs(BaseModel): name: str diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 52c429a2b..ccc395b80 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -14,7 +14,7 @@ class Api(Enum): inference = "inference" safety = "safety" agents = "agents" - memory = "memory" + vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" eval = "eval" @@ -25,7 +25,7 @@ class Api(Enum): models = "models" shields = "shields" - memory_banks = "memory_banks" + vector_dbs = "vector_dbs" datasets = "datasets" scoring_functions = "scoring_functions" eval_tasks = "eval_tasks" diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py deleted file mode 100644 index ec8ba824b..000000000 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import ( - Annotated, - List, - Literal, - Optional, - Protocol, - runtime_checkable, - Union, -) - -from llama_models.schema_utils import json_schema_type, register_schema, webmethod -from pydantic import BaseModel, Field - -from llama_stack.apis.resource import Resource, ResourceType -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol - - -@json_schema_type -class MemoryBankType(Enum): - vector = "vector" - keyvalue = "keyvalue" - keyword = "keyword" - graph = "graph" - - -# define params for each type of memory bank, this leads to a tagged union -# accepted as input from the API or from the config. -@json_schema_type -class VectorMemoryBankParams(BaseModel): - memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value - embedding_model: str - chunk_size_in_tokens: int - overlap_size_in_tokens: Optional[int] = None - - -@json_schema_type -class KeyValueMemoryBankParams(BaseModel): - memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( - MemoryBankType.keyvalue.value - ) - - -@json_schema_type -class KeywordMemoryBankParams(BaseModel): - memory_bank_type: Literal[MemoryBankType.keyword.value] = ( - MemoryBankType.keyword.value - ) - - -@json_schema_type -class GraphMemoryBankParams(BaseModel): - memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value - - -BankParams = Annotated[ - Union[ - VectorMemoryBankParams, - KeyValueMemoryBankParams, - KeywordMemoryBankParams, - GraphMemoryBankParams, - ], - Field(discriminator="memory_bank_type"), -] - - -# Some common functionality for memory banks. -class MemoryBankResourceMixin(Resource): - type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value - - @property - def memory_bank_id(self) -> str: - return self.identifier - - @property - def provider_memory_bank_id(self) -> str: - return self.provider_resource_id - - -@json_schema_type -class VectorMemoryBank(MemoryBankResourceMixin): - memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value - embedding_model: str - chunk_size_in_tokens: int - embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2 - overlap_size_in_tokens: Optional[int] = None - - -@json_schema_type -class KeyValueMemoryBank(MemoryBankResourceMixin): - memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( - MemoryBankType.keyvalue.value - ) - - -# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention. -@json_schema_type -class KeywordMemoryBank(MemoryBankResourceMixin): - memory_bank_type: Literal[MemoryBankType.keyword.value] = ( - MemoryBankType.keyword.value - ) - - -@json_schema_type -class GraphMemoryBank(MemoryBankResourceMixin): - memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value - - -MemoryBank = register_schema( - Annotated[ - Union[ - VectorMemoryBank, - KeyValueMemoryBank, - KeywordMemoryBank, - GraphMemoryBank, - ], - Field(discriminator="memory_bank_type"), - ], - name="MemoryBank", -) - - -class MemoryBankInput(BaseModel): - memory_bank_id: str - params: BankParams - provider_memory_bank_id: Optional[str] = None - - -class ListMemoryBanksResponse(BaseModel): - data: List[MemoryBank] - - -@runtime_checkable -@trace_protocol -class MemoryBanks(Protocol): - @webmethod(route="/memory-banks", method="GET") - async def list_memory_banks(self) -> ListMemoryBanksResponse: ... - - @webmethod(route="/memory-banks/{memory_bank_id}", method="GET") - async def get_memory_bank( - self, - memory_bank_id: str, - ) -> Optional[MemoryBank]: ... - - @webmethod(route="/memory-banks", method="POST") - async def register_memory_bank( - self, - memory_bank_id: str, - params: BankParams, - provider_id: Optional[str] = None, - provider_memory_bank_id: Optional[str] = None, - ) -> MemoryBank: ... - - @webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE") - async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index a85f5a31c..dfe3ddb24 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field class ResourceType(Enum): model = "model" shield = "shield" - memory_bank = "memory_bank" + vector_db = "vector_db" dataset = "dataset" scoring_function = "scoring_function" eval_task = "eval_task" diff --git a/llama_stack/apis/memory_banks/__init__.py b/llama_stack/apis/vector_dbs/__init__.py similarity index 81% rename from llama_stack/apis/memory_banks/__init__.py rename to llama_stack/apis/vector_dbs/__init__.py index 7511677ab..158241a6d 100644 --- a/llama_stack/apis/memory_banks/__init__.py +++ b/llama_stack/apis/vector_dbs/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .memory_banks import * # noqa: F401 F403 +from .vector_dbs import * # noqa: F401 F403 diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py new file mode 100644 index 000000000..4b782e2d5 --- /dev/null +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Literal, Optional, Protocol, runtime_checkable + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel + +from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol + + +@json_schema_type +class VectorDB(Resource): + type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value + + embedding_model: str + embedding_dimension: int + + @property + def vector_db_id(self) -> str: + return self.identifier + + @property + def provider_vector_db_id(self) -> str: + return self.provider_resource_id + + +class VectorDBInput(BaseModel): + vector_db_id: str + embedding_model: str + embedding_dimension: int + provider_vector_db_id: Optional[str] = None + + +class ListVectorDBsResponse(BaseModel): + data: List[VectorDB] + + +@runtime_checkable +@trace_protocol +class VectorDBs(Protocol): + @webmethod(route="/vector-dbs", method="GET") + async def list_vector_dbs(self) -> ListVectorDBsResponse: ... + + @webmethod(route="/vector-dbs/{vector_db_id}", method="GET") + async def get_vector_db( + self, + vector_db_id: str, + ) -> Optional[VectorDB]: ... + + @webmethod(route="/vector-dbs", method="POST") + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: Optional[int] = 384, + provider_id: Optional[str] = None, + provider_vector_db_id: Optional[str] = None, + ) -> VectorDB: ... + + @webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE") + async def unregister_vector_db(self, vector_db_id: str) -> None: ... diff --git a/llama_stack/apis/memory/__init__.py b/llama_stack/apis/vector_io/__init__.py similarity index 82% rename from llama_stack/apis/memory/__init__.py rename to llama_stack/apis/vector_io/__init__.py index 260862228..3fe4fa4b6 100644 --- a/llama_stack/apis/memory/__init__.py +++ b/llama_stack/apis/vector_io/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .memory import * # noqa: F401 F403 +from .vector_io import * # noqa: F401 F403 diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/vector_io/vector_io.py similarity index 61% rename from llama_stack/apis/memory/memory.py rename to llama_stack/apis/vector_io/vector_io.py index 6e6fcf697..5371b8918 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -13,55 +13,45 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_stack.apis.common.content_types import URL from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory_banks import MemoryBank +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -@json_schema_type -class MemoryBankDocument(BaseModel): - document_id: str - content: InterleavedContent | URL - mime_type: str | None = None - metadata: Dict[str, Any] = Field(default_factory=dict) - - class Chunk(BaseModel): content: InterleavedContent - token_count: int - document_id: str + metadata: Dict[str, Any] = Field(default_factory=dict) @json_schema_type -class QueryDocumentsResponse(BaseModel): +class QueryChunksResponse(BaseModel): chunks: List[Chunk] scores: List[float] -class MemoryBankStore(Protocol): - def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... +class VectorDBStore(Protocol): + def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ... @runtime_checkable @trace_protocol -class Memory(Protocol): - memory_bank_store: MemoryBankStore +class VectorIO(Protocol): + vector_db_store: VectorDBStore # this will just block now until documents are inserted, but it should # probably return a Job instance which can be polled for completion - @webmethod(route="/memory/insert", method="POST") - async def insert_documents( + @webmethod(route="/vector-io/insert", method="POST") + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], + vector_db_id: str, + chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: ... - @webmethod(route="/memory/query", method="POST") - async def query_documents( + @webmethod(route="/vector-io/query", method="POST") + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: ... + ) -> QueryChunksResponse: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index c1a91cf6c..99ffeb346 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -13,14 +13,14 @@ from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput from llama_stack.apis.inference import Inference -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime +from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig @@ -34,7 +34,7 @@ RoutingKey = Union[str, List[str]] RoutableObject = Union[ Model, Shield, - MemoryBank, + VectorDB, Dataset, ScoringFn, EvalTask, @@ -47,7 +47,7 @@ RoutableObjectWithProvider = Annotated[ Union[ Model, Shield, - MemoryBank, + VectorDB, Dataset, ScoringFn, EvalTask, @@ -60,7 +60,7 @@ RoutableObjectWithProvider = Annotated[ RoutedProtocol = Union[ Inference, Safety, - Memory, + VectorIO, DatasetIO, Scoring, Eval, @@ -153,7 +153,7 @@ a default SQLite store will be used.""", # registry of "resources" in the distribution models: List[ModelInput] = Field(default_factory=list) shields: List[ShieldInput] = Field(default_factory=list) - memory_banks: List[MemoryBankInput] = Field(default_factory=list) + vector_dbs: List[VectorDBInput] = Field(default_factory=list) datasets: List[DatasetInput] = Field(default_factory=list) scoring_fns: List[ScoringFnInput] = Field(default_factory=list) eval_tasks: List[EvalTaskInput] = Field(default_factory=list) diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 4183d92cd..b02d0fb6c 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -32,8 +32,8 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: router_api=Api.safety, ), AutoRoutedApiInfo( - routing_table_api=Api.memory_banks, - router_api=Api.memory, + routing_table_api=Api.vector_dbs, + router_api=Api.vector_io, ), AutoRoutedApiInfo( routing_table_api=Api.datasets, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 204555b16..bd5a9ae98 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -15,8 +15,6 @@ from llama_stack.apis.eval import Eval from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.safety import Safety @@ -25,6 +23,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_dbs import VectorDBs +from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.datatypes import ( AutoRoutedProviderSpec, @@ -40,7 +40,6 @@ from llama_stack.providers.datatypes import ( DatasetsProtocolPrivate, EvalTasksProtocolPrivate, InlineProviderSpec, - MemoryBanksProtocolPrivate, ModelsProtocolPrivate, ProviderSpec, RemoteProviderConfig, @@ -48,6 +47,7 @@ from llama_stack.providers.datatypes import ( ScoringFunctionsProtocolPrivate, ShieldsProtocolPrivate, ToolsProtocolPrivate, + VectorDBsProtocolPrivate, ) log = logging.getLogger(__name__) @@ -62,8 +62,8 @@ def api_protocol_map() -> Dict[Api, Any]: Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, - Api.memory: Memory, - Api.memory_banks: MemoryBanks, + Api.vector_io: VectorIO, + Api.vector_dbs: VectorDBs, Api.models: Models, Api.safety: Safety, Api.shields: Shields, @@ -84,7 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), - Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), + Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), Api.scoring: ( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 889bd4624..1d035d878 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -12,13 +12,6 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse -from llama_stack.apis.memory_banks import ( - BankParams, - ListMemoryBanksResponse, - MemoryBank, - MemoryBanks, - MemoryBankType, -) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( @@ -36,6 +29,7 @@ from llama_stack.apis.tools import ( ToolGroups, ToolHost, ) +from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.distribution.datatypes import ( RoutableObject, RoutableObjectWithProvider, @@ -59,8 +53,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable return await p.register_model(obj) elif api == Api.safety: return await p.register_shield(obj) - elif api == Api.memory: - return await p.register_memory_bank(obj) + elif api == Api.vector_io: + return await p.register_vector_db(obj) elif api == Api.datasetio: return await p.register_dataset(obj) elif api == Api.scoring: @@ -75,8 +69,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) - if api == Api.memory: - return await p.unregister_memory_bank(obj.identifier) + if api == Api.vector_io: + return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) elif api == Api.datasetio: @@ -120,8 +114,8 @@ class CommonRoutingTableImpl(RoutingTable): p.model_store = self elif api == Api.safety: p.shield_store = self - elif api == Api.memory: - p.memory_bank_store = self + elif api == Api.vector_io: + p.vector_db_store = self elif api == Api.datasetio: p.dataset_store = self elif api == Api.scoring: @@ -145,8 +139,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("Inference", "model") elif isinstance(self, ShieldsRoutingTable): return ("Safety", "shield") - elif isinstance(self, MemoryBanksRoutingTable): - return ("Memory", "memory_bank") + elif isinstance(self, VectorDBsRoutingTable): + return ("VectorIO", "vector_db") elif isinstance(self, DatasetsRoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): @@ -196,9 +190,6 @@ class CommonRoutingTableImpl(RoutingTable): async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: - # Get existing objects from registry - existing_obj = await self.dist_registry.get(obj.type, obj.identifier) - # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -311,22 +302,23 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return shield -class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): - async def list_memory_banks(self) -> ListMemoryBanksResponse: - return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank")) +class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): + async def list_vector_dbs(self) -> ListVectorDBsResponse: + return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: - return await self.get_object_by_identifier("memory_bank", memory_bank_id) + async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: + return await self.get_object_by_identifier("vector_db", vector_db_id) - async def register_memory_bank( + async def register_vector_db( self, - memory_bank_id: str, - params: BankParams, + vector_db_id: str, + embedding_model: str, + embedding_dimension: Optional[int] = 384, provider_id: Optional[str] = None, - provider_memory_bank_id: Optional[str] = None, - ) -> MemoryBank: - if provider_memory_bank_id is None: - provider_memory_bank_id = memory_bank_id + provider_vector_db_id: Optional[str] = None, + ) -> VectorDB: + if provider_vector_db_id is None: + provider_vector_db_id = vector_db_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) == 1: @@ -335,44 +327,39 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - model = await self.get_object_by_identifier("model", params.embedding_model) + model = await self.get_object_by_identifier("model", embedding_model) if model is None: - if params.embedding_model == "all-MiniLM-L6-v2": + if embedding_model == "all-MiniLM-L6-v2": raise ValueError( "Embeddings are now served via Inference providers. " "Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. " "See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example." ) else: - raise ValueError(f"Model {params.embedding_model} not found") + raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: - raise ValueError( - f"Model {params.embedding_model} is not an embedding model" - ) + raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: raise ValueError( - f"Model {params.embedding_model} does not have an embedding dimension" + f"Model {embedding_model} does not have an embedding dimension" ) - memory_bank_data = { - "identifier": memory_bank_id, - "type": ResourceType.memory_bank.value, + vector_db_data = { + "identifier": vector_db_id, + "type": ResourceType.vector_db.value, "provider_id": provider_id, - "provider_resource_id": provider_memory_bank_id, - **params.model_dump(), + "provider_resource_id": provider_vector_db_id, + "embedding_model": embedding_model, + "embedding_dimension": model.metadata["embedding_dimension"], } - if params.memory_bank_type == MemoryBankType.vector.value: - memory_bank_data["embedding_dimension"] = model.metadata[ - "embedding_dimension" - ] - memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data) - await self.register_object(memory_bank) - return memory_bank + vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) + await self.register_object(vector_db) + return vector_db - async def unregister_memory_bank(self, memory_bank_id: str) -> None: - existing_bank = await self.get_memory_bank(memory_bank_id) - if existing_bank is None: - raise ValueError(f"Memory bank {memory_bank_id} not found") - await self.unregister_object(existing_bank) + async def unregister_vector_db(self, vector_db_id: str) -> None: + existing_vector_db = await self.get_vector_db(vector_db_id) + if existing_vector_db is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + await self.unregister_object(existing_vector_db) class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index ad7bcd234..180ec0ecc 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -21,8 +21,6 @@ from llama_stack.apis.eval import Eval from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.safety import Safety @@ -32,6 +30,8 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_dbs import VectorDBs +from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls @@ -42,7 +42,7 @@ log = logging.getLogger(__name__) class LlamaStack( - MemoryBanks, + VectorDBs, Inference, BatchInference, Agents, @@ -51,7 +51,7 @@ class LlamaStack( Datasets, Telemetry, PostTraining, - Memory, + VectorIO, Eval, EvalTasks, Scoring, @@ -69,7 +69,7 @@ class LlamaStack( RESOURCES = [ ("models", Api.models, "register_model", "list_models"), ("shields", Api.shields, "register_shield", "list_shields"), - ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), + ("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"), ("datasets", Api.datasets, "register_dataset", "list_datasets"), ( "scoring_fns", diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 4815754d2..d0c448f8c 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -14,11 +14,11 @@ from llama_stack.apis.datasets import Dataset from llama_stack.apis.datatypes import Api from llama_stack.apis.eval_tasks import EvalTask -from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield from llama_stack.apis.tools import Tool +from llama_stack.apis.vector_dbs import VectorDB class ModelsProtocolPrivate(Protocol): @@ -31,10 +31,10 @@ class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... -class MemoryBanksProtocolPrivate(Protocol): - async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... +class VectorDBsProtocolPrivate(Protocol): + async def register_vector_db(self, vector_db: VectorDB) -> None: ... - async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... + async def unregister_vector_db(self, vector_db_id: str) -> None: ... class DatasetsProtocolPrivate(Protocol): diff --git a/llama_stack/providers/inline/memory/__init__.py b/llama_stack/providers/inline/vector_io/__init__.py similarity index 100% rename from llama_stack/providers/inline/memory/__init__.py rename to llama_stack/providers/inline/vector_io/__init__.py diff --git a/llama_stack/providers/inline/memory/chroma/__init__.py b/llama_stack/providers/inline/vector_io/chroma/__init__.py similarity index 100% rename from llama_stack/providers/inline/memory/chroma/__init__.py rename to llama_stack/providers/inline/vector_io/chroma/__init__.py diff --git a/llama_stack/providers/inline/memory/chroma/config.py b/llama_stack/providers/inline/vector_io/chroma/config.py similarity index 100% rename from llama_stack/providers/inline/memory/chroma/config.py rename to llama_stack/providers/inline/vector_io/chroma/config.py diff --git a/llama_stack/providers/inline/memory/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py similarity index 100% rename from llama_stack/providers/inline/memory/faiss/__init__.py rename to llama_stack/providers/inline/vector_io/faiss/__init__.py diff --git a/llama_stack/providers/inline/memory/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py similarity index 100% rename from llama_stack/providers/inline/memory/faiss/config.py rename to llama_stack/providers/inline/vector_io/faiss/config.py diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py similarity index 100% rename from llama_stack/providers/inline/memory/faiss/faiss.py rename to llama_stack/providers/inline/vector_io/faiss/faiss.py diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/vector_io.py similarity index 64% rename from llama_stack/providers/registry/memory.py rename to llama_stack/providers/registry/vector_io.py index 6867a9186..df7b7f4b3 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/vector_io.py @@ -38,78 +38,78 @@ EMBEDDING_DEPS = [ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( - api=Api.memory, + api=Api.vector_io, provider_type="inline::meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], - module="llama_stack.providers.inline.memory.faiss", - config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + module="llama_stack.providers.inline.vector_io.faiss", + config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", api_dependencies=[Api.inference], ), InlineProviderSpec( - api=Api.memory, + api=Api.vector_io, provider_type="inline::faiss", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], - module="llama_stack.providers.inline.memory.faiss", - config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + module="llama_stack.providers.inline.vector_io.faiss", + config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", api_dependencies=[Api.inference], ), remote_provider_spec( - Api.memory, + Api.vector_io, AdapterSpec( adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], - module="llama_stack.providers.remote.memory.chroma", - config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", + module="llama_stack.providers.remote.vector_io.chroma", + config_class="llama_stack.providers.remote.vector_io.chroma.ChromaRemoteImplConfig", ), api_dependencies=[Api.inference], ), InlineProviderSpec( - api=Api.memory, + api=Api.vector_io, provider_type="inline::chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb"], - module="llama_stack.providers.inline.memory.chroma", - config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig", + module="llama_stack.providers.inline.vector_io.chroma", + config_class="llama_stack.providers.inline.vector_io.chroma.ChromaInlineImplConfig", api_dependencies=[Api.inference], ), remote_provider_spec( - Api.memory, + Api.vector_io, AdapterSpec( adapter_type="pgvector", pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], - module="llama_stack.providers.remote.memory.pgvector", - config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", + module="llama_stack.providers.remote.vector_io.pgvector", + config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorConfig", ), api_dependencies=[Api.inference], ), remote_provider_spec( - Api.memory, + Api.vector_io, AdapterSpec( adapter_type="weaviate", pip_packages=EMBEDDING_DEPS + ["weaviate-client"], - module="llama_stack.providers.remote.memory.weaviate", - config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", - provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", + module="llama_stack.providers.remote.vector_io.weaviate", + config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateConfig", + provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", ), api_dependencies=[Api.inference], ), remote_provider_spec( - api=Api.memory, + api=Api.vector_io, adapter=AdapterSpec( adapter_type="sample", pip_packages=[], - module="llama_stack.providers.remote.memory.sample", - config_class="llama_stack.providers.remote.memory.sample.SampleConfig", + module="llama_stack.providers.remote.vector_io.sample", + config_class="llama_stack.providers.remote.vector_io.sample.SampleConfig", ), api_dependencies=[], ), remote_provider_spec( - Api.memory, + Api.vector_io, AdapterSpec( adapter_type="qdrant", pip_packages=EMBEDDING_DEPS + ["qdrant-client"], - module="llama_stack.providers.remote.memory.qdrant", - config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", + module="llama_stack.providers.remote.vector_io.qdrant", + config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantConfig", ), api_dependencies=[Api.inference], ), diff --git a/llama_stack/providers/remote/memory/__init__.py b/llama_stack/providers/remote/vector_io/__init__.py similarity index 100% rename from llama_stack/providers/remote/memory/__init__.py rename to llama_stack/providers/remote/vector_io/__init__.py diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/vector_io/chroma/__init__.py similarity index 100% rename from llama_stack/providers/remote/memory/chroma/__init__.py rename to llama_stack/providers/remote/vector_io/chroma/__init__.py diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py similarity index 100% rename from llama_stack/providers/remote/memory/chroma/chroma.py rename to llama_stack/providers/remote/vector_io/chroma/chroma.py diff --git a/llama_stack/providers/remote/memory/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py similarity index 100% rename from llama_stack/providers/remote/memory/chroma/config.py rename to llama_stack/providers/remote/vector_io/chroma/config.py diff --git a/llama_stack/providers/remote/memory/pgvector/__init__.py b/llama_stack/providers/remote/vector_io/pgvector/__init__.py similarity index 100% rename from llama_stack/providers/remote/memory/pgvector/__init__.py rename to llama_stack/providers/remote/vector_io/pgvector/__init__.py diff --git a/llama_stack/providers/remote/memory/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py similarity index 100% rename from llama_stack/providers/remote/memory/pgvector/config.py rename to llama_stack/providers/remote/vector_io/pgvector/config.py diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py similarity index 100% rename from llama_stack/providers/remote/memory/pgvector/pgvector.py rename to llama_stack/providers/remote/vector_io/pgvector/pgvector.py diff --git a/llama_stack/providers/remote/memory/qdrant/__init__.py b/llama_stack/providers/remote/vector_io/qdrant/__init__.py similarity index 100% rename from llama_stack/providers/remote/memory/qdrant/__init__.py rename to llama_stack/providers/remote/vector_io/qdrant/__init__.py diff --git a/llama_stack/providers/remote/memory/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py similarity index 100% rename from llama_stack/providers/remote/memory/qdrant/config.py rename to llama_stack/providers/remote/vector_io/qdrant/config.py diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py similarity index 100% rename from llama_stack/providers/remote/memory/qdrant/qdrant.py rename to llama_stack/providers/remote/vector_io/qdrant/qdrant.py diff --git a/llama_stack/providers/remote/memory/sample/__init__.py b/llama_stack/providers/remote/vector_io/sample/__init__.py similarity index 100% rename from llama_stack/providers/remote/memory/sample/__init__.py rename to llama_stack/providers/remote/vector_io/sample/__init__.py diff --git a/llama_stack/providers/remote/memory/sample/config.py b/llama_stack/providers/remote/vector_io/sample/config.py similarity index 100% rename from llama_stack/providers/remote/memory/sample/config.py rename to llama_stack/providers/remote/vector_io/sample/config.py diff --git a/llama_stack/providers/remote/memory/sample/sample.py b/llama_stack/providers/remote/vector_io/sample/sample.py similarity index 100% rename from llama_stack/providers/remote/memory/sample/sample.py rename to llama_stack/providers/remote/vector_io/sample/sample.py diff --git a/llama_stack/providers/remote/memory/weaviate/__init__.py b/llama_stack/providers/remote/vector_io/weaviate/__init__.py similarity index 100% rename from llama_stack/providers/remote/memory/weaviate/__init__.py rename to llama_stack/providers/remote/vector_io/weaviate/__init__.py diff --git a/llama_stack/providers/remote/memory/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py similarity index 100% rename from llama_stack/providers/remote/memory/weaviate/config.py rename to llama_stack/providers/remote/vector_io/weaviate/config.py diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py similarity index 100% rename from llama_stack/providers/remote/memory/weaviate/weaviate.py rename to llama_stack/providers/remote/vector_io/weaviate/weaviate.py