memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.memory.api.endpoints import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api.endpoints import Safety
from llama_toolchain.safety.providers import available_safety_providers
@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.memory: Memory,
}
for api, protocol in protocols.items():
@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
webmethod = method.__webmethod__
route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
}

View file

@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
),
DistributionSpec(
spec_id="test-memory",
description="Just a test distribution spec for testing memory bank APIs",
provider_specs={
Api.memory: providers[Api.memory]["meta-reference-faiss"],
},

View file

@ -5,8 +5,10 @@
# the root directory of this source tree.
import asyncio
import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
@ -28,12 +30,13 @@ import fire
import httpx
import yaml
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
return endpoint
def create_dynamic_typed_route(func: Any):
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
if is_streaming:
async def endpoint(request: request_model):
async def endpoint(**kwargs):
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
print(e)
import traceback
traceback.print_exc()
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
)
return StreamingResponse(
sse_generator(func(request)), media_type="text/event-stream"
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: request_model):
async def endpoint(**kwargs):
try:
return (
await func(request)
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(request)
else func(**kwargs)
)
except Exception as e:
print(e)
import traceback
traceback.print_exc()
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(annotation=Annotated[param.annotation, Body()])
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method)
create_dynamic_typed_route(impl_method, endpoint.method)
)
for route in app.routes:
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)

View file

@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedTextMedia | URL
mime_type: str
metadata: Dict[str, Any]
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
@ -103,7 +103,7 @@ class Memory(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
@ -136,14 +136,14 @@ class Memory(Protocol):
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get")
@webmethod(route="/memory_bank/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete")
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,

View file

@ -34,19 +34,19 @@ class MemoryClient(Memory):
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
async with client.get(
r = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
@ -55,21 +55,21 @@ class MemoryClient(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_banks/create",
data={
json={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents(
self,
@ -77,16 +77,16 @@ class MemoryClient(Memory):
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_bank/insert",
data={
json={
"bank_id": bank_id,
"documents": documents,
"documents": [d.dict() for d in documents],
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
)
r.raise_for_status()
async def query_documents(
self,
@ -95,18 +95,18 @@ class MemoryClient(Memory):
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
async with client.post(
r = await client.post(
f"{self.base_url}/memory_bank/query",
data={
json={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
)
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=[
MemoryBankDocument(
document_id="1",
content="hello world",
),
MemoryBankDocument(
document_id="2",
content="goodbye world",
),
],
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"hello world",
"How do I use Lora?",
],
)
print(response)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"Tell me more about llama3 and torchtune",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
def main(host: str, port: int, stream: bool = True):

View file

@ -5,4 +5,4 @@
# the root directory of this source tree.
from .config import FaissImplConfig # noqa
from .memory import get_provider_impl # noqa
from .faiss import get_provider_impl # noqa

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import faiss
import httpx
import numpy as np
from sentence_transformers import SentenceTransformer
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
return await client.get(doc.content).text
r = await client.get(doc.content.uri)
return r.text
def _process(c):
if isinstance(c, str):
@ -62,16 +63,17 @@ def make_overlapped_chunks(
return chunks
class BankState(BaseModel):
@dataclass
class BankState:
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
id_by_index: Dict[int, str] = Field(default_factory=dict)
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents(
self,
model: SentenceTransformer,
model: "SentenceTransformer",
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
@ -97,21 +99,44 @@ class BankState(BaseModel):
content=chunk[0],
token_count=chunk[1],
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id
async def query_documents(
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
) -> Tuple[List[Chunk], List[float]]:
self,
model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
query_vector = model.encode([query])[0]
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
)
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
scores = [1.0 / float(d) for d in distances[0]]
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return chunks, scores
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
@ -122,7 +147,7 @@ class BankState(BaseModel):
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.model = None
self.states = {}
async def initialize(self) -> None: ...
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
print("Creating memory bank")
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
id = str(uuid.uuid4())
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=id,
bank_id=bank_id,
name=name,
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[id] = state
self.states[bank_id] = state
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
await state.insert_documents(self.model, documents)
await state.insert_documents(self.get_model(), documents)
async def query_documents(
self,
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
chunks, scores = await state.query_documents(self.model, query, params)
return QueryDocumentsResponse(chunk=chunks, scores=scores)
return await state.query_documents(self.get_model(), query, params)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model

View file

@ -9,13 +9,14 @@ from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_inference_providers() -> List[ProviderSpec]:
def available_memory_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=[
"faiss",
"blobfile",
"faiss-cpu",
"sentence-transformers",
],
module="llama_toolchain.memory.meta_reference.faiss",