mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
remove mixin and test fixes
This commit is contained in:
parent
5bbeb985ca
commit
0e451525e5
9 changed files with 140 additions and 69 deletions
|
@ -23,8 +23,8 @@ from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
InferenceEmbeddingMixin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
@ -131,7 +131,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
@ -147,11 +147,12 @@ class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivat
|
||||||
|
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = self._create_bank_with_index(
|
index = BankWithIndex(
|
||||||
bank,
|
bank,
|
||||||
await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
bank.embedding_dimension, self.kvstore, bank.identifier
|
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -175,11 +176,12 @@ class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivat
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
memory_bank,
|
memory_bank,
|
||||||
await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
|
|
@ -15,12 +15,10 @@ from numpy.typing import NDArray
|
||||||
from pydantic import parse_obj_as
|
from pydantic import parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
InferenceEmbeddingMixin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -72,7 +70,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
await self.client.delete_collection(self.collection.name)
|
await self.client.delete_collection(self.collection.name)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, url: str, inference_api: Api.inference) -> None:
|
def __init__(self, url: str, inference_api: Api.inference) -> None:
|
||||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
|
@ -111,8 +109,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
|
||||||
name=memory_bank.identifier,
|
name=memory_bank.identifier,
|
||||||
metadata={"bank": memory_bank.model_dump_json()},
|
metadata={"bank": memory_bank.model_dump_json()},
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
memory_bank, ChromaIndex(self.client, collection)
|
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
@ -125,9 +123,10 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
|
||||||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.cache[bank.identifier] = self._create_bank_with_index(
|
self.cache[bank.identifier] = BankWithIndex(
|
||||||
bank,
|
bank,
|
||||||
ChromaIndex(self.client, collection),
|
ChromaIndex(self.client, collection),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
|
||||||
collection = await self.client.get_collection(bank_id)
|
collection = await self.client.get_collection(bank_id)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||||
index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection))
|
index = BankWithIndex(
|
||||||
|
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -21,7 +21,6 @@ from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
InferenceEmbeddingMixin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
@ -120,9 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
|
|
||||||
):
|
|
||||||
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
@ -171,8 +168,8 @@ class PGVectorMemoryAdapter(
|
||||||
|
|
||||||
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||||
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
memory_bank, index
|
memory_bank, index, self.inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
@ -183,9 +180,10 @@ class PGVectorMemoryAdapter(
|
||||||
banks = load_models(self.cursor, VectorMemoryBank)
|
banks = load_models(self.cursor, VectorMemoryBank)
|
||||||
for bank in banks:
|
for bank in banks:
|
||||||
if bank.identifier not in self.cache:
|
if bank.identifier not in self.cache:
|
||||||
index = self._create_bank_with_index(
|
index = BankWithIndex(
|
||||||
bank,
|
bank,
|
||||||
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
|
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
return banks
|
return banks
|
||||||
|
@ -216,5 +214,5 @@ class PGVectorMemoryAdapter(
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||||
self.cache[bank_id] = self._create_bank_with_index(bank, index)
|
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
|
@ -21,7 +21,6 @@ from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
InferenceEmbeddingMixin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -101,9 +100,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorMemoryAdapter(
|
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
|
|
||||||
):
|
|
||||||
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
@ -124,9 +121,10 @@ class QdrantVectorMemoryAdapter(
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
index = self._create_bank_with_index(
|
index = BankWithIndex(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
@ -144,9 +142,10 @@ class QdrantVectorMemoryAdapter(
|
||||||
if not bank:
|
if not bank:
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
index = self._create_bank_with_index(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -19,7 +19,6 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
InferenceEmbeddingMixin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||||
|
@ -83,7 +82,6 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemoryAdapter(
|
class WeaviateMemoryAdapter(
|
||||||
InferenceEmbeddingMixin,
|
|
||||||
Memory,
|
Memory,
|
||||||
NeedsRequestProviderData,
|
NeedsRequestProviderData,
|
||||||
MemoryBanksProtocolPrivate,
|
MemoryBanksProtocolPrivate,
|
||||||
|
@ -140,9 +138,10 @@ class WeaviateMemoryAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
memory_bank,
|
memory_bank,
|
||||||
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
@ -164,9 +163,10 @@ class WeaviateMemoryAdapter(
|
||||||
if not client.collections.exists(bank.identifier):
|
if not client.collections.exists(bank.identifier):
|
||||||
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
||||||
|
|
||||||
index = self._create_bank_with_index(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -6,9 +6,65 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "meta_reference",
|
||||||
|
"memory": "faiss",
|
||||||
|
},
|
||||||
|
id="meta_reference",
|
||||||
|
marks=pytest.mark.meta_reference,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "ollama",
|
||||||
|
"memory": "pgvector",
|
||||||
|
},
|
||||||
|
id="ollama",
|
||||||
|
marks=pytest.mark.ollama,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"memory": "chroma",
|
||||||
|
},
|
||||||
|
id="chroma",
|
||||||
|
marks=pytest.mark.chroma,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "bedrock",
|
||||||
|
"memory": "qdrant",
|
||||||
|
},
|
||||||
|
id="qdrant",
|
||||||
|
marks=pytest.mark.qdrant,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "fireworks",
|
||||||
|
"memory": "weaviate",
|
||||||
|
},
|
||||||
|
id="weaviate",
|
||||||
|
marks=pytest.mark.weaviate,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for fixture_name in MEMORY_FIXTURES:
|
for fixture_name in MEMORY_FIXTURES:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
|
@ -18,12 +74,22 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
|
model = metafunc.config.getoption("--embedding-model")
|
||||||
|
if not model:
|
||||||
|
raise ValueError(
|
||||||
|
"No embedding model specified. Please provide a valid embedding model."
|
||||||
|
)
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
|
||||||
|
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||||
if "memory_stack" in metafunc.fixturenames:
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
available_fixtures = {
|
||||||
"memory_stack",
|
"inference": INFERENCE_FIXTURES,
|
||||||
[
|
"memory": MEMORY_FIXTURES,
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
}
|
||||||
for fixture_name in MEMORY_FIXTURES
|
combinations = (
|
||||||
],
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
indirect=True,
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
|
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -10,6 +10,8 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
|
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
|
@ -97,14 +99,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_stack(request):
|
async def memory_stack(embedding_model, request):
|
||||||
fixture_name = request.param
|
fixture_dict = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["inference", "memory"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory, Api.inference],
|
||||||
{"memory": fixture.providers},
|
providers,
|
||||||
fixture.provider_data,
|
provider_data,
|
||||||
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=embedding_model,
|
||||||
|
model_type=ModelType.embedding_model,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
|
@ -45,12 +45,14 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
async def register_memory_bank(
|
||||||
|
banks_impl: MemoryBanks, embedding_model: str
|
||||||
|
) -> MemoryBank:
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||||
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack, embedding_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
# Register a test bank
|
# Register a test bank
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify our bank shows up in list
|
# Verify our bank shows up in list
|
||||||
|
@ -84,7 +86,7 @@ class TestMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_banks_register(self, memory_stack, embedding_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
|
@ -94,7 +96,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -109,7 +111,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -126,13 +128,15 @@ class TestMemory:
|
||||||
await banks_impl.unregister_memory_bank(bank_id)
|
await banks_impl.unregister_memory_bank(bank_id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, memory_stack, sample_documents):
|
async def test_query_documents(
|
||||||
|
self, memory_stack, embedding_model, sample_documents
|
||||||
|
):
|
||||||
memory_impl, banks_impl = memory_stack
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||||
await memory_impl.insert_documents(
|
await memory_impl.insert_documents(
|
||||||
registered_bank.memory_bank_id, sample_documents
|
registered_bank.memory_bank_id, sample_documents
|
||||||
)
|
)
|
||||||
|
|
|
@ -198,20 +198,3 @@ class BankWithIndex:
|
||||||
)
|
)
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
||||||
|
|
||||||
class InferenceEmbeddingMixin:
|
|
||||||
inference_api: Api.inference
|
|
||||||
|
|
||||||
def __init__(self, inference_api: Api.inference):
|
|
||||||
self.inference_api = inference_api
|
|
||||||
|
|
||||||
def _create_bank_with_index(
|
|
||||||
self, bank: VectorMemoryBank, index: EmbeddingIndex
|
|
||||||
) -> BankWithIndex:
|
|
||||||
|
|
||||||
return BankWithIndex(
|
|
||||||
bank=bank,
|
|
||||||
index=index,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue