mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 21:48:36 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -5,4 +5,4 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import FaissImplConfig # noqa
|
||||
from .memory import get_provider_impl # noqa
|
||||
from .faiss import get_provider_impl # noqa
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import faiss
|
||||
import httpx
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
|
|||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.get(doc.content).text
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
|
@ -62,16 +63,17 @@ def make_overlapped_chunks(
|
|||
return chunks
|
||||
|
||||
|
||||
class BankState(BaseModel):
|
||||
@dataclass
|
||||
class BankState:
|
||||
bank: MemoryBank
|
||||
index: Optional[faiss.IndexFlatL2] = None
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
model: SentenceTransformer,
|
||||
model: "SentenceTransformer",
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
@ -97,21 +99,44 @@ class BankState(BaseModel):
|
|||
content=chunk[0],
|
||||
token_count=chunk[1],
|
||||
)
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||
self.id_by_index[indexlen + i] = doc.document_id
|
||||
|
||||
async def query_documents(
|
||||
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
|
||||
) -> Tuple[List[Chunk], List[float]]:
|
||||
self,
|
||||
model: "SentenceTransformer",
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
query_vector = model.encode([query])[0]
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
query_vector = model.encode([query_str])[0]
|
||||
distances, indices = self.index.search(
|
||||
query_vector.reshape(1, -1).astype(np.float32), k
|
||||
)
|
||||
|
||||
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
|
||||
scores = [1.0 / float(d) for d in distances[0]]
|
||||
chunks = []
|
||||
scores = []
|
||||
for d, i in zip(distances[0], indices[0]):
|
||||
if i < 0:
|
||||
continue
|
||||
chunks.append(self.chunk_by_index[int(i)])
|
||||
scores.append(1.0 / float(d))
|
||||
|
||||
return chunks, scores
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||
if self.index is None:
|
||||
|
@ -122,7 +147,7 @@ class BankState(BaseModel):
|
|||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
self.model = None
|
||||
self.states = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
|
|||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
print("Creating memory bank")
|
||||
assert url is None, "URL is not supported for this implementation"
|
||||
assert (
|
||||
config.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {config.type}"
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=id,
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
state = BankState(bank=bank)
|
||||
self.states[id] = state
|
||||
self.states[bank_id] = state
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
|
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
|
|||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
|
||||
await state.insert_documents(self.model, documents)
|
||||
await state.insert_documents(self.get_model(), documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
|
|||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
|
||||
chunks, scores = await state.query_documents(self.model, query, params)
|
||||
return QueryDocumentsResponse(chunk=chunks, scores=scores)
|
||||
return await state.query_documents(self.get_model(), query, params)
|
||||
|
||||
def get_model(self) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
if self.model is None:
|
||||
print("Loading sentence transformer")
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return self.model
|
Loading…
Add table
Add a link
Reference in a new issue