fix weaviate, update run.yamls

This commit is contained in:
Ashwin Bharambe 2024-10-09 22:15:28 -07:00
parent 238e658cdf
commit 8a175129fc
4 changed files with 7 additions and 49 deletions

View file

@ -40,19 +40,3 @@ providers:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference
config: {} config: {}
models:
- identifier: Llama3.1-8B-Instruct
llama_model: Llama3.1-8B-Instruct
provider_id: remote::ollama
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null

View file

@ -43,19 +43,3 @@ providers:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference
config: {} config: {}
models:
- identifier: Llama3.1-8B-Instruct
llama_model: Llama3.1-8B-Instruct
provider_id: meta-reference
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null

View file

@ -14,6 +14,7 @@ from weaviate.classes.init import Auth
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -78,7 +79,9 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig) -> None:
self.config = config self.config = config
self.client_cache = {} self.client_cache = {}
@ -136,6 +139,9 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]

View file

@ -48,19 +48,3 @@ providers:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference
config: {} config: {}
models:
- identifier: Llama3.1-8B-Instruct
llama_model: Llama3.1-8B-Instruct
provider_id: meta-reference
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null