more memory related fixes; memory.client now works

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:10:24 -07:00 committed by Ashwin Bharambe
parent 3725e74906
commit 862f8ddb8d
3 changed files with 24 additions and 76 deletions

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.client import MemoryBanksClient
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -35,44 +35,16 @@ class MemoryClient(Memory):
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory/create",
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"name": name,
"config": config.dict(),
"url": url,
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
response.raise_for_status()
async def insert_documents(
self,
@ -114,22 +86,20 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
banks_client = MemoryBanksClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
bank = VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
cprint(json.dumps(bank.dict(), indent=4), "green")
await client.register_memory_bank(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id)
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
urls = [
"memory_optimizations.rst",
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"How do I use Lora?",
],
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"Tell me more about llama3 and torchtune",
],

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
from typing import Any, Dict, List, Optional
@ -70,31 +69,10 @@ class MemoryBanksClient(MemoryBanks):
j = response.json()
return deserialize_memory_bank_def(j)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

View file

@ -153,15 +153,15 @@ class BankWithIndex:
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model(self.bank.config.embedding_model)
model = get_embedding_model(self.bank.embedding_model)
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
doc.document_id,
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
self.bank.chunk_size_in_tokens,
self.bank.overlap_size_in_tokens
or (self.bank.chunk_size_in_tokens // 4),
)
if not chunks:
continue
@ -189,6 +189,6 @@ class BankWithIndex:
else:
query_str = _process(query)
model = get_embedding_model(self.bank.config.embedding_model)
model = get_embedding_model(self.bank.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)