mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
memory client works
This commit is contained in:
parent
a08958c000
commit
8d14d4228b
8 changed files with 164 additions and 86 deletions
|
@ -11,6 +11,8 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
||||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||||
from llama_toolchain.inference.api.endpoints import Inference
|
from llama_toolchain.inference.api.endpoints import Inference
|
||||||
from llama_toolchain.inference.providers import available_inference_providers
|
from llama_toolchain.inference.providers import available_inference_providers
|
||||||
|
from llama_toolchain.memory.api.endpoints import Memory
|
||||||
|
from llama_toolchain.memory.providers import available_memory_providers
|
||||||
from llama_toolchain.safety.api.endpoints import Safety
|
from llama_toolchain.safety.api.endpoints import Safety
|
||||||
from llama_toolchain.safety.providers import available_safety_providers
|
from llama_toolchain.safety.providers import available_safety_providers
|
||||||
|
|
||||||
|
@ -47,6 +49,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.agentic_system: AgenticSystem,
|
Api.agentic_system: AgenticSystem,
|
||||||
|
Api.memory: Memory,
|
||||||
}
|
}
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
|
@ -60,9 +63,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
webmethod = method.__webmethod__
|
webmethod = method.__webmethod__
|
||||||
route = webmethod.route
|
route = webmethod.route
|
||||||
|
|
||||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
if webmethod.method == "GET":
|
||||||
# annotation and write our own openapi generator
|
method = "get"
|
||||||
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
|
elif webmethod.method == "DELETE":
|
||||||
|
method = "delete"
|
||||||
|
else:
|
||||||
|
method = "post"
|
||||||
|
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||||
|
|
||||||
apis[api] = endpoints
|
apis[api] = endpoints
|
||||||
|
|
||||||
|
@ -82,4 +89,5 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
Api.inference: inference_providers_by_id,
|
Api.inference: inference_providers_by_id,
|
||||||
Api.safety: safety_providers_by_id,
|
Api.safety: safety_providers_by_id,
|
||||||
Api.agentic_system: agentic_system_providers_by_id,
|
Api.agentic_system: agentic_system_providers_by_id,
|
||||||
|
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
spec_id="test-memory",
|
spec_id="test-memory",
|
||||||
|
description="Just a test distribution spec for testing memory bank APIs",
|
||||||
provider_specs={
|
provider_specs={
|
||||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||||
},
|
},
|
||||||
|
|
|
@ -5,8 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
|
import traceback
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
AsyncGenerator as AsyncGeneratorABC,
|
AsyncGenerator as AsyncGeneratorABC,
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
|
@ -28,12 +30,13 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints
|
from .distribution import api_endpoints
|
||||||
|
@ -66,6 +69,7 @@ def create_sse_event(data: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
traceback.print_exception(exc)
|
||||||
http_exc = translate_exception(exc)
|
http_exc = translate_exception(exc)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
@ -155,9 +159,8 @@ def create_dynamic_passthrough(
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any):
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
hints = get_type_hints(func)
|
hints = get_type_hints(func)
|
||||||
request_model = next(iter(hints.values()))
|
|
||||||
response_model = hints["return"]
|
response_model = hints["return"]
|
||||||
|
|
||||||
# NOTE: I think it is better to just add a method within each Api
|
# NOTE: I think it is better to just add a method within each Api
|
||||||
|
@ -168,7 +171,7 @@ def create_dynamic_typed_route(func: Any):
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
|
||||||
async def endpoint(request: request_model):
|
async def endpoint(**kwargs):
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
|
@ -178,10 +181,7 @@ def create_dynamic_typed_route(func: Any):
|
||||||
print("Generator cancelled")
|
print("Generator cancelled")
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
traceback.print_exception(e)
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -191,25 +191,36 @@ def create_dynamic_typed_route(func: Any):
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sse_generator(func(request)), media_type="text/event-stream"
|
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def endpoint(request: request_model):
|
async def endpoint(**kwargs):
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
await func(request)
|
await func(**kwargs)
|
||||||
if asyncio.iscoroutinefunction(func)
|
if asyncio.iscoroutinefunction(func)
|
||||||
else func(request)
|
else func(**kwargs)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
traceback.print_exception(e)
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
if method == "post":
|
||||||
|
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||||
|
# do anything too intelligent and ask for some parameters in the query
|
||||||
|
# and some in the body
|
||||||
|
endpoint.__signature__ = sig.replace(
|
||||||
|
parameters=[
|
||||||
|
param.replace(annotation=Annotated[param.annotation, Body()])
|
||||||
|
for param in sig.parameters.values()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
endpoint.__signature__ = sig
|
||||||
|
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -296,7 +307,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(impl_method)
|
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||||
)
|
)
|
||||||
|
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
|
@ -307,6 +318,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, handle_sigint)
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class MemoryBankDocument(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedTextMedia | URL
|
content: InterleavedTextMedia | URL
|
||||||
mime_type: str
|
mime_type: str
|
||||||
metadata: Dict[str, Any]
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -103,7 +103,7 @@ class Memory(Protocol):
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
@webmethod(route="/memory_banks/get", method="GET")
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop", method="DELETE")
|
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||||
|
@ -136,14 +136,14 @@ class Memory(Protocol):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse: ...
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/documents/get")
|
@webmethod(route="/memory_bank/documents/get", method="GET")
|
||||||
async def get_documents(
|
async def get_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_ids: List[str],
|
document_ids: List[str],
|
||||||
) -> List[MemoryBankDocument]: ...
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/documents/delete")
|
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
|
||||||
async def delete_documents(
|
async def delete_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
|
|
|
@ -34,19 +34,19 @@ class MemoryClient(Memory):
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.get(
|
r = await client.get(
|
||||||
f"{self.base_url}/memory_banks/get",
|
f"{self.base_url}/memory_banks/get",
|
||||||
params={
|
params={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as r:
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
if len(d) == 0:
|
if not d:
|
||||||
return None
|
return None
|
||||||
return MemoryBank(**d)
|
return MemoryBank(**d)
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
|
@ -55,21 +55,21 @@ class MemoryClient(Memory):
|
||||||
url: Optional[URL] = None,
|
url: Optional[URL] = None,
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.post(
|
r = await client.post(
|
||||||
f"{self.base_url}/memory_banks/create",
|
f"{self.base_url}/memory_banks/create",
|
||||||
data={
|
json={
|
||||||
"name": name,
|
"name": name,
|
||||||
"config": config.dict(),
|
"config": config.dict(),
|
||||||
"url": url,
|
"url": url,
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as r:
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
if len(d) == 0:
|
if not d:
|
||||||
return None
|
return None
|
||||||
return MemoryBank(**d)
|
return MemoryBank(**d)
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -77,16 +77,16 @@ class MemoryClient(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.post(
|
r = await client.post(
|
||||||
f"{self.base_url}/memory_bank/insert",
|
f"{self.base_url}/memory_bank/insert",
|
||||||
data={
|
json={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
"documents": documents,
|
"documents": [d.dict() for d in documents],
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as r:
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -95,18 +95,18 @@ class MemoryClient(Memory):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.post(
|
r = await client.post(
|
||||||
f"{self.base_url}/memory_bank/query",
|
f"{self.base_url}/memory_bank/query",
|
||||||
data={
|
json={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
"query": query,
|
"query": query,
|
||||||
"params": params,
|
"params": params,
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as r:
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return QueryDocumentsResponse(**r.json())
|
return QueryDocumentsResponse(**r.json())
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
|
@ -126,31 +126,53 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
|
|
||||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||||
assert retrieved_bank is not None
|
assert retrieved_bank is not None
|
||||||
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
|
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||||
|
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=URL(
|
||||||
|
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||||
|
),
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
# insert some documents
|
# insert some documents
|
||||||
await client.insert_documents(
|
await client.insert_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.bank_id,
|
||||||
documents=[
|
documents=documents,
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="1",
|
|
||||||
content="hello world",
|
|
||||||
),
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="2",
|
|
||||||
content="goodbye world",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# query the documents
|
# query the documents
|
||||||
response = await client.query_documents(
|
response = await client.query_documents(
|
||||||
bank_id=bank.bank_id,
|
bank_id=bank.bank_id,
|
||||||
query=[
|
query=[
|
||||||
"hello world",
|
"How do I use Lora?",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
print(response)
|
for chunk, score in zip(response.chunks, response.scores):
|
||||||
|
print(f"Score: {score}")
|
||||||
|
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||||
|
|
||||||
|
response = await client.query_documents(
|
||||||
|
bank_id=bank.bank_id,
|
||||||
|
query=[
|
||||||
|
"Tell me more about llama3 and torchtune",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for chunk, score in zip(response.chunks, response.scores):
|
||||||
|
print(f"Score: {score}")
|
||||||
|
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, stream: bool = True):
|
def main(host: str, port: int, stream: bool = True):
|
||||||
|
|
|
@ -5,4 +5,4 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import FaissImplConfig # noqa
|
from .config import FaissImplConfig # noqa
|
||||||
from .memory import get_provider_impl # noqa
|
from .faiss import get_provider_impl # noqa
|
||||||
|
|
|
@ -4,13 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
@ -33,7 +33,8 @@ async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSp
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
return await client.get(doc.content).text
|
r = await client.get(doc.content.uri)
|
||||||
|
return r.text
|
||||||
|
|
||||||
def _process(c):
|
def _process(c):
|
||||||
if isinstance(c, str):
|
if isinstance(c, str):
|
||||||
|
@ -62,16 +63,17 @@ def make_overlapped_chunks(
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
class BankState(BaseModel):
|
@dataclass
|
||||||
|
class BankState:
|
||||||
bank: MemoryBank
|
bank: MemoryBank
|
||||||
index: Optional[faiss.IndexFlatL2] = None
|
index: Optional[faiss.IndexFlatL2] = None
|
||||||
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
|
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
||||||
id_by_index: Dict[int, str] = Field(default_factory=dict)
|
id_by_index: Dict[int, str] = field(default_factory=dict)
|
||||||
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
|
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
model: SentenceTransformer,
|
model: "SentenceTransformer",
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
@ -97,21 +99,44 @@ class BankState(BaseModel):
|
||||||
content=chunk[0],
|
content=chunk[0],
|
||||||
token_count=chunk[1],
|
token_count=chunk[1],
|
||||||
)
|
)
|
||||||
|
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||||
self.id_by_index[indexlen + i] = doc.document_id
|
self.id_by_index[indexlen + i] = doc.document_id
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
|
self,
|
||||||
) -> Tuple[List[Chunk], List[float]]:
|
model: "SentenceTransformer",
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
query_vector = model.encode([query])[0]
|
|
||||||
|
def _process(c) -> str:
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
if isinstance(query, list):
|
||||||
|
query_str = " ".join([_process(c) for c in query])
|
||||||
|
else:
|
||||||
|
query_str = _process(query)
|
||||||
|
|
||||||
|
query_vector = model.encode([query_str])[0]
|
||||||
distances, indices = self.index.search(
|
distances, indices = self.index.search(
|
||||||
query_vector.reshape(1, -1).astype(np.float32), k
|
query_vector.reshape(1, -1).astype(np.float32), k
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
|
chunks = []
|
||||||
scores = [1.0 / float(d) for d in distances[0]]
|
scores = []
|
||||||
|
for d, i in zip(distances[0], indices[0]):
|
||||||
|
if i < 0:
|
||||||
|
continue
|
||||||
|
chunks.append(self.chunk_by_index[int(i)])
|
||||||
|
scores.append(1.0 / float(d))
|
||||||
|
|
||||||
return chunks, scores
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||||
if self.index is None:
|
if self.index is None:
|
||||||
|
@ -122,7 +147,7 @@ class BankState(BaseModel):
|
||||||
class FaissMemoryImpl(Memory):
|
class FaissMemoryImpl(Memory):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
self.model = None
|
||||||
self.states = {}
|
self.states = {}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
@ -135,20 +160,21 @@ class FaissMemoryImpl(Memory):
|
||||||
config: MemoryBankConfig,
|
config: MemoryBankConfig,
|
||||||
url: Optional[URL] = None,
|
url: Optional[URL] = None,
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
|
print("Creating memory bank")
|
||||||
assert url is None, "URL is not supported for this implementation"
|
assert url is None, "URL is not supported for this implementation"
|
||||||
assert (
|
assert (
|
||||||
config.type == MemoryBankType.vector.value
|
config.type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {config.type}"
|
), f"Only vector banks are supported {config.type}"
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
bank_id = str(uuid.uuid4())
|
||||||
bank = MemoryBank(
|
bank = MemoryBank(
|
||||||
bank_id=id,
|
bank_id=bank_id,
|
||||||
name=name,
|
name=name,
|
||||||
config=config,
|
config=config,
|
||||||
url=url,
|
url=url,
|
||||||
)
|
)
|
||||||
state = BankState(bank=bank)
|
state = BankState(bank=bank)
|
||||||
self.states[id] = state
|
self.states[bank_id] = state
|
||||||
return bank
|
return bank
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
@ -164,7 +190,7 @@ class FaissMemoryImpl(Memory):
|
||||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||||
state = self.states[bank_id]
|
state = self.states[bank_id]
|
||||||
|
|
||||||
await state.insert_documents(self.model, documents)
|
await state.insert_documents(self.get_model(), documents)
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -175,5 +201,13 @@ class FaissMemoryImpl(Memory):
|
||||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||||
state = self.states[bank_id]
|
state = self.states[bank_id]
|
||||||
|
|
||||||
chunks, scores = await state.query_documents(self.model, query, params)
|
return await state.query_documents(self.get_model(), query, params)
|
||||||
return QueryDocumentsResponse(chunk=chunks, scores=scores)
|
|
||||||
|
def get_model(self) -> "SentenceTransformer":
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
print("Loading sentence transformer")
|
||||||
|
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
return self.model
|
|
@ -9,13 +9,14 @@ from typing import List
|
||||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_inference_providers() -> List[ProviderSpec]:
|
def available_memory_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
provider_id="meta-reference-faiss",
|
provider_id="meta-reference-faiss",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"faiss",
|
"blobfile",
|
||||||
|
"faiss-cpu",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.memory.meta_reference.faiss",
|
module="llama_toolchain.memory.meta_reference.faiss",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue