mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix pgvector, store source of truth in Chroma
This commit is contained in:
parent
dd9d34cf7d
commit
fe0dabe596
4 changed files with 86 additions and 40 deletions
|
@ -11,8 +11,11 @@ from urllib.parse import urlparse
|
||||||
import chromadb
|
import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from pydantic import parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
@ -63,7 +66,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
|
@ -101,31 +104,33 @@ class ChromaMemoryAdapter(Memory):
|
||||||
|
|
||||||
collection = await self.client.get_or_create_collection(
|
collection = await self.client.get_or_create_collection(
|
||||||
name=memory_bank.identifier,
|
name=memory_bank.identifier,
|
||||||
|
metadata={"bank": memory_bank.json()},
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
bank_index = BankWithIndex(
|
||||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = bank_index
|
self.cache[memory_bank.identifier] = bank_index
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
if bank_id in self.cache:
|
|
||||||
return self.cache[bank_id]
|
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
|
||||||
if bank is None:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
|
||||||
|
|
||||||
collections = await self.client.list_collections()
|
collections = await self.client.list_collections()
|
||||||
for collection in collections:
|
for collection in collections:
|
||||||
if collection.name == bank_id:
|
try:
|
||||||
|
data = json.loads(collection.metadata["bank"])
|
||||||
|
bank = parse_obj_as(MemoryBankDef, data)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Failed to parse bank: {collection.metadata}")
|
||||||
|
continue
|
||||||
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=ChromaIndex(self.client, collection),
|
index=ChromaIndex(self.client, collection),
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank.identifier] = index
|
||||||
return index
|
|
||||||
|
|
||||||
return None
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -133,7 +138,7 @@ class ChromaMemoryAdapter(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = self.cache.get(bank_id, None)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
@ -145,7 +150,7 @@ class ChromaMemoryAdapter(Memory):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = self.cache.get(bank_id, None)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
|
|
@ -4,15 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from psycopg2 import sql
|
from psycopg2 import sql
|
||||||
from psycopg2.extras import execute_values, Json
|
from psycopg2.extras import execute_values, Json
|
||||||
|
|
||||||
|
from pydantic import BaseModel, parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -28,10 +31,31 @@ def check_extension_version(cur):
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||||
|
query = sql.SQL(
|
||||||
|
"""
|
||||||
|
INSERT INTO metadata_store (key, data)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (key) DO UPDATE
|
||||||
|
SET data = EXCLUDED.data
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||||
|
execute_values(cur, query, values, template="(%s, %s)")
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(cur, cls):
|
||||||
|
query = "SELECT key, data FROM metadata_store"
|
||||||
|
cur.execute(query)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return [parse_obj_as(cls, row["data"]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
class PGVectorIndex(EmbeddingIndex):
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
self.table_name = f"vector_store_{bank.name}"
|
self.table_name = f"vector_store_{bank.identifier}"
|
||||||
|
|
||||||
self.cursor.execute(
|
self.cursor.execute(
|
||||||
f"""
|
f"""
|
||||||
|
@ -88,7 +112,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory):
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -113,6 +137,14 @@ class PGVectorMemoryAdapter(Memory):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Vector extension is not installed.")
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
data JSONB
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -130,26 +162,29 @@ class PGVectorMemoryAdapter(Memory):
|
||||||
memory_bank.type == MemoryBankType.vector.value
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.type}"
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
|
|
||||||
|
upsert_models(
|
||||||
|
self.cursor,
|
||||||
|
[
|
||||||
|
(memory_bank.identifier, memory_bank),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
|
||||||
if bank_id in self.cache:
|
|
||||||
return self.cache[bank_id]
|
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
|
||||||
if not bank:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
|
||||||
|
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
|
banks = load_models(self.cursor, MemoryBankDef)
|
||||||
|
for bank in banks:
|
||||||
|
if bank.identifier not in self.cache:
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank.identifier] = index
|
||||||
return index
|
return banks
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -157,7 +192,7 @@ class PGVectorMemoryAdapter(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = self.cache.get(bank_id, None)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
@ -169,7 +204,7 @@ class PGVectorMemoryAdapter(Memory):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = self.cache.get(bank_id, None)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
|
|
@ -140,6 +140,10 @@ class WeaviateMemoryAdapter(
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
|
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||||
|
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
|
||||||
|
# list() happens at Stack startup when the Weaviate client (credentials) is not
|
||||||
|
# yet available. We need to figure out a way to make this work.
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
|
|
@ -81,6 +81,8 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(memory_settings):
|
async def test_banks_list(memory_settings):
|
||||||
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
banks_impl = memory_settings["memory_banks_impl"]
|
banks_impl = memory_settings["memory_banks_impl"]
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await banks_impl.list_memory_banks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue