forked from phoenix-oss/llama-stack-mirror
[memory refactor][1/n] Rename Memory -> VectorIO, MemoryBanks -> VectorDBs (#828)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. This is the first part: - delete other kinds of memory banks (keyvalue, keyword, graph) for now; we will introduce a keyvalue store API as part of this design but not use it in the RAG tool yet. - renaming of the APIs
This commit is contained in:
parent
35a00d004a
commit
3ae8585b65
37 changed files with 175 additions and 296 deletions
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
@ -133,8 +132,6 @@ class Session(BaseModel):
|
|||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
|
|
|
@ -14,7 +14,7 @@ class Api(Enum):
|
|||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
memory = "memory"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
|
@ -25,7 +25,7 @@ class Api(Enum):
|
|||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
memory_banks = "memory_banks"
|
||||
vector_dbs = "vector_dbs"
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
eval_tasks = "eval_tasks"
|
||||
|
|
|
@ -1,161 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankType(Enum):
|
||||
vector = "vector"
|
||||
keyvalue = "keyvalue"
|
||||
keyword = "keyword"
|
||||
graph = "graph"
|
||||
|
||||
|
||||
# define params for each type of memory bank, this leads to a tagged union
|
||||
# accepted as input from the API or from the config.
|
||||
@json_schema_type
|
||||
class VectorMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||
MemoryBankType.keyvalue.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeywordMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||
MemoryBankType.keyword.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
BankParams = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankParams,
|
||||
KeyValueMemoryBankParams,
|
||||
KeywordMemoryBankParams,
|
||||
GraphMemoryBankParams,
|
||||
],
|
||||
Field(discriminator="memory_bank_type"),
|
||||
]
|
||||
|
||||
|
||||
# Some common functionality for memory banks.
|
||||
class MemoryBankResourceMixin(Resource):
|
||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
||||
|
||||
@property
|
||||
def memory_bank_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_memory_bank_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||
MemoryBankType.keyvalue.value
|
||||
)
|
||||
|
||||
|
||||
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
||||
@json_schema_type
|
||||
class KeywordMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||
MemoryBankType.keyword.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBank = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
VectorMemoryBank,
|
||||
KeyValueMemoryBank,
|
||||
KeywordMemoryBank,
|
||||
GraphMemoryBank,
|
||||
],
|
||||
Field(discriminator="memory_bank_type"),
|
||||
],
|
||||
name="MemoryBank",
|
||||
)
|
||||
|
||||
|
||||
class MemoryBankInput(BaseModel):
|
||||
memory_bank_id: str
|
||||
params: BankParams
|
||||
provider_memory_bank_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListMemoryBanksResponse(BaseModel):
|
||||
data: List[MemoryBank]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory-banks", method="GET")
|
||||
async def list_memory_banks(self) -> ListMemoryBanksResponse: ...
|
||||
|
||||
@webmethod(route="/memory-banks/{memory_bank_id}", method="GET")
|
||||
async def get_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory-banks", method="POST")
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE")
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
|
|||
class ResourceType(Enum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
memory_bank = "memory_bank"
|
||||
vector_db = "vector_db"
|
||||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
eval_task = "eval_task"
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .memory_banks import * # noqa: F401 F403
|
||||
from .vector_dbs import * # noqa: F401 F403
|
66
llama_stack/apis/vector_dbs/vector_dbs.py
Normal file
66
llama_stack/apis/vector_dbs/vector_dbs.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorDB(Resource):
|
||||
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
|
||||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
|
||||
@property
|
||||
def vector_db_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_vector_db_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class VectorDBInput(BaseModel):
|
||||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_vector_db_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListVectorDBsResponse(BaseModel):
|
||||
data: List[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> Optional[VectorDB]: ...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> VectorDB: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .memory import * # noqa: F401 F403
|
||||
from .vector_io import * # noqa: F401 F403
|
|
@ -13,55 +13,45 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory_banks import MemoryBank
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
content: InterleavedContent
|
||||
token_count: int
|
||||
document_id: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryDocumentsResponse(BaseModel):
|
||||
class QueryChunksResponse(BaseModel):
|
||||
chunks: List[Chunk]
|
||||
scores: List[float]
|
||||
|
||||
|
||||
class MemoryBankStore(Protocol):
|
||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Memory(Protocol):
|
||||
memory_bank_store: MemoryBankStore
|
||||
class VectorIO(Protocol):
|
||||
vector_db_store: VectorDBStore
|
||||
|
||||
# this will just block now until documents are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
@webmethod(route="/memory/insert", method="POST")
|
||||
async def insert_documents(
|
||||
@webmethod(route="/vector-io/insert", method="POST")
|
||||
async def insert_chunks(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/query", method="POST")
|
||||
async def query_documents(
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
async def query_chunks(
|
||||
self,
|
||||
bank_id: str,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
) -> QueryChunksResponse: ...
|
|
@ -13,14 +13,14 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
|||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
|
||||
|
@ -34,7 +34,7 @@ RoutingKey = Union[str, List[str]]
|
|||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
MemoryBank,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
|
@ -47,7 +47,7 @@ RoutableObjectWithProvider = Annotated[
|
|||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
MemoryBank,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
|
@ -60,7 +60,7 @@ RoutableObjectWithProvider = Annotated[
|
|||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
Memory,
|
||||
VectorIO,
|
||||
DatasetIO,
|
||||
Scoring,
|
||||
Eval,
|
||||
|
@ -153,7 +153,7 @@ a default SQLite store will be used.""",
|
|||
# registry of "resources" in the distribution
|
||||
models: List[ModelInput] = Field(default_factory=list)
|
||||
shields: List[ShieldInput] = Field(default_factory=list)
|
||||
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
|
||||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||
|
|
|
@ -32,8 +32,8 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.memory_banks,
|
||||
router_api=Api.memory,
|
||||
routing_table_api=Api.vector_dbs,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.datasets,
|
||||
|
|
|
@ -15,8 +15,6 @@ from llama_stack.apis.eval import Eval
|
|||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
@ -25,6 +23,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AutoRoutedProviderSpec,
|
||||
|
@ -40,7 +40,6 @@ from llama_stack.providers.datatypes import (
|
|||
DatasetsProtocolPrivate,
|
||||
EvalTasksProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
MemoryBanksProtocolPrivate,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
|
@ -48,6 +47,7 @@ from llama_stack.providers.datatypes import (
|
|||
ScoringFunctionsProtocolPrivate,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolsProtocolPrivate,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -62,8 +62,8 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.memory: Memory,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
@ -84,7 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
|
|||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
|
|
|
@ -12,13 +12,6 @@ from llama_stack.apis.common.content_types import URL
|
|||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
||||
from llama_stack.apis.memory_banks import (
|
||||
BankParams,
|
||||
ListMemoryBanksResponse,
|
||||
MemoryBank,
|
||||
MemoryBanks,
|
||||
MemoryBankType,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
|
@ -36,6 +29,7 @@ from llama_stack.apis.tools import (
|
|||
ToolGroups,
|
||||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
|
@ -59,8 +53,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.vector_io:
|
||||
return await p.register_vector_db(obj)
|
||||
elif api == Api.datasetio:
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
|
@ -75,8 +69,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.unregister_memory_bank(obj.identifier)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_db(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
|
@ -120,8 +114,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
elif api == Api.vector_io:
|
||||
p.vector_db_store = self
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
elif api == Api.scoring:
|
||||
|
@ -145,8 +139,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, MemoryBanksRoutingTable):
|
||||
return ("Memory", "memory_bank")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
@ -196,9 +190,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# Get existing objects from registry
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -311,22 +302,23 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
return shield
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> ListMemoryBanksResponse:
|
||||
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
||||
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
||||
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
|
||||
async def register_memory_bank(
|
||||
async def register_vector_db(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
if provider_memory_bank_id is None:
|
||||
provider_memory_bank_id = memory_bank_id
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
|
@ -335,44 +327,39 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
if params.embedding_model == "all-MiniLM-L6-v2":
|
||||
if embedding_model == "all-MiniLM-L6-v2":
|
||||
raise ValueError(
|
||||
"Embeddings are now served via Inference providers. "
|
||||
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
||||
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {params.embedding_model} not found")
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} is not an embedding model"
|
||||
)
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||
f"Model {embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
memory_bank_data = {
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
"provider_resource_id": provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||
"embedding_dimension"
|
||||
]
|
||||
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.unregister_object(existing_bank)
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
|
@ -21,8 +21,6 @@ from llama_stack.apis.eval import Eval
|
|||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
@ -32,6 +30,8 @@ from llama_stack.apis.shields import Shields
|
|||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
|
@ -42,7 +42,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
|
@ -51,7 +51,7 @@ class LlamaStack(
|
|||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
VectorIO,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
|
@ -69,7 +69,7 @@ class LlamaStack(
|
|||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
|
|
|
@ -14,11 +14,11 @@ from llama_stack.apis.datasets import Dataset
|
|||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
|
@ -31,10 +31,10 @@ class ShieldsProtocolPrivate(Protocol):
|
|||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
||||
|
||||
class MemoryBanksProtocolPrivate(Protocol):
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||
class VectorDBsProtocolPrivate(Protocol):
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
|
|
|
@ -38,78 +38,78 @@ EMBEDDING_DEPS = [
|
|||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_stack.providers.remote.memory.chroma",
|
||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaRemoteImplConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
||||
module="llama_stack.providers.inline.memory.chroma",
|
||||
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
|
||||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaInlineImplConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="pgvector",
|
||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.memory.pgvector",
|
||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||
module="llama_stack.providers.remote.memory.weaviate",
|
||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
api=Api.vector_io,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.memory.sample",
|
||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||
module="llama_stack.providers.remote.vector_io.sample",
|
||||
config_class="llama_stack.providers.remote.vector_io.sample.SampleConfig",
|
||||
),
|
||||
api_dependencies=[],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="qdrant",
|
||||
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
||||
module="llama_stack.providers.remote.memory.qdrant",
|
||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
Loading…
Add table
Add a link
Reference in a new issue