memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedTextMedia | URL
mime_type: str
metadata: Dict[str, Any]
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
@ -103,7 +103,7 @@ class Memory(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
@ -136,14 +136,14 @@ class Memory(Protocol):
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get")
@webmethod(route="/memory_bank/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete")
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,

View file

@ -34,19 +34,19 @@ class MemoryClient(Memory):
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
async with client.get(
r = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
@ -55,21 +55,21 @@ class MemoryClient(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_banks/create",
data={
json={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents(
self,
@ -77,16 +77,16 @@ class MemoryClient(Memory):
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_bank/insert",
data={
json={
"bank_id": bank_id,
"documents": documents,
"documents": [d.dict() for d in documents],
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
)
r.raise_for_status()
async def query_documents(
self,
@ -95,18 +95,18 @@ class MemoryClient(Memory):
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_bank/query",
data={
json={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
)
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=[
MemoryBankDocument(
document_id="1",
content="hello world",
),
MemoryBankDocument(
document_id="2",
content="goodbye world",
),
],
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"hello world",
"How do I use Lora?",
],
)
print(response)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"Tell me more about llama3 and torchtune",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
def main(host: str, port: int, stream: bool = True):

View file

@ -5,4 +5,4 @@
# the root directory of this source tree.
from .config import FaissImplConfig # noqa
from .memory import get_provider_impl # noqa
from .faiss import get_provider_impl # noqa

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import faiss
import httpx
import numpy as np
from sentence_transformers import SentenceTransformer
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
return await client.get(doc.content).text
r = await client.get(doc.content.uri)
return r.text
def _process(c):
if isinstance(c, str):
@ -62,16 +63,17 @@ def make_overlapped_chunks(
return chunks
class BankState(BaseModel):
@dataclass
class BankState:
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
id_by_index: Dict[int, str] = Field(default_factory=dict)
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents(
self,
model: SentenceTransformer,
model: "SentenceTransformer",
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
@ -97,21 +99,44 @@ class BankState(BaseModel):
content=chunk[0],
token_count=chunk[1],
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id
async def query_documents(
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
) -> Tuple[List[Chunk], List[float]]:
self,
model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
query_vector = model.encode([query])[0]
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
)
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
scores = [1.0 / float(d) for d in distances[0]]
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return chunks, scores
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
@ -122,7 +147,7 @@ class BankState(BaseModel):
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.model = None
self.states = {}
async def initialize(self) -> None: ...
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
print("Creating memory bank")
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
id = str(uuid.uuid4())
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=id,
bank_id=bank_id,
name=name,
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[id] = state
self.states[bank_id] = state
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
await state.insert_documents(self.model, documents)
await state.insert_documents(self.get_model(), documents)
async def query_documents(
self,
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
chunks, scores = await state.query_documents(self.model, query, params)
return QueryDocumentsResponse(chunk=chunks, scores=scores)
return await state.query_documents(self.get_model(), query, params)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model

View file

@ -9,13 +9,14 @@ from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_inference_providers() -> List[ProviderSpec]:
def available_memory_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=[
"faiss",
"blobfile",
"faiss-cpu",
"sentence-transformers",
],
module="llama_toolchain.memory.meta_reference.faiss",