Vector store inference api (#598)

# What does this PR do?
Moves all the memory providers to use the inference API and improved the
memory tests to setup the inference stack correctly and use the
embedding models


## Test Plan
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.2-3B-Instruct"
--embedding-model="sentence-transformers/all-MiniLM-L6-v2"
llama_stack/providers/tests/inference/test_embeddings.py --env
EMBEDDING_DIMENSION=384


pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=weaviate"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env
WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar
 
pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=chroma"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env
CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=pgvector"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=faiss"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:16:54 -08:00 committed by GitHub
parent db7b26a8c9
commit 4f8b73b9e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 235 additions and 118 deletions

View file

@ -4,16 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps):
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissMemoryImpl
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
impl = FaissMemoryImpl(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -19,11 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
@ -32,7 +31,8 @@ from .config import FaissImplConfig
logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
FAISS_INDEX_PREFIX = "faiss_index:v2::"
class FaissIndex(EmbeddingIndex):
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
if not self.kvstore:
return
index_key = f"faiss_index:v1::{self.bank_id}"
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
if stored_data:
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
}
index_key = f"faiss_index:v1::{self.bank_id}"
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self):
if not self.kvstore or not self.bank_id:
return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None:
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache = {}
self.kvstore = None
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
bank,
await FaissIndex.create(
bank.embedding_dimension, self.kvstore, bank.identifier
),
self.inference_api,
)
self.cache[bank.identifier] = index
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
)
# Store in cache
index = BankWithIndex(
bank=memory_bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
await FaissIndex.create(
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
),
self.inference_api,
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]

View file

@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.memory,
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
api=Api.memory,
@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
),
api_dependencies=[],
),
remote_provider_spec(
Api.memory,
@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
),
api_dependencies=[Api.inference],
),
]

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url)
impl = ChromaMemoryAdapter(config.url, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -15,8 +15,7 @@ from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -72,7 +71,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
def __init__(self, url: str, inference_api: Api.inference) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
@ -82,6 +81,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.host = parsed.hostname
self.port = parsed.port
self.inference_api = inference_api
self.client = None
self.cache = {}
@ -109,10 +109,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections()
@ -124,11 +123,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
self.cache[bank.identifier] = BankWithIndex(
bank,
ChromaIndex(self.client, collection),
self.inference_api,
)
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()]
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps):
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config)
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None:
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cursor = None
self.conn = None
self.cache = {}
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models(
self.cursor,
[
(memory_bank.identifier, memory_bank),
],
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, index, self.inference_api
)
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
@ -190,8 +181,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
for bank in banks:
if bank.identifier not in self.cache:
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
bank,
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
self.inference_api,
)
self.cache[bank.identifier] = index
return banks
@ -214,14 +206,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
return self.cache[bank_id]

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps):
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config)
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig) -> None:
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None:
pass
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex(
bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api,
)
self.cache[memory_bank.identifier] = index
@ -143,6 +145,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex(
bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps):
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config)
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -12,10 +12,11 @@ import weaviate
import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
):
def __init__(self, config: WeaviateConfig) -> None:
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client()
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
],
)
index = BankWithIndex(
bank=memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.inference_api,
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is
@ -163,6 +173,7 @@ class WeaviateMemoryAdapter(
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
return index

View file

@ -6,9 +6,65 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"memory": "faiss",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
),
pytest.param(
{
"inference": "ollama",
"memory": "pgvector",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "together",
"memory": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "bedrock",
"memory": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"memory": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_addoption(parser):
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
def pytest_configure(config):
for fixture_name in MEMORY_FIXTURES:
config.addinivalue_line(
@ -18,12 +74,22 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc):
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if not model:
raise ValueError(
"No embedding model specified. Please provide a valid embedding model."
)
params = [pytest.param(model, id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames:
metafunc.parametrize(
"memory_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in MEMORY_FIXTURES
],
indirect=True,
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -10,6 +10,8 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
@ -97,14 +99,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")
async def memory_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
async def memory_stack(embedding_model, request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "memory"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.memory],
{"memory": fixture.providers},
fixture.provider_data,
[Api.memory, Api.inference],
providers,
provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
)
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -45,12 +45,14 @@ def sample_documents():
]
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
async def register_memory_bank(
banks_impl: MemoryBanks, embedding_model: str
) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
embedding_model=embedding_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
class TestMemory:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack):
async def test_banks_list(self, memory_stack, embedding_model):
_, banks_impl = memory_stack
# Register a test bank
registered_bank = await register_memory_bank(banks_impl)
registered_bank = await register_memory_bank(banks_impl, embedding_model)
try:
# Verify our bank shows up in list
@ -84,7 +86,7 @@ class TestMemory:
)
@pytest.mark.asyncio
async def test_banks_register(self, memory_stack):
async def test_banks_register(self, memory_stack, embedding_model):
_, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -94,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
embedding_model=embedding_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -109,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
embedding_model=embedding_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -126,13 +128,15 @@ class TestMemory:
await banks_impl.unregister_memory_bank(bank_id)
@pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents):
async def test_query_documents(
self, memory_stack, embedding_model, sample_documents
):
memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl)
registered_bank = await register_memory_bank(banks_impl, embedding_model)
await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents
)
@ -165,13 +169,13 @@ class TestMemory:
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2}
params5 = {"score_threshold": 0.01}
response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5
)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores)
assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse):

View file

@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
class BankWithIndex:
bank: VectorMemoryBank
index: EmbeddingIndex
inference_api: Api.inference
async def insert_documents(
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model(self.bank.embedding_model)
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
@ -183,7 +165,10 @@ class BankWithIndex:
)
if not chunks:
continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
await self.index.add_chunks(chunks, embeddings)
@ -208,6 +193,8 @@ class BankWithIndex:
else:
query_str = _process(query)
model = get_embedding_model(self.bank.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [query_str]
)
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold)