Added Pinecone Memory Adapter

This commit is contained in:
Sarthak Deshpande 2024-10-23 13:16:36 +05:30
parent 2e5e46d896
commit c2d74188ee
3 changed files with 226 additions and 0 deletions

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PineconeConfig, PineconeRequestProviderData # noqa: F401
from .pinecone import PineconeMemoryAdapter
async def get_adapter_impl(config: PineconeConfig, _deps):
impl = PineconeMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class PineconeRequestProviderData(BaseModel):
pinecone_api_key: str
class PineconeConfig(BaseModel):
dimensions: int
cloud: str
region: str

View file

@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from numpy.typing import NDArray
from pinecone import ServerlessSpec
from pinecone.grpc import PineconeGRPC as Pinecone
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import PineconeConfig, PineconeRequestProviderData
class PineconeIndex(EmbeddingIndex):
def __init__(self, client: Pinecone, index_name: str):
self.client = client
self.index_name = index_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(
{
"id": f"vec{i+1}",
"values": embeddings[i].tolist(),
"metadata": {"chunk": chunk},
}
)
# Inserting chunks into a prespecified Weaviate collection
index = self.client.Index(self.index_name)
index.upsert(vectors=data_objects)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
index = self.client.Index(self.index_name)
results = index.query(
vector=embedding, top_k=k, include_values=True, include_metadata=True
)
chunks = []
scores = []
for doc in results["matches"]:
chunk_json = doc["metadata"]["chunk"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)
scores.append(doc.score)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PineconeMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
):
def __init__(self, config: PineconeConfig) -> None:
self.config = config
self.client_cache = {}
self.cache = {}
def _get_client(self) -> Pinecone:
provider_data = self.get_request_provider_data()
assert provider_data is not None, "Request provider data must be set"
assert isinstance(provider_data, PineconeRequestProviderData)
key = f"{provider_data.pinecone_api_key}"
if key in self.client_cache:
return self.client_cache[key]
client = Pinecone(api_key=provider_data.pinecone_api_key)
self.client_cache[key] = client
return client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def check_if_index_exists(
self,
client: Pinecone,
index_name: str,
) -> bool:
try:
# Get list of all indexes
active_indexes = client.list_indexes()
for index in active_indexes:
if index["name"] == index_name:
return True
return False
except Exception as e:
print(f"Error checking index: {e}")
return False
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
client = self._get_client()
# Create collection if it doesn't exist
if not self.check_if_index_exists(client, memory_bank.identifier):
client.create_index(
name=memory_bank.identifier,
dimension=self.config.dimensions if self.config.dimensions else 1024,
metric="cosine",
spec=ServerlessSpec(
cloud=self.config.cloud if self.config.cloud else "aws",
region=self.config.region if self.config.region else "us-east-1",
),
)
index = BankWithIndex(
bank=memory_bank,
index=PineconeIndex(client=client, index_name=memory_bank.identifier),
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
# list() happens at Stack startup when the Pinecone client (credentials) is not
# yet available. We need to figure out a way to make this work.
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
client = self._get_client()
if not self.check_if_index_exists(client, bank_id):
raise ValueError(f"Collection with name `{bank_id}` not found")
index = BankWithIndex(
bank=bank,
index=PineconeIndex(client=client, index_name=bank_id),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)