user inference api to generate embeddings in vector store

This commit is contained in:
Dinesh Yeduguru 2024-12-09 12:49:35 -08:00
parent 96accc1216
commit 5bbeb985ca
12 changed files with 134 additions and 96 deletions

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps):
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config)
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -16,11 +16,12 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
InferenceEmbeddingMixin,
)
from .config import PGVectorConfig
@ -119,9 +120,12 @@ class PGVectorIndex(EmbeddingIndex):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None:
class PGVectorMemoryAdapter(
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cursor = None
self.conn = None
self.cache = {}
@ -160,27 +164,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models(
self.cursor,
[
(memory_bank.identifier, memory_bank),
],
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
self.cache[memory_bank.identifier] = self._create_bank_with_index(
memory_bank, index
)
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
@ -189,9 +183,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:
if bank.identifier not in self.cache:
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
index = self._create_bank_with_index(
bank,
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
)
self.cache[bank.identifier] = index
return banks
@ -214,14 +208,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
self.cache[bank_id] = self._create_bank_with_index(bank, index)
return self.cache[bank_id]