[memory refactor][1/n] Rename Memory -> VectorIO, MemoryBanks -> VectorDBs (#828)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

This is the first part:

- delete other kinds of memory banks (keyvalue, keyword, graph) for now;
we will introduce a keyvalue store API as part of this design but not
use it in the RAG tool yet.
- renaming of the APIs
This commit is contained in:
Ashwin Bharambe 2025-01-22 09:59:30 -08:00 committed by GitHub
parent 35a00d004a
commit 3ae8585b65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 175 additions and 296 deletions

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -133,8 +132,6 @@ class Session(BaseModel):
turns: List[Turn]
started_at: datetime
memory_bank: Optional[MemoryBank] = None
class AgentToolGroupWithArgs(BaseModel):
name: str

View file

@ -14,7 +14,7 @@ class Api(Enum):
inference = "inference"
safety = "safety"
agents = "agents"
memory = "memory"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
@ -25,7 +25,7 @@ class Api(Enum):
models = "models"
shields = "shields"
memory_banks = "memory_banks"
vector_dbs = "vector_dbs"
datasets = "datasets"
scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"

View file

@ -1,161 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Annotated,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
# define params for each type of memory bank, this leads to a tagged union
# accepted as input from the API or from the config.
@json_schema_type
class VectorMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
@json_schema_type
class KeywordMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type
class GraphMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
BankParams = Annotated[
Union[
VectorMemoryBankParams,
KeyValueMemoryBankParams,
KeywordMemoryBankParams,
GraphMemoryBankParams,
],
Field(discriminator="memory_bank_type"),
]
# Some common functionality for memory banks.
class MemoryBankResourceMixin(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
@property
def memory_bank_id(self) -> str:
return self.identifier
@property
def provider_memory_bank_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
@json_schema_type
class KeywordMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type
class GraphMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBank = register_schema(
Annotated[
Union[
VectorMemoryBank,
KeyValueMemoryBank,
KeywordMemoryBank,
GraphMemoryBank,
],
Field(discriminator="memory_bank_type"),
],
name="MemoryBank",
)
class MemoryBankInput(BaseModel):
memory_bank_id: str
params: BankParams
provider_memory_bank_id: Optional[str] = None
class ListMemoryBanksResponse(BaseModel):
data: List[MemoryBank]
@runtime_checkable
@trace_protocol
class MemoryBanks(Protocol):
@webmethod(route="/memory-banks", method="GET")
async def list_memory_banks(self) -> ListMemoryBanksResponse: ...
@webmethod(route="/memory-banks/{memory_bank_id}", method="GET")
async def get_memory_bank(
self,
memory_bank_id: str,
) -> Optional[MemoryBank]: ...
@webmethod(route="/memory-banks", method="POST")
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
class ResourceType(Enum):
model = "model"
shield = "shield"
memory_bank = "memory_bank"
vector_db = "vector_db"
dataset = "dataset"
scoring_function = "scoring_function"
eval_task = "eval_task"

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory_banks import * # noqa: F401 F403
from .vector_dbs import * # noqa: F401 F403

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class VectorDB(Resource):
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
embedding_model: str
embedding_dimension: int
@property
def vector_db_id(self) -> str:
return self.identifier
@property
def provider_vector_db_id(self) -> str:
return self.provider_resource_id
class VectorDBInput(BaseModel):
vector_db_id: str
embedding_model: str
embedding_dimension: int
provider_vector_db_id: Optional[str] = None
class ListVectorDBsResponse(BaseModel):
data: List[VectorDB]
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
) -> Optional[VectorDB]: ...
@webmethod(route="/vector-dbs", method="POST")
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> VectorDB: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory import * # noqa: F401 F403
from .vector_io import * # noqa: F401 F403

View file

@ -13,55 +13,45 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBank
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class Chunk(BaseModel):
content: InterleavedContent
token_count: int
document_id: str
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class QueryDocumentsResponse(BaseModel):
class QueryChunksResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
@runtime_checkable
@trace_protocol
class Memory(Protocol):
memory_bank_store: MemoryBankStore
class VectorIO(Protocol):
vector_db_store: VectorDBStore
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory/insert", method="POST")
async def insert_documents(
@webmethod(route="/vector-io/insert", method="POST")
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory/query", method="POST")
async def query_documents(
@webmethod(route="/vector-io/query", method="POST")
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
) -> QueryChunksResponse: ...

View file

@ -13,14 +13,14 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
@ -34,7 +34,7 @@ RoutingKey = Union[str, List[str]]
RoutableObject = Union[
Model,
Shield,
MemoryBank,
VectorDB,
Dataset,
ScoringFn,
EvalTask,
@ -47,7 +47,7 @@ RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
MemoryBank,
VectorDB,
Dataset,
ScoringFn,
EvalTask,
@ -60,7 +60,7 @@ RoutableObjectWithProvider = Annotated[
RoutedProtocol = Union[
Inference,
Safety,
Memory,
VectorIO,
DatasetIO,
Scoring,
Eval,
@ -153,7 +153,7 @@ a default SQLite store will be used.""",
# registry of "resources" in the distribution
models: List[ModelInput] = Field(default_factory=list)
shields: List[ShieldInput] = Field(default_factory=list)
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)

View file

@ -32,8 +32,8 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.memory_banks,
router_api=Api.memory,
routing_table_api=Api.vector_dbs,
router_api=Api.vector_io,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,

View file

@ -15,8 +15,6 @@ from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
@ -25,6 +23,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
@ -40,7 +40,6 @@ from llama_stack.providers.datatypes import (
DatasetsProtocolPrivate,
EvalTasksProtocolPrivate,
InlineProviderSpec,
MemoryBanksProtocolPrivate,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
@ -48,6 +47,7 @@ from llama_stack.providers.datatypes import (
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
VectorDBsProtocolPrivate,
)
log = logging.getLogger(__name__)
@ -62,8 +62,8 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.memory: Memory,
Api.memory_banks: MemoryBanks,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
@ -84,7 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (

View file

@ -12,13 +12,6 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
from llama_stack.apis.memory_banks import (
BankParams,
ListMemoryBanksResponse,
MemoryBank,
MemoryBanks,
MemoryBankType,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
@ -36,6 +29,7 @@ from llama_stack.apis.tools import (
ToolGroups,
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@ -59,8 +53,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.memory:
return await p.register_memory_bank(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
@ -75,8 +69,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
return await p.unregister_memory_bank(obj.identifier)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
@ -120,8 +114,8 @@ class CommonRoutingTableImpl(RoutingTable):
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
elif api == Api.vector_io:
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
@ -145,8 +139,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, MemoryBanksRoutingTable):
return ("Memory", "memory_bank")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
@ -196,9 +190,6 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# Get existing objects from registry
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -311,22 +302,23 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return shield
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> ListMemoryBanksResponse:
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
return await self.get_object_by_identifier("vector_db", vector_db_id)
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank_id: str,
params: BankParams,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
if provider_memory_bank_id is None:
provider_memory_bank_id = memory_bank_id
provider_vector_db_id: Optional[str] = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
@ -335,44 +327,39 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
model = await self.get_object_by_identifier("model", params.embedding_model)
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
if params.embedding_model == "all-MiniLM-L6-v2":
if embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {params.embedding_model} not found")
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
f"Model {embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
await self.register_object(memory_bank)
return memory_bank
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
await self.unregister_object(existing_bank)
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -21,8 +21,6 @@ from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
@ -32,6 +30,8 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
@ -42,7 +42,7 @@ log = logging.getLogger(__name__)
class LlamaStack(
MemoryBanks,
VectorDBs,
Inference,
BatchInference,
Agents,
@ -51,7 +51,7 @@ class LlamaStack(
Datasets,
Telemetry,
PostTraining,
Memory,
VectorIO,
Eval,
EvalTasks,
Scoring,
@ -69,7 +69,7 @@ class LlamaStack(
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",

View file

@ -14,11 +14,11 @@ from llama_stack.apis.datasets import Dataset
from llama_stack.apis.datatypes import Api
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool
from llama_stack.apis.vector_dbs import VectorDB
class ModelsProtocolPrivate(Protocol):
@ -31,10 +31,10 @@ class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
class MemoryBanksProtocolPrivate(Protocol):
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol):

View file

@ -38,78 +38,78 @@ EMBEDDING_DEPS = [
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
api=Api.vector_io,
provider_type="inline::meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.memory,
api=Api.vector_io,
provider_type="inline::faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
Api.vector_io,
AdapterSpec(
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaRemoteImplConfig",
),
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.memory,
api=Api.vector_io,
provider_type="inline::chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb"],
module="llama_stack.providers.inline.memory.chroma",
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaInlineImplConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
Api.vector_io,
AdapterSpec(
adapter_type="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
Api.vector_io,
AdapterSpec(
adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.remote.memory.weaviate",
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
api=Api.memory,
api=Api.vector_io,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
module="llama_stack.providers.remote.vector_io.sample",
config_class="llama_stack.providers.remote.vector_io.sample.SampleConfig",
),
api_dependencies=[],
),
remote_provider_spec(
Api.memory,
Api.vector_io,
AdapterSpec(
adapter_type="qdrant",
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantConfig",
),
api_dependencies=[Api.inference],
),