mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
Added in registry and tests passed
This commit is contained in:
parent
c2d74188ee
commit
07e9da19b3
5 changed files with 42 additions and 25 deletions
|
@ -12,6 +12,6 @@ class PineconeRequestProviderData(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PineconeConfig(BaseModel):
|
class PineconeConfig(BaseModel):
|
||||||
dimensions: int
|
dimension: int = 384
|
||||||
cloud: str
|
cloud: str = "aws"
|
||||||
region: str
|
region: str = "us-east-1"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import time
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pinecone import ServerlessSpec
|
from pinecone import ServerlessSpec
|
||||||
|
@ -34,15 +34,20 @@ class PineconeIndex(EmbeddingIndex):
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
data_objects.append(
|
data_objects.append(
|
||||||
{
|
{
|
||||||
"id": f"vec{i+1}",
|
"id": chunk.document_id,
|
||||||
"values": embeddings[i].tolist(),
|
"values": embeddings[i].tolist(),
|
||||||
"metadata": {"chunk": chunk},
|
"metadata": {
|
||||||
|
"content": chunk.content,
|
||||||
|
"token_count": chunk.token_count,
|
||||||
|
"document_id": chunk.document_id,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inserting chunks into a prespecified Weaviate collection
|
# Inserting chunks into a prespecified Weaviate collection
|
||||||
index = self.client.Index(self.index_name)
|
index = self.client.Index(self.index_name)
|
||||||
index.upsert(vectors=data_objects)
|
index.upsert(vectors=data_objects)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, embedding: NDArray, k: int, score_threshold: float
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
@ -50,16 +55,16 @@ class PineconeIndex(EmbeddingIndex):
|
||||||
index = self.client.Index(self.index_name)
|
index = self.client.Index(self.index_name)
|
||||||
|
|
||||||
results = index.query(
|
results = index.query(
|
||||||
vector=embedding, top_k=k, include_values=True, include_metadata=True
|
vector=embedding, top_k=k, include_values=False, include_metadata=True
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for doc in results["matches"]:
|
for doc in results["matches"]:
|
||||||
chunk_json = doc["metadata"]["chunk"]
|
chunk_json = doc["metadata"]
|
||||||
|
print(f"chunk_json: {chunk_json}")
|
||||||
try:
|
try:
|
||||||
chunk_dict = json.loads(chunk_json)
|
chunk = Chunk(**chunk_json)
|
||||||
chunk = Chunk(**chunk_dict)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -130,11 +135,11 @@ class PineconeMemoryAdapter(
|
||||||
if not self.check_if_index_exists(client, memory_bank.identifier):
|
if not self.check_if_index_exists(client, memory_bank.identifier):
|
||||||
client.create_index(
|
client.create_index(
|
||||||
name=memory_bank.identifier,
|
name=memory_bank.identifier,
|
||||||
dimension=self.config.dimensions if self.config.dimensions else 1024,
|
dimension=self.config.dimension,
|
||||||
metric="cosine",
|
metric="cosine",
|
||||||
spec=ServerlessSpec(
|
spec=ServerlessSpec(
|
||||||
cloud=self.config.cloud if self.config.cloud else "aws",
|
cloud=self.config.cloud,
|
||||||
region=self.config.region if self.config.region else "us-east-1",
|
region=self.config.region,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -146,7 +151,7 @@ class PineconeMemoryAdapter(
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||||
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
|
# not ideal. It should be pinecone which is the source of truth. Unfortunately,
|
||||||
# list() happens at Stack startup when the Pinecone client (credentials) is not
|
# list() happens at Stack startup when the Pinecone client (credentials) is not
|
||||||
# yet available. We need to figure out a way to make this work.
|
# yet available. We need to figure out a way to make this work.
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
|
@ -84,4 +84,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig",
|
config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
Api.memory,
|
||||||
|
AdapterSpec(
|
||||||
|
adapter_type="pinecone",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["pinecone"],
|
||||||
|
module="llama_stack.providers.adapters.memory.pinecone",
|
||||||
|
config_class="llama_stack.providers.adapters.memory.pinecone.PineconeConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.adapters.memory.pinecone.PineconeRequestProviderData",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -20,10 +20,12 @@ providers:
|
||||||
config:
|
config:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 6333
|
port: 6333
|
||||||
|
- provider_id: test-pinecone
|
||||||
|
provider_type: remote::pinecone
|
||||||
|
config: {}
|
||||||
# if a provider needs private keys from the client, they use the
|
# if a provider needs private keys from the client, they use the
|
||||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||||
# this is a place to provide such data.
|
# this is a place to provide such data.
|
||||||
provider_data:
|
provider_data:
|
||||||
"test-weaviate":
|
"test-pinecone":
|
||||||
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
pinecone_api_key:
|
||||||
weaviate_cluster_url: http://foobarbaz
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ def sample_documents():
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
bank = VectorMemoryBankDef(
|
bank = VectorMemoryBankDef(
|
||||||
identifier="test_bank",
|
identifier="test-bank",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
|
@ -95,7 +95,7 @@ async def test_banks_register(memory_settings):
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
banks_impl = memory_settings["memory_banks_impl"]
|
banks_impl = memory_settings["memory_banks_impl"]
|
||||||
bank = VectorMemoryBankDef(
|
bank = VectorMemoryBankDef(
|
||||||
identifier="test_bank_no_provider",
|
identifier="test-bank-no-provider",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
|
@ -119,33 +119,33 @@ async def test_query_documents(memory_settings, sample_documents):
|
||||||
banks_impl = memory_settings["memory_banks_impl"]
|
banks_impl = memory_settings["memory_banks_impl"]
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test-bank", sample_documents)
|
||||||
|
|
||||||
await register_memory_bank(banks_impl)
|
await register_memory_bank(banks_impl)
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test-bank", sample_documents)
|
||||||
|
|
||||||
query1 = "programming language"
|
query1 = "programming language"
|
||||||
response1 = await memory_impl.query_documents("test_bank", query1)
|
response1 = await memory_impl.query_documents("test-bank", query1)
|
||||||
assert_valid_response(response1)
|
assert_valid_response(response1)
|
||||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||||
|
|
||||||
# Test case 3: Query with semantic similarity
|
# Test case 3: Query with semantic similarity
|
||||||
query3 = "AI and brain-inspired computing"
|
query3 = "AI and brain-inspired computing"
|
||||||
response3 = await memory_impl.query_documents("test_bank", query3)
|
response3 = await memory_impl.query_documents("test-bank", query3)
|
||||||
assert_valid_response(response3)
|
assert_valid_response(response3)
|
||||||
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
|
||||||
|
|
||||||
# Test case 4: Query with limit on number of results
|
# Test case 4: Query with limit on number of results
|
||||||
query4 = "computer"
|
query4 = "computer"
|
||||||
params4 = {"max_chunks": 2}
|
params4 = {"max_chunks": 2}
|
||||||
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
response4 = await memory_impl.query_documents("test-bank", query4, params4)
|
||||||
assert_valid_response(response4)
|
assert_valid_response(response4)
|
||||||
assert len(response4.chunks) <= 2
|
assert len(response4.chunks) <= 2
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.2}
|
||||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
response5 = await memory_impl.query_documents("test-bank", query5, params5)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.2 for score in response5.scores)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue