user inference api to generate embeddings in vector store

This commit is contained in:
Dinesh Yeduguru 2024-12-09 12:49:35 -08:00
parent 96accc1216
commit 5bbeb985ca
12 changed files with 134 additions and 96 deletions

View file

@ -4,16 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps): async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissMemoryImpl from .faiss import FaissMemoryImpl
assert isinstance( assert isinstance(
config, FaissImplConfig config, FaissImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config) impl = FaissMemoryImpl(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,13 +19,12 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
InferenceEmbeddingMixin,
) )
from .config import FaissImplConfig from .config import FaissImplConfig
@ -95,6 +94,15 @@ class FaissIndex(EmbeddingIndex):
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index) indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk self.chunk_by_index[indexlen + i] = chunk
@ -123,9 +131,10 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cache = {} self.cache = {}
self.kvstore = None self.kvstore = None
@ -138,10 +147,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks: for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data) bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex( index = self._create_bank_with_index(
bank=bank, bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier bank.embedding_dimension, self.kvstore, bank.identifier
), ),
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
@ -166,13 +175,12 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
) )
# Store in cache # Store in cache
index = BankWithIndex( self.cache[memory_bank.identifier] = self._create_bank_with_index(
bank=memory_bank, memory_bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
), ),
) )
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]

View file

@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.", deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.pgvector", module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.memory, api=Api.memory,
@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.sample", module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig", config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
), ),
api_dependencies=[],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.qdrant", module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
), ),
api_dependencies=[Api.inference],
), ),
] ]

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaMemoryAdapter from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url) impl = ChromaMemoryAdapter(config.url, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
InferenceEmbeddingMixin,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -71,8 +72,8 @@ class ChromaIndex(EmbeddingIndex):
await self.client.delete_collection(self.collection.name) await self.client.delete_collection(self.collection.name)
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(self, url: str, inference_api: Api.inference) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {url}") log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
parsed = urlparse(url) parsed = urlparse(url)
@ -82,6 +83,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.host = parsed.hostname self.host = parsed.hostname
self.port = parsed.port self.port = parsed.port
self.inference_api = inference_api
self.client = None self.client = None
self.cache = {} self.cache = {}
@ -109,10 +111,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
name=memory_bank.identifier, name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()}, metadata={"bank": memory_bank.model_dump_json()},
) )
bank_index = BankWithIndex( self.cache[memory_bank.identifier] = self._create_bank_with_index(
bank=memory_bank, index=ChromaIndex(self.client, collection) memory_bank, ChromaIndex(self.client, collection)
) )
self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections() collections = await self.client.list_collections()
@ -124,11 +125,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
log.exception(f"Failed to parse bank: {collection.metadata}") log.exception(f"Failed to parse bank: {collection.metadata}")
continue continue
index = BankWithIndex( self.cache[bank.identifier] = self._create_bank_with_index(
bank=bank, bank,
index=ChromaIndex(self.client, collection), ChromaIndex(self.client, collection),
) )
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]
@ -166,6 +166,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_collection(bank_id) collection = await self.client.get_collection(bank_id)
if not collection: if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma") raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection))
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps): async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config) impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -16,11 +16,12 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
InferenceEmbeddingMixin,
) )
from .config import PGVectorConfig from .config import PGVectorConfig
@ -119,9 +120,12 @@ class PGVectorIndex(EmbeddingIndex):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class PGVectorMemoryAdapter(
def __init__(self, config: PGVectorConfig) -> None: InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cursor = None self.cursor = None
self.conn = None self.conn = None
self.cache = {} self.cache = {}
@ -160,27 +164,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_memory_bank( async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
self,
memory_bank: MemoryBank,
) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models( upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
self.cursor, index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
[ self.cache[memory_bank.identifier] = self._create_bank_with_index(
(memory_bank.identifier, memory_bank), memory_bank, index
],
) )
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id] del self.cache[memory_bank_id]
@ -189,9 +183,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
banks = load_models(self.cursor, VectorMemoryBank) banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks: for bank in banks:
if bank.identifier not in self.cache: if bank.identifier not in self.cache:
index = BankWithIndex( index = self._create_bank_with_index(
bank=bank, bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
return banks return banks
@ -214,14 +208,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params) return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex( index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
bank=bank, self.cache[bank_id] = self._create_bank_with_index(bank, index)
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), return self.cache[bank_id]
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps): async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config) impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -21,6 +21,7 @@ from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
InferenceEmbeddingMixin,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -100,11 +101,14 @@ class QdrantIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class QdrantVectorMemoryAdapter(
def __init__(self, config: QdrantConfig) -> None: InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {} self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -120,7 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
memory_bank.memory_bank_type == MemoryBankType.vector memory_bank.memory_bank_type == MemoryBankType.vector
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
index = BankWithIndex( index = self._create_bank_with_index(
bank=memory_bank, bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier), index=QdrantIndex(self.client, memory_bank.identifier),
) )
@ -140,7 +144,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
if not bank: if not bank:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found")
index = BankWithIndex( index = self._create_bank_with_index(
bank=bank, bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id), index=QdrantIndex(client=self.client, collection_name=bank_id),
) )

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps): async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config) impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,6 +19,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
InferenceEmbeddingMixin,
) )
from .config import WeaviateConfig, WeaviateRequestProviderData from .config import WeaviateConfig, WeaviateRequestProviderData
@ -82,10 +83,14 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateMemoryAdapter( class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate InferenceEmbeddingMixin,
Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
): ):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
@ -135,11 +140,10 @@ class WeaviateMemoryAdapter(
], ],
) )
index = BankWithIndex( self.cache[memory_bank.identifier] = self._create_bank_with_index(
bank=memory_bank, memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), WeaviateIndex(client=client, collection_name=memory_bank.identifier),
) )
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is # TODO: right now the Llama Stack is the source of truth for these banks. That is
@ -160,7 +164,7 @@ class WeaviateMemoryAdapter(
if not client.collections.exists(bank.identifier): if not client.collections.exists(bank.identifier):
raise ValueError(f"Collection with name `{bank.identifier}` not found") raise ValueError(f"Collection with name `{bank.identifier}` not found")
index = BankWithIndex( index = self._create_bank_with_index(
bank=bank, bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id), index=WeaviateIndex(client=client, collection_name=bank_id),
) )

