forked from phoenix-oss/llama-stack-mirror
[memory refactor][1/n] Rename Memory -> VectorIO, MemoryBanks -> VectorDBs (#828)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. This is the first part: - delete other kinds of memory banks (keyvalue, keyword, graph) for now; we will introduce a keyvalue store API as part of this design but not use it in the RAG tool yet. - renaming of the APIs
This commit is contained in:
parent
35a00d004a
commit
3ae8585b65
37 changed files with 175 additions and 296 deletions
|
@ -12,13 +12,6 @@ from llama_stack.apis.common.content_types import URL
|
|||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
||||
from llama_stack.apis.memory_banks import (
|
||||
BankParams,
|
||||
ListMemoryBanksResponse,
|
||||
MemoryBank,
|
||||
MemoryBanks,
|
||||
MemoryBankType,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
|
@ -36,6 +29,7 @@ from llama_stack.apis.tools import (
|
|||
ToolGroups,
|
||||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
|
@ -59,8 +53,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.vector_io:
|
||||
return await p.register_vector_db(obj)
|
||||
elif api == Api.datasetio:
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
|
@ -75,8 +69,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.unregister_memory_bank(obj.identifier)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_db(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
|
@ -120,8 +114,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
elif api == Api.vector_io:
|
||||
p.vector_db_store = self
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
elif api == Api.scoring:
|
||||
|
@ -145,8 +139,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, MemoryBanksRoutingTable):
|
||||
return ("Memory", "memory_bank")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
@ -196,9 +190,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# Get existing objects from registry
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -311,22 +302,23 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
return shield
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> ListMemoryBanksResponse:
|
||||
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
||||
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
||||
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
|
||||
async def register_memory_bank(
|
||||
async def register_vector_db(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
if provider_memory_bank_id is None:
|
||||
provider_memory_bank_id = memory_bank_id
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
|
@ -335,44 +327,39 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
if params.embedding_model == "all-MiniLM-L6-v2":
|
||||
if embedding_model == "all-MiniLM-L6-v2":
|
||||
raise ValueError(
|
||||
"Embeddings are now served via Inference providers. "
|
||||
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
||||
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {params.embedding_model} not found")
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} is not an embedding model"
|
||||
)
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||
f"Model {embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
memory_bank_data = {
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
"provider_resource_id": provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||
"embedding_dimension"
|
||||
]
|
||||
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.unregister_object(existing_bank)
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue