mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
more memory related fixes; memory.client now works
This commit is contained in:
parent
3725e74906
commit
862f8ddb8d
3 changed files with 24 additions and 76 deletions
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks.client import MemoryBanksClient
|
||||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,44 +35,16 @@ class MemoryClient(Memory):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(
|
response = await client.post(
|
||||||
f"{self.base_url}/memory/get",
|
f"{self.base_url}/memory/register_memory_bank",
|
||||||
params={
|
|
||||||
"bank_id": bank_id,
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
timeout=20,
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
if not d:
|
|
||||||
return None
|
|
||||||
return MemoryBank(**d)
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
r = await client.post(
|
|
||||||
f"{self.base_url}/memory/create",
|
|
||||||
json={
|
json={
|
||||||
"name": name,
|
"memory_bank": json.loads(memory_bank.json()),
|
||||||
"config": config.dict(),
|
|
||||||
"url": url,
|
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
response.raise_for_status()
|
||||||
d = r.json()
|
|
||||||
if not d:
|
|
||||||
return None
|
|
||||||
return MemoryBank(**d)
|
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -114,22 +86,20 @@ class MemoryClient(Memory):
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
client = MemoryClient(f"http://{host}:{port}")
|
client = MemoryClient(f"http://{host}:{port}")
|
||||||
|
banks_client = MemoryBanksClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# create a memory bank
|
bank = VectorMemoryBankDef(
|
||||||
bank = await client.create_memory_bank(
|
identifier="test_bank",
|
||||||
name="test_bank",
|
provider_id="",
|
||||||
config=VectorMemoryBankConfig(
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
bank_id="test_bank",
|
chunk_size_in_tokens=512,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
overlap_size_in_tokens=64,
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
cprint(json.dumps(bank.dict(), indent=4), "green")
|
await client.register_memory_bank(bank)
|
||||||
|
|
||||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
|
||||||
assert retrieved_bank is not None
|
assert retrieved_bank is not None
|
||||||
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
|
|
||||||
# insert some documents
|
# insert some documents
|
||||||
await client.insert_documents(
|
await client.insert_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.identifier,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
|
|
||||||
# query the documents
|
# query the documents
|
||||||
response = await client.query_documents(
|
response = await client.query_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.identifier,
|
||||||
query=[
|
query=[
|
||||||
"How do I use Lora?",
|
"How do I use Lora?",
|
||||||
],
|
],
|
||||||
|
@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||||
|
|
||||||
response = await client.query_documents(
|
response = await client.query_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.identifier,
|
||||||
query=[
|
query=[
|
||||||
"Tell me more about llama3 and torchtune",
|
"Tell me more about llama3 and torchtune",
|
||||||
],
|
],
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -70,31 +69,10 @@ class MemoryBanksClient(MemoryBanks):
|
||||||
j = response.json()
|
j = response.json()
|
||||||
return deserialize_memory_bank_def(j)
|
return deserialize_memory_bank_def(j)
|
||||||
|
|
||||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{self.base_url}/memory/register_memory_bank",
|
|
||||||
json={
|
|
||||||
"memory_bank": json.loads(memory_bank.json()),
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
client = MemoryBanksClient(f"http://{host}:{port}")
|
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
await client.register_memory_bank(
|
|
||||||
VectorMemoryBankDef(
|
|
||||||
identifier="test_bank",
|
|
||||||
provider_id="",
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.list_memory_banks()
|
response = await client.list_memory_banks()
|
||||||
cprint(f"list_memory_banks response={response}", "green")
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
|
|
@ -153,15 +153,15 @@ class BankWithIndex:
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model(self.bank.config.embedding_model)
|
model = get_embedding_model(self.bank.embedding_model)
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
doc.document_id,
|
doc.document_id,
|
||||||
content,
|
content,
|
||||||
self.bank.config.chunk_size_in_tokens,
|
self.bank.chunk_size_in_tokens,
|
||||||
self.bank.config.overlap_size_in_tokens
|
self.bank.overlap_size_in_tokens
|
||||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
or (self.bank.chunk_size_in_tokens // 4),
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
continue
|
continue
|
||||||
|
@ -189,6 +189,6 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.config.embedding_model)
|
model = get_embedding_model(self.bank.embedding_model)
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
return await self.index.query(query_vector, k)
|
return await self.index.query(query_vector, k)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue