memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -5,4 +5,4 @@
# the root directory of this source tree.
from .config import FaissImplConfig # noqa
from .memory import get_provider_impl # noqa
from .faiss import get_provider_impl # noqa

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import faiss
import httpx
import numpy as np
from sentence_transformers import SentenceTransformer
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
return await client.get(doc.content).text
r = await client.get(doc.content.uri)
return r.text
def _process(c):
if isinstance(c, str):
@ -62,16 +63,17 @@ def make_overlapped_chunks(
return chunks
class BankState(BaseModel):
@dataclass
class BankState:
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
id_by_index: Dict[int, str] = Field(default_factory=dict)
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents(
self,
model: SentenceTransformer,
model: "SentenceTransformer",
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
@ -97,21 +99,44 @@ class BankState(BaseModel):
content=chunk[0],
token_count=chunk[1],
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id
async def query_documents(
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
) -> Tuple[List[Chunk], List[float]]:
self,
model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
query_vector = model.encode([query])[0]
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
)
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
scores = [1.0 / float(d) for d in distances[0]]
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return chunks, scores
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
@ -122,7 +147,7 @@ class BankState(BaseModel):
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.model = None
self.states = {}
async def initialize(self) -> None: ...
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
print("Creating memory bank")
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
id = str(uuid.uuid4())
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=id,
bank_id=bank_id,
name=name,
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[id] = state
self.states[bank_id] = state
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
await state.insert_documents(self.model, documents)
await state.insert_documents(self.get_model(), documents)
async def query_documents(
self,
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
chunks, scores = await state.query_documents(self.model, query, params)
return QueryDocumentsResponse(chunk=chunks, scores=scores)
return await state.query_documents(self.get_model(), query, params)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model