From c2d74188ee768ef9c3040239dee2447f3f209353 Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande Date: Wed, 23 Oct 2024 13:16:36 +0530 Subject: [PATCH] Added Pinecone Memory Adapter --- .../adapters/memory/pinecone/__init__.py | 14 ++ .../adapters/memory/pinecone/config.py | 17 ++ .../adapters/memory/pinecone/pinecone.py | 195 ++++++++++++++++++ 3 files changed, 226 insertions(+) create mode 100644 llama_stack/providers/adapters/memory/pinecone/__init__.py create mode 100644 llama_stack/providers/adapters/memory/pinecone/config.py create mode 100644 llama_stack/providers/adapters/memory/pinecone/pinecone.py diff --git a/llama_stack/providers/adapters/memory/pinecone/__init__.py b/llama_stack/providers/adapters/memory/pinecone/__init__.py new file mode 100644 index 000000000..d91442e14 --- /dev/null +++ b/llama_stack/providers/adapters/memory/pinecone/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import PineconeConfig, PineconeRequestProviderData # noqa: F401 +from .pinecone import PineconeMemoryAdapter + + +async def get_adapter_impl(config: PineconeConfig, _deps): + impl = PineconeMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/memory/pinecone/config.py b/llama_stack/providers/adapters/memory/pinecone/config.py new file mode 100644 index 000000000..8043e0a53 --- /dev/null +++ b/llama_stack/providers/adapters/memory/pinecone/config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class PineconeRequestProviderData(BaseModel): + pinecone_api_key: str + + +class PineconeConfig(BaseModel): + dimensions: int + cloud: str + region: str diff --git a/llama_stack/providers/adapters/memory/pinecone/pinecone.py b/llama_stack/providers/adapters/memory/pinecone/pinecone.py new file mode 100644 index 000000000..0cade2b10 --- /dev/null +++ b/llama_stack/providers/adapters/memory/pinecone/pinecone.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +from numpy.typing import NDArray +from pinecone import ServerlessSpec +from pinecone.grpc import PineconeGRPC as Pinecone + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import ( + BankWithIndex, + EmbeddingIndex, +) +from .config import PineconeConfig, PineconeRequestProviderData + + +class PineconeIndex(EmbeddingIndex): + def __init__(self, client: Pinecone, index_name: str): + self.client = client + self.index_name = index_name + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + + data_objects = [] + for i, chunk in enumerate(chunks): + data_objects.append( + { + "id": f"vec{i+1}", + "values": embeddings[i].tolist(), + "metadata": {"chunk": chunk}, + } + ) + + # Inserting chunks into a prespecified Weaviate collection + index = self.client.Index(self.index_name) + index.upsert(vectors=data_objects) + + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: + index = self.client.Index(self.index_name) + + results = index.query( + vector=embedding, top_k=k, include_values=True, include_metadata=True + ) + + chunks = [] + scores = [] + for doc in results["matches"]: + chunk_json = doc["metadata"]["chunk"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + import traceback + + traceback.print_exc() + print(f"Failed to parse document: {chunk_json}") + continue + + chunks.append(chunk) + scores.append(doc.score) + + return QueryDocumentsResponse(chunks=chunks, scores=scores) + + +class PineconeMemoryAdapter( + Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate +): + def __init__(self, config: PineconeConfig) -> None: + self.config = config + self.client_cache = {} + self.cache = {} + + def _get_client(self) -> Pinecone: + provider_data = self.get_request_provider_data() + assert provider_data is not None, "Request provider data must be set" + assert isinstance(provider_data, PineconeRequestProviderData) + + key = f"{provider_data.pinecone_api_key}" + if key in self.client_cache: + return self.client_cache[key] + + client = Pinecone(api_key=provider_data.pinecone_api_key) + self.client_cache[key] = client + return client + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def check_if_index_exists( + self, + client: Pinecone, + index_name: str, + ) -> bool: + try: + # Get list of all indexes + active_indexes = client.list_indexes() + for index in active_indexes: + if index["name"] == index_name: + return True + return False + except Exception as e: + print(f"Error checking index: {e}") + return False + + async def register_memory_bank( + self, + memory_bank: MemoryBankDef, + ) -> None: + assert ( + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" + + client = self._get_client() + + # Create collection if it doesn't exist + if not self.check_if_index_exists(client, memory_bank.identifier): + client.create_index( + name=memory_bank.identifier, + dimension=self.config.dimensions if self.config.dimensions else 1024, + metric="cosine", + spec=ServerlessSpec( + cloud=self.config.cloud if self.config.cloud else "aws", + region=self.config.region if self.config.region else "us-east-1", + ), + ) + + index = BankWithIndex( + bank=memory_bank, + index=PineconeIndex(client=client, index_name=memory_bank.identifier), + ) + self.cache[memory_bank.identifier] = index + + async def list_memory_banks(self) -> List[MemoryBankDef]: + # TODO: right now the Llama Stack is the source of truth for these banks. That is + # not ideal. It should be Weaviate which is the source of truth. Unfortunately, + # list() happens at Stack startup when the Pinecone client (credentials) is not + # yet available. We need to figure out a way to make this work. + return [i.bank for i in self.cache.values()] + + async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found") + + client = self._get_client() + if not self.check_if_index_exists(client, bank_id): + raise ValueError(f"Collection with name `{bank_id}` not found") + + index = BankWithIndex( + bank=bank, + index=PineconeIndex(client=client, index_name=bank_id), + ) + self.cache[bank_id] = index + return index + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + await index.insert_documents(documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + return await index.query_documents(query, params)