mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
[memory refactor][1/n] Rename Memory -> VectorIO, MemoryBanks -> VectorDBs (#828)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. This is the first part: - delete other kinds of memory banks (keyvalue, keyword, graph) for now; we will introduce a keyvalue store API as part of this design but not use it in the RAG tool yet. - renaming of the APIs
This commit is contained in:
parent
35a00d004a
commit
3ae8585b65
37 changed files with 175 additions and 296 deletions
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
@ -133,8 +132,6 @@ class Session(BaseModel):
|
||||||
turns: List[Turn]
|
turns: List[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
memory_bank: Optional[MemoryBank] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
|
@ -14,7 +14,7 @@ class Api(Enum):
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
memory = "memory"
|
vector_io = "vector_io"
|
||||||
datasetio = "datasetio"
|
datasetio = "datasetio"
|
||||||
scoring = "scoring"
|
scoring = "scoring"
|
||||||
eval = "eval"
|
eval = "eval"
|
||||||
|
@ -25,7 +25,7 @@ class Api(Enum):
|
||||||
|
|
||||||
models = "models"
|
models = "models"
|
||||||
shields = "shields"
|
shields = "shields"
|
||||||
memory_banks = "memory_banks"
|
vector_dbs = "vector_dbs"
|
||||||
datasets = "datasets"
|
datasets = "datasets"
|
||||||
scoring_functions = "scoring_functions"
|
scoring_functions = "scoring_functions"
|
||||||
eval_tasks = "eval_tasks"
|
eval_tasks = "eval_tasks"
|
||||||
|
|
|
@ -1,161 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
runtime_checkable,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBankType(Enum):
|
|
||||||
vector = "vector"
|
|
||||||
keyvalue = "keyvalue"
|
|
||||||
keyword = "keyword"
|
|
||||||
graph = "graph"
|
|
||||||
|
|
||||||
|
|
||||||
# define params for each type of memory bank, this leads to a tagged union
|
|
||||||
# accepted as input from the API or from the config.
|
|
||||||
@json_schema_type
|
|
||||||
class VectorMemoryBankParams(BaseModel):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
||||||
embedding_model: str
|
|
||||||
chunk_size_in_tokens: int
|
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class KeyValueMemoryBankParams(BaseModel):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
||||||
MemoryBankType.keyvalue.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class KeywordMemoryBankParams(BaseModel):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
||||||
MemoryBankType.keyword.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class GraphMemoryBankParams(BaseModel):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
||||||
|
|
||||||
|
|
||||||
BankParams = Annotated[
|
|
||||||
Union[
|
|
||||||
VectorMemoryBankParams,
|
|
||||||
KeyValueMemoryBankParams,
|
|
||||||
KeywordMemoryBankParams,
|
|
||||||
GraphMemoryBankParams,
|
|
||||||
],
|
|
||||||
Field(discriminator="memory_bank_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Some common functionality for memory banks.
|
|
||||||
class MemoryBankResourceMixin(Resource):
|
|
||||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_bank_id(self) -> str:
|
|
||||||
return self.identifier
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_memory_bank_id(self) -> str:
|
|
||||||
return self.provider_resource_id
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class VectorMemoryBank(MemoryBankResourceMixin):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
|
||||||
embedding_model: str
|
|
||||||
chunk_size_in_tokens: int
|
|
||||||
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
|
||||||
MemoryBankType.keyvalue.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
|
||||||
@json_schema_type
|
|
||||||
class KeywordMemoryBank(MemoryBankResourceMixin):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
|
||||||
MemoryBankType.keyword.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class GraphMemoryBank(MemoryBankResourceMixin):
|
|
||||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBank = register_schema(
|
|
||||||
Annotated[
|
|
||||||
Union[
|
|
||||||
VectorMemoryBank,
|
|
||||||
KeyValueMemoryBank,
|
|
||||||
KeywordMemoryBank,
|
|
||||||
GraphMemoryBank,
|
|
||||||
],
|
|
||||||
Field(discriminator="memory_bank_type"),
|
|
||||||
],
|
|
||||||
name="MemoryBank",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBankInput(BaseModel):
|
|
||||||
memory_bank_id: str
|
|
||||||
params: BankParams
|
|
||||||
provider_memory_bank_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListMemoryBanksResponse(BaseModel):
|
|
||||||
data: List[MemoryBank]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
@trace_protocol
|
|
||||||
class MemoryBanks(Protocol):
|
|
||||||
@webmethod(route="/memory-banks", method="GET")
|
|
||||||
async def list_memory_banks(self) -> ListMemoryBanksResponse: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory-banks/{memory_bank_id}", method="GET")
|
|
||||||
async def get_memory_bank(
|
|
||||||
self,
|
|
||||||
memory_bank_id: str,
|
|
||||||
) -> Optional[MemoryBank]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory-banks", method="POST")
|
|
||||||
async def register_memory_bank(
|
|
||||||
self,
|
|
||||||
memory_bank_id: str,
|
|
||||||
params: BankParams,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
|
||||||
) -> MemoryBank: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory-banks/{memory_bank_id}", method="DELETE")
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
|
||||||
class ResourceType(Enum):
|
class ResourceType(Enum):
|
||||||
model = "model"
|
model = "model"
|
||||||
shield = "shield"
|
shield = "shield"
|
||||||
memory_bank = "memory_bank"
|
vector_db = "vector_db"
|
||||||
dataset = "dataset"
|
dataset = "dataset"
|
||||||
scoring_function = "scoring_function"
|
scoring_function = "scoring_function"
|
||||||
eval_task = "eval_task"
|
eval_task = "eval_task"
|
||||||
|
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .memory_banks import * # noqa: F401 F403
|
from .vector_dbs import * # noqa: F401 F403
|
66
llama_stack/apis/vector_dbs/vector_dbs.py
Normal file
66
llama_stack/apis/vector_dbs/vector_dbs.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VectorDB(Resource):
|
||||||
|
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
|
||||||
|
|
||||||
|
embedding_model: str
|
||||||
|
embedding_dimension: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vector_db_id(self) -> str:
|
||||||
|
return self.identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_vector_db_id(self) -> str:
|
||||||
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBInput(BaseModel):
|
||||||
|
vector_db_id: str
|
||||||
|
embedding_model: str
|
||||||
|
embedding_dimension: int
|
||||||
|
provider_vector_db_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListVectorDBsResponse(BaseModel):
|
||||||
|
data: List[VectorDB]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class VectorDBs(Protocol):
|
||||||
|
@webmethod(route="/vector-dbs", method="GET")
|
||||||
|
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
|
||||||
|
async def get_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
) -> Optional[VectorDB]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/vector-dbs", method="POST")
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: Optional[int] = 384,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
provider_vector_db_id: Optional[str] = None,
|
||||||
|
) -> VectorDB: ...
|
||||||
|
|
||||||
|
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .memory import * # noqa: F401 F403
|
from .vector_io import * # noqa: F401 F403
|
|
@ -13,55 +13,45 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.memory_banks import MemoryBank
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBankDocument(BaseModel):
|
|
||||||
document_id: str
|
|
||||||
content: InterleavedContent | URL
|
|
||||||
mime_type: str | None = None
|
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
token_count: int
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
document_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QueryDocumentsResponse(BaseModel):
|
class QueryChunksResponse(BaseModel):
|
||||||
chunks: List[Chunk]
|
chunks: List[Chunk]
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
class MemoryBankStore(Protocol):
|
class VectorDBStore(Protocol):
|
||||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Memory(Protocol):
|
class VectorIO(Protocol):
|
||||||
memory_bank_store: MemoryBankStore
|
vector_db_store: VectorDBStore
|
||||||
|
|
||||||
# this will just block now until documents are inserted, but it should
|
# this will just block now until documents are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
@webmethod(route="/memory/insert", method="POST")
|
@webmethod(route="/vector-io/insert", method="POST")
|
||||||
async def insert_documents(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
vector_db_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
chunks: List[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory/query", method="POST")
|
@webmethod(route="/vector-io/query", method="POST")
|
||||||
async def query_documents(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse: ...
|
) -> QueryChunksResponse: ...
|
|
@ -13,14 +13,14 @@ from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
|
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
|
|
||||||
from llama_stack.apis.models import Model, ModelInput
|
from llama_stack.apis.models import Model, ModelInput
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||||
from llama_stack.apis.shields import Shield, ShieldInput
|
from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ RoutingKey = Union[str, List[str]]
|
||||||
RoutableObject = Union[
|
RoutableObject = Union[
|
||||||
Model,
|
Model,
|
||||||
Shield,
|
Shield,
|
||||||
MemoryBank,
|
VectorDB,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
EvalTask,
|
EvalTask,
|
||||||
|
@ -47,7 +47,7 @@ RoutableObjectWithProvider = Annotated[
|
||||||
Union[
|
Union[
|
||||||
Model,
|
Model,
|
||||||
Shield,
|
Shield,
|
||||||
MemoryBank,
|
VectorDB,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFn,
|
ScoringFn,
|
||||||
EvalTask,
|
EvalTask,
|
||||||
|
@ -60,7 +60,7 @@ RoutableObjectWithProvider = Annotated[
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Union[
|
||||||
Inference,
|
Inference,
|
||||||
Safety,
|
Safety,
|
||||||
Memory,
|
VectorIO,
|
||||||
DatasetIO,
|
DatasetIO,
|
||||||
Scoring,
|
Scoring,
|
||||||
Eval,
|
Eval,
|
||||||
|
@ -153,7 +153,7 @@ a default SQLite store will be used.""",
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: List[ModelInput] = Field(default_factory=list)
|
models: List[ModelInput] = Field(default_factory=list)
|
||||||
shields: List[ShieldInput] = Field(default_factory=list)
|
shields: List[ShieldInput] = Field(default_factory=list)
|
||||||
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
|
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||||
|
|
|
@ -32,8 +32,8 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
router_api=Api.safety,
|
router_api=Api.safety,
|
||||||
),
|
),
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.memory_banks,
|
routing_table_api=Api.vector_dbs,
|
||||||
router_api=Api.memory,
|
router_api=Api.vector_io,
|
||||||
),
|
),
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.datasets,
|
routing_table_api=Api.datasets,
|
||||||
|
|
|
@ -15,8 +15,6 @@ from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTasks
|
from llama_stack.apis.eval_tasks import EvalTasks
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
@ -25,6 +23,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.client import get_client_impl
|
from llama_stack.distribution.client import get_client_impl
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
|
@ -40,7 +40,6 @@ from llama_stack.providers.datatypes import (
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
EvalTasksProtocolPrivate,
|
EvalTasksProtocolPrivate,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
MemoryBanksProtocolPrivate,
|
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
RemoteProviderConfig,
|
RemoteProviderConfig,
|
||||||
|
@ -48,6 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolsProtocolPrivate,
|
ToolsProtocolPrivate,
|
||||||
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -62,8 +62,8 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
Api.memory: Memory,
|
Api.vector_io: VectorIO,
|
||||||
Api.memory_banks: MemoryBanks,
|
Api.vector_dbs: VectorDBs,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.shields: Shields,
|
Api.shields: Shields,
|
||||||
|
@ -84,7 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
Api.scoring: (
|
Api.scoring: (
|
||||||
|
|
|
@ -12,13 +12,6 @@ from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
||||||
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
||||||
from llama_stack.apis.memory_banks import (
|
|
||||||
BankParams,
|
|
||||||
ListMemoryBanksResponse,
|
|
||||||
MemoryBank,
|
|
||||||
MemoryBanks,
|
|
||||||
MemoryBankType,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import (
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
@ -36,6 +29,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolGroups,
|
ToolGroups,
|
||||||
ToolHost,
|
ToolHost,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
|
@ -59,8 +53,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
return await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
return await p.register_shield(obj)
|
return await p.register_shield(obj)
|
||||||
elif api == Api.memory:
|
elif api == Api.vector_io:
|
||||||
return await p.register_memory_bank(obj)
|
return await p.register_vector_db(obj)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.register_dataset(obj)
|
return await p.register_dataset(obj)
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
|
@ -75,8 +69,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
|
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.memory:
|
if api == Api.vector_io:
|
||||||
return await p.unregister_memory_bank(obj.identifier)
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
elif api == Api.inference:
|
elif api == Api.inference:
|
||||||
return await p.unregister_model(obj.identifier)
|
return await p.unregister_model(obj.identifier)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
|
@ -120,8 +114,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p.model_store = self
|
p.model_store = self
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
elif api == Api.memory:
|
elif api == Api.vector_io:
|
||||||
p.memory_bank_store = self
|
p.vector_db_store = self
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
|
@ -145,8 +139,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
return ("Safety", "shield")
|
return ("Safety", "shield")
|
||||||
elif isinstance(self, MemoryBanksRoutingTable):
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
return ("Memory", "memory_bank")
|
return ("VectorIO", "vector_db")
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
return ("DatasetIO", "dataset")
|
return ("DatasetIO", "dataset")
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
@ -196,9 +190,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
async def register_object(
|
async def register_object(
|
||||||
self, obj: RoutableObjectWithProvider
|
self, obj: RoutableObjectWithProvider
|
||||||
) -> RoutableObjectWithProvider:
|
) -> RoutableObjectWithProvider:
|
||||||
# Get existing objects from registry
|
|
||||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
|
||||||
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -311,22 +302,23 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
return shield
|
return shield
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
async def list_memory_banks(self) -> ListMemoryBanksResponse:
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
return ListMemoryBanksResponse(data=await self.get_all_with_type("memory_bank"))
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||||
|
|
||||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
||||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
memory_bank_id: str,
|
vector_db_id: str,
|
||||||
params: BankParams,
|
embedding_model: str,
|
||||||
|
embedding_dimension: Optional[int] = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_memory_bank_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> MemoryBank:
|
) -> VectorDB:
|
||||||
if provider_memory_bank_id is None:
|
if provider_vector_db_id is None:
|
||||||
provider_memory_bank_id = memory_bank_id
|
provider_vector_db_id = vector_db_id
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
# If provider_id not specified, use the only provider if it supports this shield type
|
||||||
if len(self.impls_by_provider_id) == 1:
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
@ -335,44 +327,39 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
model = await self.get_object_by_identifier("model", params.embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
if params.embedding_model == "all-MiniLM-L6-v2":
|
if embedding_model == "all-MiniLM-L6-v2":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Embeddings are now served via Inference providers. "
|
"Embeddings are now served via Inference providers. "
|
||||||
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
||||||
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {params.embedding_model} not found")
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
f"Model {params.embedding_model} is not an embedding model"
|
|
||||||
)
|
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model {params.embedding_model} does not have an embedding dimension"
|
f"Model {embedding_model} does not have an embedding dimension"
|
||||||
)
|
)
|
||||||
memory_bank_data = {
|
vector_db_data = {
|
||||||
"identifier": memory_bank_id,
|
"identifier": vector_db_id,
|
||||||
"type": ResourceType.memory_bank.value,
|
"type": ResourceType.vector_db.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_memory_bank_id,
|
"provider_resource_id": provider_vector_db_id,
|
||||||
**params.model_dump(),
|
"embedding_model": embedding_model,
|
||||||
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
if params.memory_bank_type == MemoryBankType.vector.value:
|
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
await self.register_object(vector_db)
|
||||||
"embedding_dimension"
|
return vector_db
|
||||||
]
|
|
||||||
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
|
||||||
await self.register_object(memory_bank)
|
|
||||||
return memory_bank
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||||
if existing_bank is None:
|
if existing_vector_db is None:
|
||||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
await self.unregister_object(existing_bank)
|
await self.unregister_object(existing_vector_db)
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
@ -21,8 +21,6 @@ from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.eval_tasks import EvalTasks
|
from llama_stack.apis.eval_tasks import EvalTasks
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
@ -32,6 +30,8 @@ from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
|
@ -42,7 +42,7 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
MemoryBanks,
|
VectorDBs,
|
||||||
Inference,
|
Inference,
|
||||||
BatchInference,
|
BatchInference,
|
||||||
Agents,
|
Agents,
|
||||||
|
@ -51,7 +51,7 @@ class LlamaStack(
|
||||||
Datasets,
|
Datasets,
|
||||||
Telemetry,
|
Telemetry,
|
||||||
PostTraining,
|
PostTraining,
|
||||||
Memory,
|
VectorIO,
|
||||||
Eval,
|
Eval,
|
||||||
EvalTasks,
|
EvalTasks,
|
||||||
Scoring,
|
Scoring,
|
||||||
|
@ -69,7 +69,7 @@ class LlamaStack(
|
||||||
RESOURCES = [
|
RESOURCES = [
|
||||||
("models", Api.models, "register_model", "list_models"),
|
("models", Api.models, "register_model", "list_models"),
|
||||||
("shields", Api.shields, "register_shield", "list_shields"),
|
("shields", Api.shields, "register_shield", "list_shields"),
|
||||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
||||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||||
(
|
(
|
||||||
"scoring_fns",
|
"scoring_fns",
|
||||||
|
|
|
@ -14,11 +14,11 @@ from llama_stack.apis.datasets import Dataset
|
||||||
|
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.eval_tasks import EvalTask
|
from llama_stack.apis.eval_tasks import EvalTask
|
||||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import Tool
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
|
||||||
|
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
|
@ -31,10 +31,10 @@ class ShieldsProtocolPrivate(Protocol):
|
||||||
async def register_shield(self, shield: Shield) -> None: ...
|
async def register_shield(self, shield: Shield) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanksProtocolPrivate(Protocol):
|
class VectorDBsProtocolPrivate(Protocol):
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class DatasetsProtocolPrivate(Protocol):
|
class DatasetsProtocolPrivate(Protocol):
|
||||||
|
|
|
@ -38,78 +38,78 @@ EMBEDDING_DEPS = [
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.vector_io,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.vector_io,
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.vector_io.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="chromadb",
|
adapter_type="chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||||
module="llama_stack.providers.remote.memory.chroma",
|
module="llama_stack.providers.remote.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaRemoteImplConfig",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.vector_io,
|
||||||
provider_type="inline::chromadb",
|
provider_type="inline::chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
||||||
module="llama_stack.providers.inline.memory.chroma",
|
module="llama_stack.providers.inline.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
|
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaInlineImplConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="pgvector",
|
adapter_type="pgvector",
|
||||||
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||||
module="llama_stack.providers.remote.memory.pgvector",
|
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||||
module="llama_stack.providers.remote.memory.weaviate",
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.vector_io,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="sample",
|
adapter_type="sample",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.memory.sample",
|
module="llama_stack.providers.remote.vector_io.sample",
|
||||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.vector_io.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
adapter_type="qdrant",
|
adapter_type="qdrant",
|
||||||
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
||||||
module="llama_stack.providers.remote.memory.qdrant",
|
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantConfig",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
Loading…
Add table
Add a link
Reference in a new issue