diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index f96d0cac6..7294392a2 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem from llama_toolchain.agentic_system.providers import available_agentic_system_providers from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.inference.providers import available_inference_providers +from llama_toolchain.memory.api.endpoints import Memory +from llama_toolchain.memory.providers import available_memory_providers from llama_toolchain.safety.api.endpoints import Safety from llama_toolchain.safety.providers import available_safety_providers @@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.inference: Inference, Api.safety: Safety, Api.agentic_system: AgenticSystem, + Api.memory: Memory, } for api, protocol in protocols.items(): @@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: webmethod = method.__webmethod__ route = webmethod.route - # use `post` for all methods right now until we fix up the `webmethod` openapi - # annotation and write our own openapi generator - endpoints.append(ApiEndpoint(route=route, method="post", name=name)) + if webmethod.method == "GET": + method = "get" + elif webmethod.method == "DELETE": + method = "delete" + else: + method = "post" + endpoints.append(ApiEndpoint(route=route, method=method, name=name)) apis[api] = endpoints @@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: Api.inference: inference_providers_by_id, Api.safety: safety_providers_by_id, Api.agentic_system: agentic_system_providers_by_id, + Api.memory: {a.provider_id: a for a in available_memory_providers()}, } diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index acba4e874..33d6e8e2a 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]: ), DistributionSpec( spec_id="test-memory", + description="Just a test distribution spec for testing memory bank APIs", provider_specs={ Api.memory: providers[Api.memory]["meta-reference-faiss"], }, diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 8707fa9ed..90c5a9a0f 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -5,8 +5,10 @@ # the root directory of this source tree. import asyncio +import inspect import json import signal +import traceback from collections.abc import ( AsyncGenerator as AsyncGeneratorABC, AsyncIterator as AsyncIteratorABC, @@ -28,12 +30,13 @@ import fire import httpx import yaml -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint +from typing_extensions import Annotated from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec from .distribution import api_endpoints @@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str: async def global_exception_handler(request: Request, exc: Exception): + traceback.print_exception(exc) http_exc = translate_exception(exc) return JSONResponse( @@ -155,9 +159,8 @@ def create_dynamic_passthrough( return endpoint -def create_dynamic_typed_route(func: Any): +def create_dynamic_typed_route(func: Any, method: str): hints = get_type_hints(func) - request_model = next(iter(hints.values())) response_model = hints["return"] # NOTE: I think it is better to just add a method within each Api @@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any): if is_streaming: - async def endpoint(request: request_model): + async def endpoint(**kwargs): async def sse_generator(event_gen): try: async for item in event_gen: @@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any): print("Generator cancelled") await event_gen.aclose() except Exception as e: - print(e) - import traceback - - traceback.print_exc() + traceback.print_exception(e) yield create_sse_event( { "error": { @@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any): ) return StreamingResponse( - sse_generator(func(request)), media_type="text/event-stream" + sse_generator(func(**kwargs)), media_type="text/event-stream" ) else: - async def endpoint(request: request_model): + async def endpoint(**kwargs): try: return ( - await func(request) + await func(**kwargs) if asyncio.iscoroutinefunction(func) - else func(request) + else func(**kwargs) ) except Exception as e: - print(e) - import traceback - - traceback.print_exc() + traceback.print_exception(e) raise translate_exception(e) from e + sig = inspect.signature(func) + if method == "post": + # make sure every parameter is annotated with Body() so FASTAPI doesn't + # do anything too intelligent and ask for some parameters in the query + # and some in the body + endpoint.__signature__ = sig.replace( + parameters=[ + param.replace(annotation=Annotated[param.annotation, Body()]) + for param in sig.parameters.values() + ] + ) + else: + endpoint.__signature__ = sig + return endpoint @@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): impl_method = getattr(impl, endpoint.name) getattr(app, endpoint.method)(endpoint.route, response_model=None)( - create_dynamic_typed_route(impl_method) + create_dynamic_typed_route(impl_method, endpoint.method) ) for route in app.routes: @@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): attrs=["bold"], ) + app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) signal.signal(signal.SIGINT, handle_sigint) diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index d4f1d5e20..9299872e3 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel): document_id: str content: InterleavedTextMedia | URL mime_type: str - metadata: Dict[str, Any] + metadata: Dict[str, Any] = Field(default_factory=dict) @json_schema_type @@ -103,7 +103,7 @@ class Memory(Protocol): @webmethod(route="/memory_banks/list", method="GET") async def list_memory_banks(self) -> List[MemoryBank]: ... - @webmethod(route="/memory_banks/get") + @webmethod(route="/memory_banks/get", method="GET") async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... @webmethod(route="/memory_banks/drop", method="DELETE") @@ -136,14 +136,14 @@ class Memory(Protocol): params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: ... - @webmethod(route="/memory_bank/documents/get") + @webmethod(route="/memory_bank/documents/get", method="GET") async def get_documents( self, bank_id: str, document_ids: List[str], ) -> List[MemoryBankDocument]: ... - @webmethod(route="/memory_bank/documents/delete") + @webmethod(route="/memory_bank/documents/delete", method="DELETE") async def delete_documents( self, bank_id: str, diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index 128d7cdd7..d4009a190 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -34,19 +34,19 @@ class MemoryClient(Memory): async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async with httpx.AsyncClient() as client: - async with client.get( + r = await client.get( f"{self.base_url}/memory_banks/get", params={ "bank_id": bank_id, }, headers={"Content-Type": "application/json"}, timeout=20, - ) as r: - r.raise_for_status() - d = r.json() - if len(d) == 0: - return None - return MemoryBank(**d) + ) + r.raise_for_status() + d = r.json() + if not d: + return None + return MemoryBank(**d) async def create_memory_bank( self, @@ -55,21 +55,21 @@ class MemoryClient(Memory): url: Optional[URL] = None, ) -> MemoryBank: async with httpx.AsyncClient() as client: - async with client.post( + r = await client.post( f"{self.base_url}/memory_banks/create", - data={ + json={ "name": name, "config": config.dict(), "url": url, }, headers={"Content-Type": "application/json"}, timeout=20, - ) as r: - r.raise_for_status() - d = r.json() - if len(d) == 0: - return None - return MemoryBank(**d) + ) + r.raise_for_status() + d = r.json() + if not d: + return None + return MemoryBank(**d) async def insert_documents( self, @@ -77,16 +77,16 @@ class MemoryClient(Memory): documents: List[MemoryBankDocument], ) -> None: async with httpx.AsyncClient() as client: - async with client.post( + r = await client.post( f"{self.base_url}/memory_bank/insert", - data={ + json={ "bank_id": bank_id, - "documents": documents, + "documents": [d.dict() for d in documents], }, headers={"Content-Type": "application/json"}, timeout=20, - ) as r: - r.raise_for_status() + ) + r.raise_for_status() async def query_documents( self, @@ -95,18 +95,18 @@ class MemoryClient(Memory): params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: async with httpx.AsyncClient() as client: - async with client.post( + r = await client.post( f"{self.base_url}/memory_bank/query", - data={ + json={ "bank_id": bank_id, "query": query, "params": params, }, headers={"Content-Type": "application/json"}, timeout=20, - ) as r: - r.raise_for_status() - return QueryDocumentsResponse(**r.json()) + ) + r.raise_for_status() + return QueryDocumentsResponse(**r.json()) async def run_main(host: str, port: int, stream: bool): @@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool): retrieved_bank = await client.get_memory_bank(bank.bank_id) assert retrieved_bank is not None - assert retrieved_bank.embedding_model == "dragon-roberta-query-2" + assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2" + + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + documents = [ + MemoryBankDocument( + document_id=f"num-{i}", + content=URL( + uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}" + ), + mime_type="text/plain", + ) + for i, url in enumerate(urls) + ] # insert some documents await client.insert_documents( bank_id=bank.bank_id, - documents=[ - MemoryBankDocument( - document_id="1", - content="hello world", - ), - MemoryBankDocument( - document_id="2", - content="goodbye world", - ), - ], + documents=documents, ) # query the documents response = await client.query_documents( bank_id=bank.bank_id, query=[ - "hello world", + "How do I use Lora?", ], ) - print(response) + for chunk, score in zip(response.chunks, response.scores): + print(f"Score: {score}") + print(f"Chunk:\n========\n{chunk}\n========\n") + + response = await client.query_documents( + bank_id=bank.bank_id, + query=[ + "Tell me more about llama3 and torchtune", + ], + ) + for chunk, score in zip(response.chunks, response.scores): + print(f"Score: {score}") + print(f"Chunk:\n========\n{chunk}\n========\n") def main(host: str, port: int, stream: bool = True): diff --git a/llama_toolchain/memory/meta_reference/faiss/__init__.py b/llama_toolchain/memory/meta_reference/faiss/__init__.py index dda96f370..69a1a06b7 100644 --- a/llama_toolchain/memory/meta_reference/faiss/__init__.py +++ b/llama_toolchain/memory/meta_reference/faiss/__init__.py @@ -5,4 +5,4 @@ # the root directory of this source tree. from .config import FaissImplConfig # noqa -from .memory import get_provider_impl # noqa +from .faiss import get_provider_impl # noqa diff --git a/llama_toolchain/memory/meta_reference/faiss/memory.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py similarity index 70% rename from llama_toolchain/memory/meta_reference/faiss/memory.py rename to llama_toolchain/memory/meta_reference/faiss/faiss.py index 2322b8519..0558a6eda 100644 --- a/llama_toolchain/memory/meta_reference/faiss/memory.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import faiss import httpx import numpy as np -from sentence_transformers import SentenceTransformer - from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer @@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp async def content_from_doc(doc: MemoryBankDocument) -> str: if isinstance(doc.content, URL): async with httpx.AsyncClient() as client: - return await client.get(doc.content).text + r = await client.get(doc.content.uri) + return r.text def _process(c): if isinstance(c, str): @@ -62,16 +63,17 @@ def make_overlapped_chunks( return chunks -class BankState(BaseModel): +@dataclass +class BankState: bank: MemoryBank index: Optional[faiss.IndexFlatL2] = None - doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict) - id_by_index: Dict[int, str] = Field(default_factory=dict) - chunk_by_index: Dict[int, str] = Field(default_factory=dict) + doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict) + id_by_index: Dict[int, str] = field(default_factory=dict) + chunk_by_index: Dict[int, str] = field(default_factory=dict) async def insert_documents( self, - model: SentenceTransformer, + model: "SentenceTransformer", documents: List[MemoryBankDocument], ) -> None: tokenizer = Tokenizer.get_instance() @@ -97,21 +99,44 @@ class BankState(BaseModel): content=chunk[0], token_count=chunk[1], ) + print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}") self.id_by_index[indexlen + i] = doc.document_id async def query_documents( - self, model: SentenceTransformer, query: str, params: Dict[str, Any] - ) -> Tuple[List[Chunk], List[float]]: + self, + model: "SentenceTransformer", + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + if params is None: + params = {} k = params.get("max_chunks", 3) - query_vector = model.encode([query])[0] + + def _process(c) -> str: + if isinstance(c, str): + return c + else: + return "" + + if isinstance(query, list): + query_str = " ".join([_process(c) for c in query]) + else: + query_str = _process(query) + + query_vector = model.encode([query_str])[0] distances, indices = self.index.search( query_vector.reshape(1, -1).astype(np.float32), k ) - chunks = [self.chunk_by_index[int(i)] for i in indices[0]] - scores = [1.0 / float(d) for d in distances[0]] + chunks = [] + scores = [] + for d, i in zip(distances[0], indices[0]): + if i < 0: + continue + chunks.append(self.chunk_by_index[int(i)]) + scores.append(1.0 / float(d)) - return chunks, scores + return QueryDocumentsResponse(chunks=chunks, scores=scores) async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2: if self.index is None: @@ -122,7 +147,7 @@ class BankState(BaseModel): class FaissMemoryImpl(Memory): def __init__(self, config: FaissImplConfig) -> None: self.config = config - self.model = SentenceTransformer("all-MiniLM-L6-v2") + self.model = None self.states = {} async def initialize(self) -> None: ... @@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory): config: MemoryBankConfig, url: Optional[URL] = None, ) -> MemoryBank: + print("Creating memory bank") assert url is None, "URL is not supported for this implementation" assert ( config.type == MemoryBankType.vector.value ), f"Only vector banks are supported {config.type}" - id = str(uuid.uuid4()) + bank_id = str(uuid.uuid4()) bank = MemoryBank( - bank_id=id, + bank_id=bank_id, name=name, config=config, url=url, ) state = BankState(bank=bank) - self.states[id] = state + self.states[bank_id] = state return bank async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: @@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory): assert bank_id in self.states, f"Bank {bank_id} not found" state = self.states[bank_id] - await state.insert_documents(self.model, documents) + await state.insert_documents(self.get_model(), documents) async def query_documents( self, @@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory): assert bank_id in self.states, f"Bank {bank_id} not found" state = self.states[bank_id] - chunks, scores = await state.query_documents(self.model, query, params) - return QueryDocumentsResponse(chunk=chunks, scores=scores) + return await state.query_documents(self.get_model(), query, params) + + def get_model(self) -> "SentenceTransformer": + from sentence_transformers import SentenceTransformer + + if self.model is None: + print("Loading sentence transformer") + self.model = SentenceTransformer("all-MiniLM-L6-v2") + + return self.model diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index bfa098d36..0717e2340 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -9,13 +9,14 @@ from typing import List from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_inference_providers() -> List[ProviderSpec]: +def available_memory_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.memory, provider_id="meta-reference-faiss", pip_packages=[ - "faiss", + "blobfile", + "faiss-cpu", "sentence-transformers", ], module="llama_toolchain.memory.meta_reference.faiss",