memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -34,19 +34,19 @@ class MemoryClient(Memory):
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
async with client.get(
r = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
@ -55,21 +55,21 @@ class MemoryClient(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_banks/create",
data={
json={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents(
self,
@ -77,16 +77,16 @@ class MemoryClient(Memory):
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_bank/insert",
data={
json={
"bank_id": bank_id,
"documents": documents,
"documents": [d.dict() for d in documents],
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
)
r.raise_for_status()
async def query_documents(
self,
@ -95,18 +95,18 @@ class MemoryClient(Memory):
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_bank/query",
data={
json={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
)
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=[
MemoryBankDocument(
document_id="1",
content="hello world",
),
MemoryBankDocument(
document_id="2",
content="goodbye world",
),
],
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"hello world",
"How do I use Lora?",
],
)
print(response)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"Tell me more about llama3 and torchtune",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
def main(host: str, port: int, stream: bool = True):