[memory refactor][2/n] Update faiss and make it pass tests (#830)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Second part:

- updates routing table / router code 
- updates the faiss implementation


## Test Plan

```
pytest -s -v -k sentence test_vector_io.py --env EMBEDDING_DIMENSION=384
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:02:15 -08:00 committed by GitHub
parent 3ae8585b65
commit 78a481bb22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 343 additions and 353 deletions

View file

@ -27,8 +27,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.apis.models import ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
@ -39,11 +37,12 @@ from llama_stack.apis.scoring import (
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank identifier"""
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
@ -57,38 +56,40 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank_id: str,
params: BankParams,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
await self.routing_table.register_memory_bank(
memory_bank_id,
params,
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_memorybank_id,
provider_vector_db_id,
)
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
vector_db_id, chunks, ttl_seconds
)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
) -> QueryChunksResponse:
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
vector_db_id, query, params
)