mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 04:34:57 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -34,19 +34,19 @@ class MemoryClient(Memory):
|
|||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.get(
|
||||
r = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"bank_id": bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if len(d) == 0:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
@ -55,21 +55,21 @@ class MemoryClient(Memory):
|
|||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_banks/create",
|
||||
data={
|
||||
json={
|
||||
"name": name,
|
||||
"config": config.dict(),
|
||||
"url": url,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if len(d) == 0:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -77,16 +77,16 @@ class MemoryClient(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_bank/insert",
|
||||
data={
|
||||
json={
|
||||
"bank_id": bank_id,
|
||||
"documents": documents,
|
||||
"documents": [d.dict() for d in documents],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -95,18 +95,18 @@ class MemoryClient(Memory):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_bank/query",
|
||||
data={
|
||||
json={
|
||||
"bank_id": bank_id,
|
||||
"query": query,
|
||||
"params": params,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
return QueryDocumentsResponse(**r.json())
|
||||
)
|
||||
r.raise_for_status()
|
||||
return QueryDocumentsResponse(**r.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
|
||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=URL(
|
||||
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||
),
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
# insert some documents
|
||||
await client.insert_documents(
|
||||
bank_id=bank.bank_id,
|
||||
documents=[
|
||||
MemoryBankDocument(
|
||||
document_id="1",
|
||||
content="hello world",
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="2",
|
||||
content="goodbye world",
|
||||
),
|
||||
],
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
# query the documents
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
query=[
|
||||
"hello world",
|
||||
"How do I use Lora?",
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
for chunk, score in zip(response.chunks, response.scores):
|
||||
print(f"Score: {score}")
|
||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
query=[
|
||||
"Tell me more about llama3 and torchtune",
|
||||
],
|
||||
)
|
||||
for chunk, score in zip(response.chunks, response.scores):
|
||||
print(f"Score: {score}")
|
||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue