forked from phoenix-oss/llama-stack-mirror
[memory refactor][2/n] Update faiss and make it pass tests (#830)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Second part: - updates routing table / router code - updates the faiss implementation ## Test Plan ``` pytest -s -v -k sentence test_vector_io.py --env EMBEDDING_DIMENSION=384 ```
This commit is contained in:
parent
3ae8585b65
commit
78a481bb22
19 changed files with 343 additions and 353 deletions
|
@ -14,11 +14,11 @@ from llama_stack.providers.datatypes import Api, RoutingTable
|
|||
from .routing_tables import (
|
||||
DatasetsRoutingTable,
|
||||
EvalTasksRoutingTable,
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
ToolGroupsRoutingTable,
|
||||
VectorDBsRoutingTable,
|
||||
)
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ async def get_routing_table_impl(
|
|||
dist_registry: DistributionRegistry,
|
||||
) -> Any:
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
|
@ -51,14 +51,14 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
InferenceRouter,
|
||||
MemoryRouter,
|
||||
SafetyRouter,
|
||||
ScoringRouter,
|
||||
ToolRuntimeRouter,
|
||||
VectorIORouter,
|
||||
)
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"vector_io": VectorIORouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
"datasetio": DatasetIORouter,
|
||||
|
|
|
@ -27,8 +27,6 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks.memory_banks import BankParams
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
|
@ -39,11 +37,12 @@ from llama_stack.apis.scoring import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank identifier"""
|
||||
class VectorIORouter(VectorIO):
|
||||
"""Routes to an provider based on the vector db identifier"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -57,38 +56,40 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
async def register_vector_db(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memorybank_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_memory_bank(
|
||||
memory_bank_id,
|
||||
params,
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
provider_id,
|
||||
provider_memorybank_id,
|
||||
provider_vector_db_id,
|
||||
)
|
||||
|
||||
async def insert_documents(
|
||||
async def insert_chunks(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||
bank_id, documents, ttl_seconds
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||
vector_db_id, chunks, ttl_seconds
|
||||
)
|
||||
|
||||
async def query_documents(
|
||||
async def query_chunks(
|
||||
self,
|
||||
bank_id: str,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||
bank_id, query, params
|
||||
) -> QueryChunksResponse:
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||
vector_db_id, query, params
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -11,12 +11,12 @@ from .config import FaissImplConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .faiss import FaissMemoryImpl
|
||||
from .faiss import FaissVectorIOImpl
|
||||
|
||||
assert isinstance(
|
||||
config, FaissImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||
impl = FaissVectorIOImpl(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -17,35 +17,28 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory import (
|
||||
Chunk,
|
||||
Memory,
|
||||
MemoryBankDocument,
|
||||
QueryDocumentsResponse,
|
||||
)
|
||||
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import FaissImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||
VECTOR_DBS_PREFIX = "vector_dbs:v2::"
|
||||
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
id_by_index: Dict[int, str]
|
||||
chunk_by_index: Dict[int, str]
|
||||
|
||||
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.id_by_index = {}
|
||||
self.chunk_by_index = {}
|
||||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
@ -65,7 +58,6 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
if stored_data:
|
||||
data = json.loads(stored_data)
|
||||
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
|
||||
self.chunk_by_index = {
|
||||
int(k): Chunk.model_validate_json(v)
|
||||
for k, v in data["chunk_by_index"].items()
|
||||
|
@ -82,7 +74,6 @@ class FaissIndex(EmbeddingIndex):
|
|||
buffer = io.BytesIO()
|
||||
np.savetxt(buffer, np_index)
|
||||
data = {
|
||||
"id_by_index": self.id_by_index,
|
||||
"chunk_by_index": {
|
||||
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
|
||||
},
|
||||
|
@ -108,10 +99,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
indexlen = len(self.id_by_index)
|
||||
indexlen = len(self.chunk_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
self.id_by_index[indexlen + i] = chunk.document_id
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
|
||||
|
@ -120,7 +110,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
) -> QueryChunksResponse:
|
||||
distances, indices = self.index.search(
|
||||
embedding.reshape(1, -1).astype(np.float32), k
|
||||
)
|
||||
|
@ -133,10 +123,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
chunks.append(self.chunk_by_index[int(i)])
|
||||
scores.append(1.0 / float(d))
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
@ -146,77 +136,74 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing banks from kvstore
|
||||
start_key = MEMORY_BANKS_PREFIX
|
||||
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
|
||||
stored_banks = await self.kvstore.range(start_key, end_key)
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
|
||||
|
||||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank,
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
await FaissIndex.create(
|
||||
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||
vector_db.embedding_dimension, self.kvstore, vector_db.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
async def register_vector_db(
|
||||
self,
|
||||
memory_bank: MemoryBank,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
# Store in kvstore
|
||||
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=memory_bank.model_dump_json(),
|
||||
value=vector_db.model_dump_json(),
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
await FaissIndex.create(
|
||||
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=await FaissIndex.create(
|
||||
vector_db.embedding_dimension, self.kvstore, vector_db.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
return [i.vector_db for i in self.cache.values()]
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_documents(
|
||||
async def insert_chunks(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id)
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
|
||||
raise ValueError(
|
||||
f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}"
|
||||
)
|
||||
|
||||
await index.insert_documents(documents)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_documents(
|
||||
async def query_chunks(
|
||||
self,
|
||||
bank_id: str,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id)
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
||||
return await index.query_chunks(query, params)
|
||||
|
|
|
@ -33,8 +33,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
Api.memory,
|
||||
Api.memory_banks,
|
||||
Api.vector_io,
|
||||
Api.vector_dbs,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
],
|
||||
|
|
|
@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
|
|
|
@ -302,7 +302,7 @@ def pytest_collection_modifyitems(session, config, items):
|
|||
pytest_plugins = [
|
||||
"llama_stack.providers.tests.inference.fixtures",
|
||||
"llama_stack.providers.tests.safety.fixtures",
|
||||
"llama_stack.providers.tests.memory.fixtures",
|
||||
"llama_stack.providers.tests.vector_io.fixtures",
|
||||
"llama_stack.providers.tests.agents.fixtures",
|
||||
"llama_stack.providers.tests.datasetio.fixtures",
|
||||
"llama_stack.providers.tests.scoring.fixtures",
|
||||
|
|
|
@ -1,192 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse
|
||||
|
||||
from llama_stack.apis.memory_banks import (
|
||||
MemoryBank,
|
||||
MemoryBanks,
|
||||
VectorMemoryBankParams,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/memory/test_memory.py
|
||||
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
return [
|
||||
MemoryBankDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def register_memory_bank(
|
||||
banks_impl: MemoryBanks, embedding_model: str
|
||||
) -> MemoryBank:
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, memory_stack, embedding_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert any(
|
||||
bank.memory_bank_id == registered_bank.memory_bank_id
|
||||
for bank in response
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id)
|
||||
|
||||
# Verify our bank was removed
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert all(
|
||||
bank.memory_bank_id != registered_bank.memory_bank_id for bank in response
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, memory_stack, embedding_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
|
||||
try:
|
||||
# Register initial bank
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify our bank exists
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert any(bank.memory_bank_id == bank_id for bank in response)
|
||||
|
||||
# Try registering same bank again
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model=embedding_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify still only one instance of our bank
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert (
|
||||
len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
await banks_impl.unregister_memory_bank(bank_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(
|
||||
self, memory_stack, embedding_model, sample_documents
|
||||
):
|
||||
memory_impl, banks_impl = memory_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||
await memory_impl.insert_documents(
|
||||
registered_bank.memory_bank_id, sample_documents
|
||||
)
|
||||
|
||||
query1 = "programming language"
|
||||
response1 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query1
|
||||
)
|
||||
assert_valid_response(response1)
|
||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||
|
||||
# Test case 3: Query with semantic similarity
|
||||
query3 = "AI and brain-inspired computing"
|
||||
response3 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query3
|
||||
)
|
||||
assert_valid_response(response3)
|
||||
assert any(
|
||||
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
||||
)
|
||||
|
||||
# Test case 4: Query with limit on number of results
|
||||
query4 = "computer"
|
||||
params4 = {"max_chunks": 2}
|
||||
response4 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query4, params4
|
||||
)
|
||||
assert_valid_response(response4)
|
||||
assert len(response4.chunks) <= 2
|
||||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.01}
|
||||
response5 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query5, params5
|
||||
)
|
||||
assert_valid_response(response5)
|
||||
print("The scores are:", response5.scores)
|
||||
assert all(score >= 0.01 for score in response5.scores)
|
||||
|
||||
|
||||
def assert_valid_response(response: QueryDocumentsResponse):
|
||||
assert isinstance(response, QueryDocumentsResponse)
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.scores) > 0
|
||||
assert len(response.chunks) == len(response.scores)
|
||||
for chunk in response.chunks:
|
||||
assert isinstance(chunk.content, str)
|
||||
assert chunk.document_id is not None
|
|
@ -12,11 +12,11 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.datasets import DatasetInput
|
||||
from llama_stack.apis.eval_tasks import EvalTaskInput
|
||||
from llama_stack.apis.memory_banks import MemoryBankInput
|
||||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.apis.vector_dbs import VectorDBInput
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
|
@ -39,7 +39,7 @@ async def construct_stack_for_test(
|
|||
provider_data: Optional[Dict[str, Any]] = None,
|
||||
models: Optional[List[ModelInput]] = None,
|
||||
shields: Optional[List[ShieldInput]] = None,
|
||||
memory_banks: Optional[List[MemoryBankInput]] = None,
|
||||
vector_dbs: Optional[List[VectorDBInput]] = None,
|
||||
datasets: Optional[List[DatasetInput]] = None,
|
||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||
|
@ -53,7 +53,7 @@ async def construct_stack_for_test(
|
|||
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
||||
models=models or [],
|
||||
shields=shields or [],
|
||||
memory_banks=memory_banks or [],
|
||||
vector_dbs=vector_dbs or [],
|
||||
datasets=datasets or [],
|
||||
scoring_fns=scoring_fns or [],
|
||||
eval_tasks=eval_tasks or [],
|
||||
|
|
|
@ -13,14 +13,14 @@ from ..conftest import (
|
|||
)
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import MEMORY_FIXTURES
|
||||
from .fixtures import VECTOR_IO_FIXTURES
|
||||
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "sentence_transformers",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
},
|
||||
id="sentence_transformers",
|
||||
marks=pytest.mark.sentence_transformers,
|
||||
|
@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"memory": "faiss",
|
||||
"vector_io": "faiss",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
|
@ -36,7 +36,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "sentence_transformers",
|
||||
"memory": "chroma",
|
||||
"vector_io": "chroma",
|
||||
},
|
||||
id="chroma",
|
||||
marks=pytest.mark.chroma,
|
||||
|
@ -44,7 +44,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "bedrock",
|
||||
"memory": "qdrant",
|
||||
"vector_io": "qdrant",
|
||||
},
|
||||
id="qdrant",
|
||||
marks=pytest.mark.qdrant,
|
||||
|
@ -52,7 +52,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"memory": "weaviate",
|
||||
"vector_io": "weaviate",
|
||||
},
|
||||
id="weaviate",
|
||||
marks=pytest.mark.weaviate,
|
||||
|
@ -61,7 +61,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in MEMORY_FIXTURES:
|
||||
for fixture_name in VECTOR_IO_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
|
@ -69,7 +69,7 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
test_config = get_test_config_for_api(metafunc.config, "memory")
|
||||
test_config = get_test_config_for_api(metafunc.config, "vector_io")
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = getattr(test_config, "embedding_model", None)
|
||||
# Fall back to the default if not specified by the config file
|
||||
|
@ -81,16 +81,16 @@ def pytest_generate_tests(metafunc):
|
|||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
|
||||
if "memory_stack" in metafunc.fixturenames:
|
||||
if "vector_io_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides_from_test_config(
|
||||
metafunc.config, "memory", DEFAULT_PROVIDER_COMBINATIONS
|
||||
metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||
metafunc.parametrize("vector_io_stack", combinations, indirect=True)
|
|
@ -12,11 +12,12 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
@ -32,12 +33,12 @@ def embedding_model(request):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_remote() -> ProviderFixture:
|
||||
def vector_io_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_faiss() -> ProviderFixture:
|
||||
def vector_io_faiss() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
|
@ -53,7 +54,7 @@ def memory_faiss() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_pgvector() -> ProviderFixture:
|
||||
def vector_io_pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
|
@ -72,7 +73,7 @@ def memory_pgvector() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_weaviate() -> ProviderFixture:
|
||||
def vector_io_weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
|
@ -89,7 +90,7 @@ def memory_weaviate() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_chroma() -> ProviderFixture:
|
||||
def vector_io_chroma() -> ProviderFixture:
|
||||
url = os.getenv("CHROMA_URL")
|
||||
if url:
|
||||
config = ChromaRemoteImplConfig(url=url)
|
||||
|
@ -110,23 +111,23 @@ def memory_chroma() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_stack(embedding_model, request):
|
||||
async def vector_io_stack(embedding_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "memory"]:
|
||||
for key in ["inference", "vector_io"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.memory, Api.inference],
|
||||
[Api.vector_io, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
|
@ -140,4 +141,4 @@ async def memory_stack(embedding_model, request):
|
|||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]
|
200
llama_stack/providers/tests/vector_io/test_vector_io.py
Normal file
200
llama_stack/providers/tests/vector_io/test_vector_io.py
Normal file
|
@ -0,0 +1,200 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
make_overlapped_chunks,
|
||||
MemoryBankDocument,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/memory/test_memory.py
|
||||
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
docs = [
|
||||
MemoryBankDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="doc4",
|
||||
content="Neural networks are inspired by biological neural networks.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
]
|
||||
chunks = []
|
||||
for doc in docs:
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id, doc.content, window_len=512, overlap_len=64
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
async def register_vector_db(vector_dbs_impl: VectorDB, embedding_model: str):
|
||||
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
|
||||
return await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
|
||||
class TestVectorIO:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, vector_io_stack, embedding_model):
|
||||
_, vector_dbs_impl = vector_io_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_vector_db = await register_vector_db(
|
||||
vector_dbs_impl, embedding_model
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert any(
|
||||
vector_db.vector_db_id == registered_vector_db.vector_db_id
|
||||
for vector_db in response.data
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
await vector_dbs_impl.unregister_vector_db(
|
||||
registered_vector_db.vector_db_id
|
||||
)
|
||||
|
||||
# Verify our bank was removed
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert all(
|
||||
vector_db.vector_db_id != registered_vector_db.vector_db_id
|
||||
for vector_db in response.data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, vector_io_stack, embedding_model):
|
||||
_, vector_dbs_impl = vector_io_stack
|
||||
|
||||
vector_db_id = f"test_vector_db_{uuid.uuid4().hex}"
|
||||
|
||||
try:
|
||||
# Register initial bank
|
||||
await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# Verify our bank exists
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert any(
|
||||
vector_db.vector_db_id == vector_db_id for vector_db in response.data
|
||||
)
|
||||
|
||||
# Try registering same bank again
|
||||
await vector_dbs_impl.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# Verify still only one instance of our bank
|
||||
response = await vector_dbs_impl.list_vector_dbs()
|
||||
assert isinstance(response, ListVectorDBsResponse)
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
vector_db
|
||||
for vector_db in response.data
|
||||
if vector_db.vector_db_id == vector_db_id
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
await vector_dbs_impl.unregister_vector_db(vector_db_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(
|
||||
self, vector_io_stack, embedding_model, sample_chunks
|
||||
):
|
||||
vector_io_impl, vector_dbs_impl = vector_io_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_impl.insert_chunks("test_vector_db", sample_chunks)
|
||||
|
||||
registered_db = await register_vector_db(vector_dbs_impl, embedding_model)
|
||||
await vector_io_impl.insert_chunks(registered_db.vector_db_id, sample_chunks)
|
||||
|
||||
query1 = "programming language"
|
||||
response1 = await vector_io_impl.query_chunks(
|
||||
registered_db.vector_db_id, query1
|
||||
)
|
||||
assert_valid_response(response1)
|
||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||
|
||||
# Test case 3: Query with semantic similarity
|
||||
query3 = "AI and brain-inspired computing"
|
||||
response3 = await vector_io_impl.query_chunks(
|
||||
registered_db.vector_db_id, query3
|
||||
)
|
||||
assert_valid_response(response3)
|
||||
assert any(
|
||||
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
||||
)
|
||||
|
||||
# Test case 4: Query with limit on number of results
|
||||
query4 = "computer"
|
||||
params4 = {"max_chunks": 2}
|
||||
response4 = await vector_io_impl.query_chunks(
|
||||
registered_db.vector_db_id, query4, params4
|
||||
)
|
||||
assert_valid_response(response4)
|
||||
assert len(response4.chunks) <= 2
|
||||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.01}
|
||||
response5 = await vector_io_impl.query_chunks(
|
||||
registered_db.vector_db_id, query5, params5
|
||||
)
|
||||
assert_valid_response(response5)
|
||||
print("The scores are:", response5.scores)
|
||||
assert all(score >= 0.01 for score in response5.scores)
|
||||
|
||||
|
||||
def assert_valid_response(response: QueryChunksResponse):
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.scores) > 0
|
||||
assert len(response.chunks) == len(response.scores)
|
||||
for chunk in response.chunks:
|
||||
assert isinstance(chunk.content, str)
|
|
@ -11,8 +11,11 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.memory.memory import MemoryBankDocument, URL
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_doc
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
MemoryBankDocument,
|
||||
URL,
|
||||
)
|
||||
|
||||
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||
|
|
@ -18,6 +18,8 @@ import numpy as np
|
|||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -25,16 +27,24 @@ from llama_stack.apis.common.content_types import (
|
|||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryBankDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
|
@ -165,7 +175,7 @@ class EmbeddingIndex(ABC):
|
|||
@abstractmethod
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
@ -174,56 +184,35 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
|
||||
@dataclass
|
||||
class BankWithIndex:
|
||||
bank: VectorMemoryBank
|
||||
class VectorDBWithIndex:
|
||||
vector_db: VectorDB
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
async def insert_documents(
|
||||
async def insert_chunks(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
chunks: List[Chunk],
|
||||
) -> None:
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
self.bank.chunk_size_in_tokens,
|
||||
self.bank.overlap_size_in_tokens
|
||||
or (self.bank.chunk_size_in_tokens // 4),
|
||||
)
|
||||
if not chunks:
|
||||
continue
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [x.content for x in chunks]
|
||||
self.vector_db.embedding_model, [x.content for x in chunks]
|
||||
)
|
||||
embeddings = np.array(embeddings_response.embeddings)
|
||||
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
async def query_documents(
|
||||
async def query_chunks(
|
||||
self,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
) -> QueryChunksResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
query_str = interleaved_content_as_str(query)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [query_str]
|
||||
self.vector_db.embedding_model, [query_str]
|
||||
)
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
|
|
|
@ -32,6 +32,7 @@ def pytest_addoption(parser):
|
|||
TEXT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
INFERENCE_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def provider_data():
|
||||
# check env for tavily secret, brave secret and inject all into provider data
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue