mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
|||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.memory.api.endpoints import Memory
|
||||
from llama_toolchain.memory.providers import available_memory_providers
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
|
||||
|
@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
Api.memory: Memory,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
|
||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
||||
# annotation and write our own openapi generator
|
||||
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
|
@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
),
|
||||
DistributionSpec(
|
||||
spec_id="test-memory",
|
||||
description="Just a test distribution spec for testing memory bank APIs",
|
||||
provider_specs={
|
||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||
},
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
|
@ -28,12 +30,13 @@ import fire
|
|||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
|
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
|
|||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
|
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any):
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
|
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def endpoint(**kwargs):
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
|
|||
)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(request)), media_type="text/event-stream"
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def endpoint(**kwargs):
|
||||
try:
|
||||
return (
|
||||
await func(request)
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(request)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(annotation=Annotated[param.annotation, Body()])
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
|
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method)
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
|
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
attrs=["bold"],
|
||||
)
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel):
|
|||
document_id: str
|
||||
content: InterleavedTextMedia | URL
|
||||
mime_type: str
|
||||
metadata: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -103,7 +103,7 @@ class Memory(Protocol):
|
|||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get")
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||
|
@ -136,14 +136,14 @@ class Memory(Protocol):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
@webmethod(route="/memory_bank/documents/get")
|
||||
@webmethod(route="/memory_bank/documents/get", method="GET")
|
||||
async def get_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
|
||||
@webmethod(route="/memory_bank/documents/delete")
|
||||
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
|
||||
async def delete_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
@ -34,19 +34,19 @@ class MemoryClient(Memory):
|
|||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.get(
|
||||
r = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"bank_id": bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if len(d) == 0:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
@ -55,21 +55,21 @@ class MemoryClient(Memory):
|
|||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_banks/create",
|
||||
data={
|
||||
json={
|
||||
"name": name,
|
||||
"config": config.dict(),
|
||||
"url": url,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if len(d) == 0:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -77,16 +77,16 @@ class MemoryClient(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_bank/insert",
|
||||
data={
|
||||
json={
|
||||
"bank_id": bank_id,
|
||||
"documents": documents,
|
||||
"documents": [d.dict() for d in documents],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -95,18 +95,18 @@ class MemoryClient(Memory):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.post(
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory_bank/query",
|
||||
data={
|
||||
json={
|
||||
"bank_id": bank_id,
|
||||
"query": query,
|
||||
"params": params,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
return QueryDocumentsResponse(**r.json())
|
||||
)
|
||||
r.raise_for_status()
|
||||
return QueryDocumentsResponse(**r.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
|
||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=URL(
|
||||
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||
),
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
# insert some documents
|
||||
await client.insert_documents(
|
||||
bank_id=bank.bank_id,
|
||||
documents=[
|
||||
MemoryBankDocument(
|
||||
document_id="1",
|
||||
content="hello world",
|
||||
),
|
||||
MemoryBankDocument(
|
||||
document_id="2",
|
||||
content="goodbye world",
|
||||
),
|
||||
],
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
# query the documents
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
query=[
|
||||
"hello world",
|
||||
"How do I use Lora?",
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
for chunk, score in zip(response.chunks, response.scores):
|
||||
print(f"Score: {score}")
|
||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
query=[
|
||||
"Tell me more about llama3 and torchtune",
|
||||
],
|
||||
)
|
||||
for chunk, score in zip(response.chunks, response.scores):
|
||||
print(f"Score: {score}")
|
||||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
|
|
|
@ -5,4 +5,4 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import FaissImplConfig # noqa
|
||||
from .memory import get_provider_impl # noqa
|
||||
from .faiss import get_provider_impl # noqa
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import faiss
|
||||
import httpx
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
|
|||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.get(doc.content).text
|
||||
r = await client.get(doc.content.uri)
|
||||
return r.text
|
||||
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
|
@ -62,16 +63,17 @@ def make_overlapped_chunks(
|
|||
return chunks
|
||||
|
||||
|
||||
class BankState(BaseModel):
|
||||
@dataclass
|
||||
class BankState:
|
||||
bank: MemoryBank
|
||||
index: Optional[faiss.IndexFlatL2] = None
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
||||
id_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
model: SentenceTransformer,
|
||||
model: "SentenceTransformer",
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
@ -97,21 +99,44 @@ class BankState(BaseModel):
|
|||
content=chunk[0],
|
||||
token_count=chunk[1],
|
||||
)
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||
self.id_by_index[indexlen + i] = doc.document_id
|
||||
|
||||
async def query_documents(
|
||||
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
|
||||
) -> Tuple[List[Chunk], List[float]]:
|
||||
self,
|
||||
model: "SentenceTransformer",
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
query_vector = model.encode([query])[0]
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
if isinstance(query, list):
|
||||
query_str = " ".join([_process(c) for c in query])
|
||||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
query_vector = model.encode([query_str])[0]
|
||||
distances, indices = self.index.search(
|
||||
query_vector.reshape(1, -1).astype(np.float32), k
|
||||
)
|
||||
|
||||
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
|
||||
scores = [1.0 / float(d) for d in distances[0]]
|
||||
chunks = []
|
||||
scores = []
|
||||
for d, i in zip(distances[0], indices[0]):
|
||||
if i < 0:
|
||||
continue
|
||||
chunks.append(self.chunk_by_index[int(i)])
|
||||
scores.append(1.0 / float(d))
|
||||
|
||||
return chunks, scores
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||
if self.index is None:
|
||||
|
@ -122,7 +147,7 @@ class BankState(BaseModel):
|
|||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
self.model = None
|
||||
self.states = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
|
|||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
print("Creating memory bank")
|
||||
assert url is None, "URL is not supported for this implementation"
|
||||
assert (
|
||||
config.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {config.type}"
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=id,
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
state = BankState(bank=bank)
|
||||
self.states[id] = state
|
||||
self.states[bank_id] = state
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
|
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
|
|||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
|
||||
await state.insert_documents(self.model, documents)
|
||||
await state.insert_documents(self.get_model(), documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
|
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
|
|||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||
state = self.states[bank_id]
|
||||
|
||||
chunks, scores = await state.query_documents(self.model, query, params)
|
||||
return QueryDocumentsResponse(chunk=chunks, scores=scores)
|
||||
return await state.query_documents(self.get_model(), query, params)
|
||||
|
||||
def get_model(self) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
if self.model is None:
|
||||
print("Loading sentence transformer")
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return self.model
|
|
@ -9,13 +9,14 @@ from typing import List
|
|||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_inference_providers() -> List[ProviderSpec]:
|
||||
def available_memory_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_id="meta-reference-faiss",
|
||||
pip_packages=[
|
||||
"faiss",
|
||||
"blobfile",
|
||||
"faiss-cpu",
|
||||
"sentence-transformers",
|
||||
],
|
||||
module="llama_toolchain.memory.meta_reference.faiss",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue