more memory related fixes; memory.client now works

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:10:24 -07:00 committed by Ashwin Bharambe
parent 3725e74906
commit 862f8ddb8d
3 changed files with 24 additions and 76 deletions

View file

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.client import MemoryBanksClient
from llama_stack.providers.utils.memory.file_utils import data_url_from_file from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -35,44 +35,16 @@ class MemoryClient(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get( response = await client.post(
f"{self.base_url}/memory/get", f"{self.base_url}/memory/register_memory_bank",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory/create",
json={ json={
"name": name, "memory_bank": json.loads(memory_bank.json()),
"config": config.dict(),
"url": url,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20,
) )
r.raise_for_status() response.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents( async def insert_documents(
self, self,
@ -114,22 +86,20 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}") client = MemoryClient(f"http://{host}:{port}")
banks_client = MemoryBanksClient(f"http://{host}:{port}")
# create a memory bank bank = VectorMemoryBankDef(
bank = await client.create_memory_bank( identifier="test_bank",
name="test_bank", provider_id="",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
),
) )
cprint(json.dumps(bank.dict(), indent=4), "green") await client.register_memory_bank(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id) retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2" assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
# insert some documents # insert some documents
await client.insert_documents( await client.insert_documents(
bank_id=bank.bank_id, bank_id=bank.identifier,
documents=documents, documents=documents,
) )
# query the documents # query the documents
response = await client.query_documents( response = await client.query_documents(
bank_id=bank.bank_id, bank_id=bank.identifier,
query=[ query=[
"How do I use Lora?", "How do I use Lora?",
], ],
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
print(f"Chunk:\n========\n{chunk}\n========\n") print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents( response = await client.query_documents(
bank_id=bank.bank_id, bank_id=bank.identifier,
query=[ query=[
"Tell me more about llama3 and torchtune", "Tell me more about llama3 and torchtune",
], ],

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -70,31 +69,10 @@ class MemoryBanksClient(MemoryBanks):
j = response.json() j = response.json()
return deserialize_memory_bank_def(j) return deserialize_memory_bank_def(j)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}") client = MemoryBanksClient(f"http://{host}:{port}")
await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await client.list_memory_banks() response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green") cprint(f"list_memory_banks response={response}", "green")

View file

@ -153,15 +153,15 @@ class BankWithIndex:
self, self,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
model = get_embedding_model(self.bank.config.embedding_model) model = get_embedding_model(self.bank.embedding_model)
for doc in documents: for doc in documents:
content = await content_from_doc(doc) content = await content_from_doc(doc)
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
doc.document_id, doc.document_id,
content, content,
self.bank.config.chunk_size_in_tokens, self.bank.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens self.bank.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4), or (self.bank.chunk_size_in_tokens // 4),
) )
if not chunks: if not chunks:
continue continue
@ -189,6 +189,6 @@ class BankWithIndex:
else: else:
query_str = _process(query) query_str = _process(query)
model = get_embedding_model(self.bank.config.embedding_model) model = get_embedding_model(self.bank.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32) query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k) return await self.index.query(query_vector, k)