View file

@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
class BankWithIndex: class BankWithIndex:
bank: VectorMemoryBank bank: VectorMemoryBank
index: EmbeddingIndex index: EmbeddingIndex
inference_api: Api.inference
async def insert_documents( async def insert_documents(
self, self,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
model = get_embedding_model(self.bank.embedding_model)
for doc in documents: for doc in documents:
content = await content_from_doc(doc) content = await content_from_doc(doc)
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
@ -183,7 +165,10 @@ class BankWithIndex:
) )
if not chunks: if not chunks:
continue continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32) embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
await self.index.add_chunks(chunks, embeddings) await self.index.add_chunks(chunks, embeddings)
@ -208,6 +193,25 @@ class BankWithIndex:
else: else:
query_str = _process(query) query_str = _process(query)
model = get_embedding_model(self.bank.embedding_model) embeddings_response = await self.inference_api.embeddings(
query_vector = model.encode([query_str])[0].astype(np.float32) self.bank.embedding_model, [query_str]
)
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold) return await self.index.query(query_vector, k, score_threshold)
class InferenceEmbeddingMixin:
inference_api: Api.inference
def __init__(self, inference_api: Api.inference):
self.inference_api = inference_api
def _create_bank_with_index(
self, bank: VectorMemoryBank, index: EmbeddingIndex
) -> BankWithIndex:
return BankWithIndex(
bank=bank,
index=index,
inference_api=self.inference_api,
)