Added in registry and tests passed

This commit is contained in:
Sarthak Deshpande 2024-10-23 23:45:01 +05:30
parent c2d74188ee
commit 07e9da19b3
5 changed files with 42 additions and 25 deletions

View file

@ -12,6 +12,6 @@ class PineconeRequestProviderData(BaseModel):
class PineconeConfig(BaseModel): class PineconeConfig(BaseModel):
dimensions: int dimension: int = 384
cloud: str cloud: str = "aws"
region: str region: str = "us-east-1"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import time
from numpy.typing import NDArray from numpy.typing import NDArray
from pinecone import ServerlessSpec from pinecone import ServerlessSpec
@ -34,15 +34,20 @@ class PineconeIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
data_objects.append( data_objects.append(
{ {
"id": f"vec{i+1}", "id": chunk.document_id,
"values": embeddings[i].tolist(), "values": embeddings[i].tolist(),
"metadata": {"chunk": chunk}, "metadata": {
"content": chunk.content,
"token_count": chunk.token_count,
"document_id": chunk.document_id,
},
} }
) )
# Inserting chunks into a prespecified Weaviate collection # Inserting chunks into a prespecified Weaviate collection
index = self.client.Index(self.index_name) index = self.client.Index(self.index_name)
index.upsert(vectors=data_objects) index.upsert(vectors=data_objects)
time.sleep(1)
async def query( async def query(
self, embedding: NDArray, k: int, score_threshold: float self, embedding: NDArray, k: int, score_threshold: float
@ -50,16 +55,16 @@ class PineconeIndex(EmbeddingIndex):
index = self.client.Index(self.index_name) index = self.client.Index(self.index_name)
results = index.query( results = index.query(
vector=embedding, top_k=k, include_values=True, include_metadata=True vector=embedding, top_k=k, include_values=False, include_metadata=True
) )
chunks = [] chunks = []
scores = [] scores = []
for doc in results["matches"]: for doc in results["matches"]:
chunk_json = doc["metadata"]["chunk"] chunk_json = doc["metadata"]
print(f"chunk_json: {chunk_json}")
try: try:
chunk_dict = json.loads(chunk_json) chunk = Chunk(**chunk_json)
chunk = Chunk(**chunk_dict)
except Exception: except Exception:
import traceback import traceback
@ -130,11 +135,11 @@ class PineconeMemoryAdapter(
if not self.check_if_index_exists(client, memory_bank.identifier): if not self.check_if_index_exists(client, memory_bank.identifier):
client.create_index( client.create_index(
name=memory_bank.identifier, name=memory_bank.identifier,
dimension=self.config.dimensions if self.config.dimensions else 1024, dimension=self.config.dimension,
metric="cosine", metric="cosine",
spec=ServerlessSpec( spec=ServerlessSpec(
cloud=self.config.cloud if self.config.cloud else "aws", cloud=self.config.cloud,
region=self.config.region if self.config.region else "us-east-1", region=self.config.region,
), ),
) )
@ -146,7 +151,7 @@ class PineconeMemoryAdapter(
async def list_memory_banks(self) -> List[MemoryBankDef]: async def list_memory_banks(self) -> List[MemoryBankDef]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is # TODO: right now the Llama Stack is the source of truth for these banks. That is
# not ideal. It should be Weaviate which is the source of truth. Unfortunately, # not ideal. It should be pinecone which is the source of truth. Unfortunately,
# list() happens at Stack startup when the Pinecone client (credentials) is not # list() happens at Stack startup when the Pinecone client (credentials) is not
# yet available. We need to figure out a way to make this work. # yet available. We need to figure out a way to make this work.
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]

View file

@ -84,4 +84,14 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig", config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig",
), ),
), ),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_type="pinecone",
pip_packages=EMBEDDING_DEPS + ["pinecone"],
module="llama_stack.providers.adapters.memory.pinecone",
config_class="llama_stack.providers.adapters.memory.pinecone.PineconeConfig",
provider_data_validator="llama_stack.providers.adapters.memory.pinecone.PineconeRequestProviderData",
),
),
] ]

View file

@ -20,10 +20,12 @@ providers:
config: config:
host: localhost host: localhost
port: 6333 port: 6333
- provider_id: test-pinecone
provider_type: remote::pinecone
config: {}
# if a provider needs private keys from the client, they use the # if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py) # "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data. # this is a place to provide such data.
provider_data: provider_data:
"test-weaviate": "test-pinecone":
weaviate_api_key: 0xdeadbeefputrealapikeyhere pinecone_api_key:
weaviate_cluster_url: http://foobarbaz

View file

@ -69,7 +69,7 @@ def sample_documents():
async def register_memory_bank(banks_impl: MemoryBanks): async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef( bank = VectorMemoryBankDef(
identifier="test_bank", identifier="test-bank",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
@ -95,7 +95,7 @@ async def test_banks_register(memory_settings):
# but so far we don't have an unregister API unfortunately, so be careful # but so far we don't have an unregister API unfortunately, so be careful
banks_impl = memory_settings["memory_banks_impl"] banks_impl = memory_settings["memory_banks_impl"]
bank = VectorMemoryBankDef( bank = VectorMemoryBankDef(
identifier="test_bank_no_provider", identifier="test-bank-no-provider",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
@ -119,33 +119,33 @@ async def test_query_documents(memory_settings, sample_documents):
banks_impl = memory_settings["memory_banks_impl"] banks_impl = memory_settings["memory_banks_impl"]
with pytest.raises(ValueError): with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test-bank", sample_documents)
await register_memory_bank(banks_impl) await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test-bank", sample_documents)
query1 = "programming language" query1 = "programming language"
response1 = await memory_impl.query_documents("test_bank", query1) response1 = await memory_impl.query_documents("test-bank", query1)
assert_valid_response(response1) assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks) assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity # Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing" query3 = "AI and brain-inspired computing"
response3 = await memory_impl.query_documents("test_bank", query3) response3 = await memory_impl.query_documents("test-bank", query3)
assert_valid_response(response3) assert_valid_response(response3)
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks) assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
# Test case 4: Query with limit on number of results # Test case 4: Query with limit on number of results
query4 = "computer" query4 = "computer"
params4 = {"max_chunks": 2} params4 = {"max_chunks": 2}
response4 = await memory_impl.query_documents("test_bank", query4, params4) response4 = await memory_impl.query_documents("test-bank", query4, params4)
assert_valid_response(response4) assert_valid_response(response4)
assert len(response4.chunks) <= 2 assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score # Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2} params5 = {"score_threshold": 0.2}
response5 = await memory_impl.query_documents("test_bank", query5, params5) response5 = await memory_impl.query_documents("test-bank", query5, params5)
assert_valid_response(response5) assert_valid_response(response5)
print("The scores are:", response5.scores) print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores) assert all(score >= 0.2 for score in response5.scores)