diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index 87fec6add..cb9d7463d 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -44,7 +44,8 @@ class MemoryClient(Memory): }, headers={ "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), + "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", + "weaviate_cluster_url": "http://localhost:8080"}), }, timeout=20, ) @@ -70,7 +71,8 @@ class MemoryClient(Memory): }, headers={ "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), + "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", + "weaviate_cluster_url": "http://localhost:8080"}), }, timeout=20, ) @@ -94,7 +96,8 @@ class MemoryClient(Memory): }, headers={ "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), + "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", + "weaviate_cluster_url": "http://localhost:8080"}), }, timeout=20, ) @@ -116,7 +119,8 @@ class MemoryClient(Memory): }, headers={ "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), + "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", + "weaviate_cluster_url": "http://localhost:8080"}), }, timeout=20, ) diff --git a/llama_stack/providers/adapters/memory/weaviate/config.py b/llama_stack/providers/adapters/memory/weaviate/config.py index b869fd544..db73604d2 100644 --- a/llama_stack/providers/adapters/memory/weaviate/config.py +++ b/llama_stack/providers/adapters/memory/weaviate/config.py @@ -11,8 +11,8 @@ class WeaviateRequestProviderData(BaseModel): # if there _is_ provider data, it must specify the API KEY # if you want it to be optional, use Optional[str] weaviate_api_key: str + weaviate_cluster_url: str @json_schema_type class WeaviateConfig(BaseModel): - url: str = Field(default="http://localhost:8080") collection: str = Field(default="MemoryBank") diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index f53851898..275722be2 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -1,7 +1,6 @@ import json import uuid from typing import List, Optional, Dict, Any -from urllib.parse import urlparse from numpy.typing import NDArray import weaviate @@ -72,7 +71,6 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateMemoryAdapter(Memory): def __init__(self, config: WeaviateConfig) -> None: - print(f"Initializing WeaviateMemoryAdapter with URL: {config.url}") self.config = config self.client = None self.cache = {} @@ -85,9 +83,10 @@ class WeaviateMemoryAdapter(Memory): assert isinstance(request_provider_data, WeaviateRequestProviderData) print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}") + print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}") # Connect to Weaviate Cloud self.client = weaviate.connect_to_weaviate_cloud( - cluster_url = self.config.url, + cluster_url = request_provider_data.weaviate_cluster_url, auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), )