mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel):
|
|||
document_id: str
|
||||
content: InterleavedTextMedia | URL
|
||||
mime_type: str
|
||||
metadata: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -103,7 +103,7 @@ class Memory(Protocol):
|
|||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get")
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||
|
@ -136,14 +136,14 @@ class Memory(Protocol):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
@webmethod(route="/memory_bank/documents/get")
|
||||
@webmethod(route="/memory_bank/documents/get", method="GET")
|
||||
async def get_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
|
||||
@webmethod(route="/memory_bank/documents/delete")
|
||||
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
|
||||
async def delete_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
@ -34,19 +34,19 @@ class MemoryClient(Memory):
|
|||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.get(
|
||||
r = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"bank_id": bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if len(d) == 0:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
@ -55,21 +55,21 @@ class MemoryClient(Memory):
|
|||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_banks/create",
|
||||
data={
|
||||
json={
|
||||
"name": name,
|
||||
"config": config.dict(),
|
||||
"url": url,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if len(d) == 0:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -77,16 +77,16 @@ class MemoryClient(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_bank/insert",
|
||||
data={
|
||||
json={
|
||||
"bank_id": bank_id,
|
||||
"documents": documents,
|
||||
"documents": [d.dict() for d in documents],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -95,18 +95,18 @@ class MemoryClient(Memory):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_bank/query",
|
||||
data={
|
||||
json={
|
||||
"bank_id": bank_id,
|
||||
"query": query,
|
||||
"params": params,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
return QueryDocumentsResponse(**r.json())
|
||||
)
|
||||
r.raise_for_status()
|
||||
return QueryDocumentsResponse(**r.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
|
||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=URL(
|
||||
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||
),
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
# insert some documents
|
||||
await client.insert_documents(
|
||||
bank_id=bank.bank_id,
|
||||
documents=[
|
||||
MemoryBankDocument(
|
||||
document_id="1",
|
||||
content="hello world",
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="2",
|
||||
content="goodbye world",
|
||||
),
|
||||
],
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
# query the documents
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
query=[
|
||||
"hello world",
|
||||
"How do I use Lora?",
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
for chunk, score in zip(response.chunks, response.scores):
|
||||
print(f"Score: {score}")
|
||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
query=[
|
||||
"Tell me more about llama3 and torchtune",
|
||||
],
|
||||
)
|
||||
for chunk, score in zip(response.chunks, response.scores):
|
||||
print(f"Score: {score}")
|
||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
|
|
|
@ -5,4 +5,4 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import FaissImplConfig # noqa
|
||||
from .memory import get_provider_impl # noqa
|
||||
from .faiss import get_provider_impl # noqa
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import faiss
|
||||
import httpx
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
|
|||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.get(doc.content).text
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
|
@ -62,16 +63,17 @@ def make_overlapped_chunks(
|
|||
return chunks
|
||||
|
||||
|
||||
class BankState(BaseModel):
|
||||
@dataclass
|
||||
class BankState:
|
||||
bank: MemoryBank
|
||||
index: Optional[faiss.IndexFlatL2] = None
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
model: SentenceTransformer,
|
||||
model: "SentenceTransformer",
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
@ -97,21 +99,44 @@ class BankState(BaseModel):
|
|||
content=chunk[0],
|
||||
token_count=chunk[1],
|
||||
)
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||
self.id_by_index[indexlen + i] = doc.document_id
|
||||
|
||||
async def query_documents(
|
||||
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
|
||||
) -> Tuple[List[Chunk], List[float]]:
|
||||
self,
|
||||
model: "SentenceTransformer",
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
query_vector = model.encode([query])[0]
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
query_vector = model.encode([query_str])[0]
|
||||
distances, indices = self.index.search(
|
||||
query_vector.reshape(1, -1).astype(np.float32), k
|
||||
)
|
||||
|
||||
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
|
||||
scores = [1.0 / float(d) for d in distances[0]]
|
||||
chunks = []
|
||||
scores = []
|
||||
for d, i in zip(distances[0], indices[0]):
|
||||
if i < 0:
|
||||
continue
|
||||
chunks.append(self.chunk_by_index[int(i)])
|
||||
scores.append(1.0 / float(d))
|
||||
|
||||
return chunks, scores
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||
if self.index is None:
|
||||
|
@ -122,7 +147,7 @@ class BankState(BaseModel):
|
|||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
self.model = None
|
||||
self.states = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
|
|||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
print("Creating memory bank")
|
||||
assert url is None, "URL is not supported for this implementation"
|
||||
assert (
|
||||
config.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {config.type}"
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=id,
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
state = BankState(bank=bank)
|
||||
self.states[id] = state
|
||||
self.states[bank_id] = state
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
|
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
|
|||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
|
||||
await state.insert_documents(self.model, documents)
|
||||
await state.insert_documents(self.get_model(), documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
|
|||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
|
||||
chunks, scores = await state.query_documents(self.model, query, params)
|
||||
return QueryDocumentsResponse(chunk=chunks, scores=scores)
|
||||
return await state.query_documents(self.get_model(), query, params)
|
||||
|
||||
def get_model(self) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
if self.model is None:
|
||||
print("Loading sentence transformer")
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return self.model
|
|
@ -9,13 +9,14 @@ from typing import List
|
|||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_inference_providers() -> List[ProviderSpec]:
|
||||
def available_memory_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_id="meta-reference-faiss",
|
||||
pip_packages=[
|
||||
"faiss",
|
||||
"blobfile",
|
||||
"faiss-cpu",
|
||||
"sentence-transformers",
|
||||
],
|
||||
module="llama_toolchain.memory.meta_reference.faiss",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue