diff --git a/llama_stack/providers/adapters/memory/pinecone/config.py b/llama_stack/providers/adapters/memory/pinecone/config.py index 8043e0a53..8e66eefec 100644 --- a/llama_stack/providers/adapters/memory/pinecone/config.py +++ b/llama_stack/providers/adapters/memory/pinecone/config.py @@ -12,6 +12,6 @@ class PineconeRequestProviderData(BaseModel): class PineconeConfig(BaseModel): - dimensions: int - cloud: str - region: str + dimension: int = 384 + cloud: str = "aws" + region: str = "us-east-1" diff --git a/llama_stack/providers/adapters/memory/pinecone/pinecone.py b/llama_stack/providers/adapters/memory/pinecone/pinecone.py index 0cade2b10..acc2a8a9d 100644 --- a/llama_stack/providers/adapters/memory/pinecone/pinecone.py +++ b/llama_stack/providers/adapters/memory/pinecone/pinecone.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json +import time from numpy.typing import NDArray from pinecone import ServerlessSpec @@ -34,15 +34,20 @@ class PineconeIndex(EmbeddingIndex): for i, chunk in enumerate(chunks): data_objects.append( { - "id": f"vec{i+1}", + "id": chunk.document_id, "values": embeddings[i].tolist(), - "metadata": {"chunk": chunk}, + "metadata": { + "content": chunk.content, + "token_count": chunk.token_count, + "document_id": chunk.document_id, + }, } ) # Inserting chunks into a prespecified Weaviate collection index = self.client.Index(self.index_name) index.upsert(vectors=data_objects) + time.sleep(1) async def query( self, embedding: NDArray, k: int, score_threshold: float @@ -50,16 +55,16 @@ class PineconeIndex(EmbeddingIndex): index = self.client.Index(self.index_name) results = index.query( - vector=embedding, top_k=k, include_values=True, include_metadata=True + vector=embedding, top_k=k, include_values=False, include_metadata=True ) chunks = [] scores = [] for doc in results["matches"]: - chunk_json = doc["metadata"]["chunk"] + chunk_json = doc["metadata"] + print(f"chunk_json: {chunk_json}") try: - chunk_dict = json.loads(chunk_json) - chunk = Chunk(**chunk_dict) + chunk = Chunk(**chunk_json) except Exception: import traceback @@ -130,11 +135,11 @@ class PineconeMemoryAdapter( if not self.check_if_index_exists(client, memory_bank.identifier): client.create_index( name=memory_bank.identifier, - dimension=self.config.dimensions if self.config.dimensions else 1024, + dimension=self.config.dimension, metric="cosine", spec=ServerlessSpec( - cloud=self.config.cloud if self.config.cloud else "aws", - region=self.config.region if self.config.region else "us-east-1", + cloud=self.config.cloud, + region=self.config.region, ), ) @@ -146,7 +151,7 @@ class PineconeMemoryAdapter( async def list_memory_banks(self) -> List[MemoryBankDef]: # TODO: right now the Llama Stack is the source of truth for these banks. That is - # not ideal. It should be Weaviate which is the source of truth. Unfortunately, + # not ideal. It should be pinecone which is the source of truth. Unfortunately, # list() happens at Stack startup when the Pinecone client (credentials) is not # yet available. We need to figure out a way to make this work. return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index a0fbf1636..62b07e9ca 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -84,4 +84,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig", ), ), + remote_provider_spec( + Api.memory, + AdapterSpec( + adapter_type="pinecone", + pip_packages=EMBEDDING_DEPS + ["pinecone"], + module="llama_stack.providers.adapters.memory.pinecone", + config_class="llama_stack.providers.adapters.memory.pinecone.PineconeConfig", + provider_data_validator="llama_stack.providers.adapters.memory.pinecone.PineconeRequestProviderData", + ), + ), ] diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml index 13575a598..da226d694 100644 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -20,10 +20,12 @@ providers: config: host: localhost port: 6333 + - provider_id: test-pinecone + provider_type: remote::pinecone + config: {} # if a provider needs private keys from the client, they use the # "get_request_provider_data" function (see distribution/request_headers.py) # this is a place to provide such data. provider_data: - "test-weaviate": - weaviate_api_key: 0xdeadbeefputrealapikeyhere - weaviate_cluster_url: http://foobarbaz + "test-pinecone": + pinecone_api_key: diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b26bf75a7..7043772db 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -69,7 +69,7 @@ def sample_documents(): async def register_memory_bank(banks_impl: MemoryBanks): bank = VectorMemoryBankDef( - identifier="test_bank", + identifier="test-bank", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, @@ -95,7 +95,7 @@ async def test_banks_register(memory_settings): # but so far we don't have an unregister API unfortunately, so be careful banks_impl = memory_settings["memory_banks_impl"] bank = VectorMemoryBankDef( - identifier="test_bank_no_provider", + identifier="test-bank-no-provider", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, @@ -119,33 +119,33 @@ async def test_query_documents(memory_settings, sample_documents): banks_impl = memory_settings["memory_banks_impl"] with pytest.raises(ValueError): - await memory_impl.insert_documents("test_bank", sample_documents) + await memory_impl.insert_documents("test-bank", sample_documents) await register_memory_bank(banks_impl) - await memory_impl.insert_documents("test_bank", sample_documents) + await memory_impl.insert_documents("test-bank", sample_documents) query1 = "programming language" - response1 = await memory_impl.query_documents("test_bank", query1) + response1 = await memory_impl.query_documents("test-bank", query1) assert_valid_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Test case 3: Query with semantic similarity query3 = "AI and brain-inspired computing" - response3 = await memory_impl.query_documents("test_bank", query3) + response3 = await memory_impl.query_documents("test-bank", query3) assert_valid_response(response3) assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks) # Test case 4: Query with limit on number of results query4 = "computer" params4 = {"max_chunks": 2} - response4 = await memory_impl.query_documents("test_bank", query4, params4) + response4 = await memory_impl.query_documents("test-bank", query4, params4) assert_valid_response(response4) assert len(response4.chunks) <= 2 # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document params5 = {"score_threshold": 0.2} - response5 = await memory_impl.query_documents("test_bank", query5, params5) + response5 = await memory_impl.query_documents("test-bank", query5, params5) assert_valid_response(response5) print("The scores are:", response5.scores) assert all(score >= 0.2 for score in response5.scores)