diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index 04c2dab5b..89f7cac99 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional import fire import httpx -from termcolor import cprint from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.memory_banks.client import MemoryBanksClient from llama_stack.providers.utils.memory.file_utils import data_url_from_file @@ -35,44 +35,16 @@ class MemoryClient(Memory): async def shutdown(self) -> None: pass - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: async with httpx.AsyncClient() as client: - r = await client.get( - f"{self.base_url}/memory/get", - params={ - "bank_id": bank_id, - }, - headers={"Content-Type": "application/json"}, - timeout=20, - ) - r.raise_for_status() - d = r.json() - if not d: - return None - return MemoryBank(**d) - - async def create_memory_bank( - self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.base_url}/memory/create", + response = await client.post( + f"{self.base_url}/memory/register_memory_bank", json={ - "name": name, - "config": config.dict(), - "url": url, + "memory_bank": json.loads(memory_bank.json()), }, headers={"Content-Type": "application/json"}, - timeout=20, ) - r.raise_for_status() - d = r.json() - if not d: - return None - return MemoryBank(**d) + response.raise_for_status() async def insert_documents( self, @@ -114,22 +86,20 @@ class MemoryClient(Memory): async def run_main(host: str, port: int, stream: bool): client = MemoryClient(f"http://{host}:{port}") + banks_client = MemoryBanksClient(f"http://{host}:{port}") - # create a memory bank - bank = await client.create_memory_bank( - name="test_bank", - config=VectorMemoryBankConfig( - bank_id="test_bank", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), + bank = VectorMemoryBankDef( + identifier="test_bank", + provider_id="", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, ) - cprint(json.dumps(bank.dict(), indent=4), "green") + await client.register_memory_bank(bank) - retrieved_bank = await client.get_memory_bank(bank.bank_id) + retrieved_bank = await banks_client.get_memory_bank(bank.identifier) assert retrieved_bank is not None - assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2" + assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2" urls = [ "memory_optimizations.rst", @@ -162,13 +132,13 @@ async def run_main(host: str, port: int, stream: bool): # insert some documents await client.insert_documents( - bank_id=bank.bank_id, + bank_id=bank.identifier, documents=documents, ) # query the documents response = await client.query_documents( - bank_id=bank.bank_id, + bank_id=bank.identifier, query=[ "How do I use Lora?", ], @@ -178,7 +148,7 @@ async def run_main(host: str, port: int, stream: bool): print(f"Chunk:\n========\n{chunk}\n========\n") response = await client.query_documents( - bank_id=bank.bank_id, + bank_id=bank.identifier, query=[ "Tell me more about llama3 and torchtune", ], diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 3b763d1f3..6a6e28133 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import json from typing import Any, Dict, List, Optional @@ -70,31 +69,10 @@ class MemoryBanksClient(MemoryBanks): j = response.json() return deserialize_memory_bank_def(j) - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.base_url}/memory/register_memory_bank", - json={ - "memory_bank": json.loads(memory_bank.json()), - }, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - async def run_main(host: str, port: int, stream: bool): client = MemoryBanksClient(f"http://{host}:{port}") - await client.register_memory_bank( - VectorMemoryBankDef( - identifier="test_bank", - provider_id="", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - response = await client.list_memory_banks() cprint(f"list_memory_banks response={response}", "green") diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 0540cdf60..d0a7aed54 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -153,15 +153,15 @@ class BankWithIndex: self, documents: List[MemoryBankDocument], ) -> None: - model = get_embedding_model(self.bank.config.embedding_model) + model = get_embedding_model(self.bank.embedding_model) for doc in documents: content = await content_from_doc(doc) chunks = make_overlapped_chunks( doc.document_id, content, - self.bank.config.chunk_size_in_tokens, - self.bank.config.overlap_size_in_tokens - or (self.bank.config.chunk_size_in_tokens // 4), + self.bank.chunk_size_in_tokens, + self.bank.overlap_size_in_tokens + or (self.bank.chunk_size_in_tokens // 4), ) if not chunks: continue @@ -189,6 +189,6 @@ class BankWithIndex: else: query_str = _process(query) - model = get_embedding_model(self.bank.config.embedding_model) + model = get_embedding_model(self.bank.embedding_model) query_vector = model.encode([query_str])[0].astype(np.float32) return await self.index.query(query_vector, k)