mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 16:29:47 +00:00
remove mixin and test fixes
This commit is contained in:
parent
5bbeb985ca
commit
0e451525e5
9 changed files with 140 additions and 69 deletions
|
|
@ -15,12 +15,10 @@ from numpy.typing import NDArray
|
|||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
InferenceEmbeddingMixin,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -72,7 +70,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
await self.client.delete_collection(self.collection.name)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str, inference_api: Api.inference) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
|
|
@ -111,8 +109,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
|
|||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||
memory_bank, ChromaIndex(self.client, collection)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
|
|
@ -125,9 +123,10 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
|
|||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
self.cache[bank.identifier] = self._create_bank_with_index(
|
||||
self.cache[bank.identifier] = BankWithIndex(
|
||||
bank,
|
||||
ChromaIndex(self.client, collection),
|
||||
self.inference_api,
|
||||
)
|
||||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
|
|||
collection = await self.client.get_collection(bank_id)
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection))
|
||||
index = BankWithIndex(
|
||||
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
InferenceEmbeddingMixin,
|
||||
)
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
|
@ -120,9 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(
|
||||
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
|
||||
):
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -171,8 +168,8 @@ class PGVectorMemoryAdapter(
|
|||
|
||||
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||
memory_bank, index
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, index, self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
|
|
@ -183,9 +180,10 @@ class PGVectorMemoryAdapter(
|
|||
banks = load_models(self.cursor, VectorMemoryBank)
|
||||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
index = self._create_bank_with_index(
|
||||
index = BankWithIndex(
|
||||
bank,
|
||||
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
return banks
|
||||
|
|
@ -216,5 +214,5 @@ class PGVectorMemoryAdapter(
|
|||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||
self.cache[bank_id] = self._create_bank_with_index(bank, index)
|
||||
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||
return self.cache[bank_id]
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
InferenceEmbeddingMixin,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -101,9 +100,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class QdrantVectorMemoryAdapter(
|
||||
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
|
||||
):
|
||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
|
|
@ -124,9 +121,10 @@ class QdrantVectorMemoryAdapter(
|
|||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
index = self._create_bank_with_index(
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
|
@ -144,9 +142,10 @@ class QdrantVectorMemoryAdapter(
|
|||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
index = self._create_bank_with_index(
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
InferenceEmbeddingMixin,
|
||||
)
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
|
@ -83,7 +82,6 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class WeaviateMemoryAdapter(
|
||||
InferenceEmbeddingMixin,
|
||||
Memory,
|
||||
NeedsRequestProviderData,
|
||||
MemoryBanksProtocolPrivate,
|
||||
|
|
@ -140,9 +138,10 @@ class WeaviateMemoryAdapter(
|
|||
],
|
||||
)
|
||||
|
||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
|
|
@ -164,9 +163,10 @@ class WeaviateMemoryAdapter(
|
|||
if not client.collections.exists(bank.identifier):
|
||||
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
||||
|
||||
index = self._create_bank_with_index(
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue