memory client works

This commit is contained in:
Ashwin Bharambe 2024-08-24 18:43:49 -07:00
parent a08958c000
commit 8d14d4228b
8 changed files with 164 additions and 86 deletions

View file

@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.inference.providers import available_inference_providers from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.memory.api.endpoints import Memory
from llama_toolchain.memory.providers import available_memory_providers
from llama_toolchain.safety.api.endpoints import Safety from llama_toolchain.safety.api.endpoints import Safety
from llama_toolchain.safety.providers import available_safety_providers from llama_toolchain.safety.providers import available_safety_providers
@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.inference: Inference, Api.inference: Inference,
Api.safety: Safety, Api.safety: Safety,
Api.agentic_system: AgenticSystem, Api.agentic_system: AgenticSystem,
Api.memory: Memory,
} }
for api, protocol in protocols.items(): for api, protocol in protocols.items():
@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
webmethod = method.__webmethod__ webmethod = method.__webmethod__
route = webmethod.route route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi if webmethod.method == "GET":
# annotation and write our own openapi generator method = "get"
endpoints.append(ApiEndpoint(route=route, method="post", name=name)) elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints apis[api] = endpoints
@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
Api.inference: inference_providers_by_id, Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id, Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id, Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
} }

View file

@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
), ),
DistributionSpec( DistributionSpec(
spec_id="test-memory", spec_id="test-memory",
description="Just a test distribution spec for testing memory bank APIs",
provider_specs={ provider_specs={
Api.memory: providers[Api.memory]["meta-reference-faiss"], Api.memory: providers[Api.memory]["meta-reference-faiss"],
}, },

View file

@ -5,8 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import inspect
import json import json
import signal import signal
import traceback
from collections.abc import ( from collections.abc import (
AsyncGenerator as AsyncGeneratorABC, AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC, AsyncIterator as AsyncIteratorABC,
@ -28,12 +30,13 @@ import fire
import httpx import httpx
import yaml import yaml
from fastapi import FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints from .distribution import api_endpoints
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc) http_exc = translate_exception(exc)
return JSONResponse( return JSONResponse(
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
return endpoint return endpoint
def create_dynamic_typed_route(func: Any): def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func) hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"] response_model = hints["return"]
# NOTE: I think it is better to just add a method within each Api # NOTE: I think it is better to just add a method within each Api
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
if is_streaming: if is_streaming:
async def endpoint(request: request_model): async def endpoint(**kwargs):
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
async for item in event_gen: async for item in event_gen:
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
print("Generator cancelled") print("Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
print(e) traceback.print_exception(e)
import traceback
traceback.print_exc()
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
) )
return StreamingResponse( return StreamingResponse(
sse_generator(func(request)), media_type="text/event-stream" sse_generator(func(**kwargs)), media_type="text/event-stream"
) )
else: else:
async def endpoint(request: request_model): async def endpoint(**kwargs):
try: try:
return ( return (
await func(request) await func(**kwargs)
if asyncio.iscoroutinefunction(func) if asyncio.iscoroutinefunction(func)
else func(request) else func(**kwargs)
) )
except Exception as e: except Exception as e:
print(e) traceback.print_exception(e)
import traceback
traceback.print_exc()
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func)
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(annotation=Annotated[param.annotation, Body()])
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
return endpoint return endpoint
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method) create_dynamic_typed_route(impl_method, endpoint.method)
) )
for route in app.routes: for route in app.routes:
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
attrs=["bold"], attrs=["bold"],
) )
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)

View file

@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel):
document_id: str document_id: str
content: InterleavedTextMedia | URL content: InterleavedTextMedia | URL
mime_type: str mime_type: str
metadata: Dict[str, Any] metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type @json_schema_type
@ -103,7 +103,7 @@ class Memory(Protocol):
@webmethod(route="/memory_banks/list", method="GET") @webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ... async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get") @webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE") @webmethod(route="/memory_banks/drop", method="DELETE")
@ -136,14 +136,14 @@ class Memory(Protocol):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ... ) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get") @webmethod(route="/memory_bank/documents/get", method="GET")
async def get_documents( async def get_documents(
self, self,
bank_id: str, bank_id: str,
document_ids: List[str], document_ids: List[str],
) -> List[MemoryBankDocument]: ... ) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete") @webmethod(route="/memory_bank/documents/delete", method="DELETE")
async def delete_documents( async def delete_documents(
self, self,
bank_id: str, bank_id: str,

View file

@ -34,19 +34,19 @@ class MemoryClient(Memory):
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.get( r = await client.get(
f"{self.base_url}/memory_banks/get", f"{self.base_url}/memory_banks/get",
params={ params={
"bank_id": bank_id, "bank_id": bank_id,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as r: )
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
if len(d) == 0: if not d:
return None return None
return MemoryBank(**d) return MemoryBank(**d)
async def create_memory_bank( async def create_memory_bank(
self, self,
@ -55,21 +55,21 @@ class MemoryClient(Memory):
url: Optional[URL] = None, url: Optional[URL] = None,
) -> MemoryBank: ) -> MemoryBank:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.post( r = await client.post(
f"{self.base_url}/memory_banks/create", f"{self.base_url}/memory_banks/create",
data={ json={
"name": name, "name": name,
"config": config.dict(), "config": config.dict(),
"url": url, "url": url,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as r: )
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
if len(d) == 0: if not d:
return None return None
return MemoryBank(**d) return MemoryBank(**d)
async def insert_documents( async def insert_documents(
self, self,
@ -77,16 +77,16 @@ class MemoryClient(Memory):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.post( r = await client.post(
f"{self.base_url}/memory_bank/insert", f"{self.base_url}/memory_bank/insert",
data={ json={
"bank_id": bank_id, "bank_id": bank_id,
"documents": documents, "documents": [d.dict() for d in documents],
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as r: )
r.raise_for_status() r.raise_for_status()
async def query_documents( async def query_documents(
self, self,
@ -95,18 +95,18 @@ class MemoryClient(Memory):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.post( r = await client.post(
f"{self.base_url}/memory_bank/query", f"{self.base_url}/memory_bank/query",
data={ json={
"bank_id": bank_id, "bank_id": bank_id,
"query": query, "query": query,
"params": params, "params": params,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as r: )
r.raise_for_status() r.raise_for_status()
return QueryDocumentsResponse(**r.json()) return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id) retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None assert retrieved_bank is not None
assert retrieved_bank.embedding_model == "dragon-roberta-query-2" assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
# insert some documents # insert some documents
await client.insert_documents( await client.insert_documents(
bank_id=bank.bank_id, bank_id=bank.bank_id,
documents=[ documents=documents,
MemoryBankDocument(
document_id="1",
content="hello world",
),
MemoryBankDocument(
document_id="2",
content="goodbye world",
),
],
) )
# query the documents # query the documents
response = await client.query_documents( response = await client.query_documents(
bank_id=bank.bank_id, bank_id=bank.bank_id,
query=[ query=[
"hello world", "How do I use Lora?",
], ],
) )
print(response) for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"Tell me more about llama3 and torchtune",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
def main(host: str, port: int, stream: bool = True): def main(host: str, port: int, stream: bool = True):

View file

@ -5,4 +5,4 @@
# the root directory of this source tree. # the root directory of this source tree.
from .config import FaissImplConfig # noqa from .config import FaissImplConfig # noqa
from .memory import get_provider_impl # noqa from .faiss import get_provider_impl # noqa

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import faiss import faiss
import httpx import httpx
import numpy as np import numpy as np
from sentence_transformers import SentenceTransformer
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
return await client.get(doc.content).text r = await client.get(doc.content.uri)
return r.text
def _process(c): def _process(c):
if isinstance(c, str): if isinstance(c, str):
@ -62,16 +63,17 @@ def make_overlapped_chunks(
return chunks return chunks
class BankState(BaseModel): @dataclass
class BankState:
bank: MemoryBank bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict) doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = Field(default_factory=dict) id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = Field(default_factory=dict) chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents( async def insert_documents(
self, self,
model: SentenceTransformer, model: "SentenceTransformer",
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
@ -97,21 +99,44 @@ class BankState(BaseModel):
content=chunk[0], content=chunk[0],
token_count=chunk[1], token_count=chunk[1],
) )
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id self.id_by_index[indexlen + i] = doc.document_id
async def query_documents( async def query_documents(
self, model: SentenceTransformer, query: str, params: Dict[str, Any] self,
) -> Tuple[List[Chunk], List[float]]: model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3) k = params.get("max_chunks", 3)
query_vector = model.encode([query])[0]
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
distances, indices = self.index.search( distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k query_vector.reshape(1, -1).astype(np.float32), k
) )
chunks = [self.chunk_by_index[int(i)] for i in indices[0]] chunks = []
scores = [1.0 / float(d) for d in distances[0]] scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return chunks, scores return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2: async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None: if self.index is None:
@ -122,7 +147,7 @@ class BankState(BaseModel):
class FaissMemoryImpl(Memory): class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig) -> None:
self.config = config self.config = config
self.model = SentenceTransformer("all-MiniLM-L6-v2") self.model = None
self.states = {} self.states = {}
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
config: MemoryBankConfig, config: MemoryBankConfig,
url: Optional[URL] = None, url: Optional[URL] = None,
) -> MemoryBank: ) -> MemoryBank:
print("Creating memory bank")
assert url is None, "URL is not supported for this implementation" assert url is None, "URL is not supported for this implementation"
assert ( assert (
config.type == MemoryBankType.vector.value config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}" ), f"Only vector banks are supported {config.type}"
id = str(uuid.uuid4()) bank_id = str(uuid.uuid4())
bank = MemoryBank( bank = MemoryBank(
bank_id=id, bank_id=bank_id,
name=name, name=name,
config=config, config=config,
url=url, url=url,
) )
state = BankState(bank=bank) state = BankState(bank=bank)
self.states[id] = state self.states[bank_id] = state
return bank return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found" assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id] state = self.states[bank_id]
await state.insert_documents(self.model, documents) await state.insert_documents(self.get_model(), documents)
async def query_documents( async def query_documents(
self, self,
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
assert bank_id in self.states, f"Bank {bank_id} not found" assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id] state = self.states[bank_id]
chunks, scores = await state.query_documents(self.model, query, params) return await state.query_documents(self.get_model(), query, params)
return QueryDocumentsResponse(chunk=chunks, scores=scores)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model

View file

@ -9,13 +9,14 @@ from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_inference_providers() -> List[ProviderSpec]: def available_memory_providers() -> List[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
provider_id="meta-reference-faiss", provider_id="meta-reference-faiss",
pip_packages=[ pip_packages=[
"faiss", "blobfile",
"faiss-cpu",
"sentence-transformers", "sentence-transformers",
], ],
module="llama_toolchain.memory.meta_reference.faiss", module="llama_toolchain.memory.meta_reference.faiss",