forked from phoenix-oss/llama-stack-mirror
Make embedding generation go through inference (#606)
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
This commit is contained in:
parent
a14785af46
commit
96e158eaac
37 changed files with 677 additions and 156 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
import json
|
||||
|
||||
from botocore.client import BaseClient
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
|
||||
model_aliases = [
|
||||
|
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(
|
||||
content
|
||||
), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_text_media_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
|||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference",
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference",
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = None
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
fireworks_api_key = self.config.api_key
|
||||
return self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
fireworks_api_key = provider_data.fireworks_api_key
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
async def completion(
|
||||
|
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Fireworks does not support media for embeddings"
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Ollama does not support media for embeddings"
|
||||
response = await self.client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
return model
|
||||
model = await self.register_helper.register_model(model)
|
||||
models = await self.client.ps()
|
||||
available_models = [m["model"] for m in models["models"]]
|
||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Together does not support media for embeddings"
|
||||
r = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding_model
|
||||
assert model.metadata.get("embedding_dimensions")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -4,12 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import ChromaRemoteImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
|
||||
async def get_adapter_impl(
|
||||
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ChromaMemoryAdapter(config)
|
||||
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -13,8 +13,7 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
|
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
|
||||
self,
|
||||
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
|
||||
inference_api: Api.inference,
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
|
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
|
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
collection = await maybe_await(self.client.get_collection(bank_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
index = BankWithIndex(
|
||||
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorMemoryAdapter
|
||||
|
||||
impl = PGVectorMemoryAdapter(config)
|
||||
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
|
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(memory_bank.identifier, memory_bank),
|
||||
],
|
||||
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, index, self.inference_api
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||
return self.cache[bank_id]
|
||||
|
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
||||
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .qdrant import QdrantVectorMemoryAdapter
|
||||
|
||||
impl = QdrantVectorMemoryAdapter(config)
|
||||
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: QdrantConfig) -> None:
|
||||
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -12,10 +12,11 @@ import weaviate
|
|||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: List[str]) -> None:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
collection.data.delete_many(
|
||||
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||
)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(
|
||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
||||
Memory,
|
||||
NeedsRequestProviderData,
|
||||
MemoryBanksProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
|
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
|
|||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
|
|||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
|
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
|
|||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue