faiss provider implementation

This commit is contained in:
Ashwin Bharambe 2024-08-23 20:58:27 -07:00
parent 14637bea66
commit a08958c000
9 changed files with 401 additions and 3 deletions

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
# import json
from typing import Dict, List, Optional
import fire
import httpx
# from termcolor import cprint
from .api import * # noqa: F403
async def get_client_impl(base_url: str):
return MemoryClient(base_url)
class MemoryClient(Memory):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
async with client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
async with client.post(
f"{self.base_url}/memory_banks/create",
data={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
async with client.post(
f"{self.base_url}/memory_bank/insert",
data={
"bank_id": bank_id,
"documents": documents,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
async with client.post(
f"{self.base_url}/memory_bank/query",
data={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="dragon-roberta-query-2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
print(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=[
MemoryBankDocument(
document_id="1",
content="hello world",
),
MemoryBankDocument(
document_id="2",
content="goodbye world",
),
],
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"hello world",
],
)
print(response)
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)