mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
user inference api to generate embeddings in vector store
This commit is contained in:
parent
96accc1216
commit
5bbeb985ca
12 changed files with 134 additions and 96 deletions
|
@ -4,16 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .faiss import FaissMemoryImpl
|
from .faiss import FaissMemoryImpl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, FaissImplConfig
|
config, FaissImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissMemoryImpl(config)
|
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,13 +19,12 @@ from numpy.typing import NDArray
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
InferenceEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
@ -95,6 +94,15 @@ class FaissIndex(EmbeddingIndex):
|
||||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||||
|
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
# Add dimension check
|
||||||
|
embedding_dim = (
|
||||||
|
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||||
|
)
|
||||||
|
if embedding_dim != self.index.d:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
indexlen = len(self.id_by_index)
|
indexlen = len(self.id_by_index)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
@ -123,9 +131,10 @@ class FaissIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.kvstore = None
|
self.kvstore = None
|
||||||
|
|
||||||
|
@ -138,10 +147,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = self._create_bank_with_index(
|
||||||
bank=bank,
|
bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
@ -166,13 +175,12 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
|
@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.chroma",
|
module="llama_stack.providers.remote.memory.chroma",
|
||||||
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
|
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.pgvector",
|
module="llama_stack.providers.remote.memory.pgvector",
|
||||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.sample",
|
module="llama_stack.providers.remote.memory.sample",
|
||||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.qdrant",
|
module="llama_stack.providers.remote.memory.qdrant",
|
||||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,12 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .chroma import ChromaMemoryAdapter
|
from .chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
impl = ChromaMemoryAdapter(config.url)
|
impl = ChromaMemoryAdapter(config.url, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
InferenceEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -71,8 +72,8 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
await self.client.delete_collection(self.collection.name)
|
await self.client.delete_collection(self.collection.name)
|
||||||
|
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str, inference_api: Api.inference) -> None:
|
||||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
url = url.rstrip("/")
|
url = url.rstrip("/")
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
|
@ -82,6 +83,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
self.host = parsed.hostname
|
self.host = parsed.hostname
|
||||||
self.port = parsed.port
|
self.port = parsed.port
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -109,10 +111,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
name=memory_bank.identifier,
|
name=memory_bank.identifier,
|
||||||
metadata={"bank": memory_bank.model_dump_json()},
|
metadata={"bank": memory_bank.model_dump_json()},
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
memory_bank, ChromaIndex(self.client, collection)
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = bank_index
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
collections = await self.client.list_collections()
|
collections = await self.client.list_collections()
|
||||||
|
@ -124,11 +125,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index = BankWithIndex(
|
self.cache[bank.identifier] = self._create_bank_with_index(
|
||||||
bank=bank,
|
bank,
|
||||||
index=ChromaIndex(self.client, collection),
|
ChromaIndex(self.client, collection),
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
|
||||||
|
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
|
@ -166,6 +166,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
collection = await self.client.get_collection(bank_id)
|
collection = await self.client.get_collection(bank_id)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection))
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .pgvector import PGVectorMemoryAdapter
|
from .pgvector import PGVectorMemoryAdapter
|
||||||
|
|
||||||
impl = PGVectorMemoryAdapter(config)
|
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -16,11 +16,12 @@ from pydantic import BaseModel, parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
InferenceEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
@ -119,9 +120,12 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class PGVectorMemoryAdapter(
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
|
||||||
|
):
|
||||||
|
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cursor = None
|
self.cursor = None
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -160,27 +164,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
self,
|
|
||||||
memory_bank: MemoryBank,
|
|
||||||
) -> None:
|
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
upsert_models(
|
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||||
self.cursor,
|
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||||
[
|
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||||
(memory_bank.identifier, memory_bank),
|
memory_bank, index
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
|
||||||
bank=memory_bank,
|
|
||||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
|
||||||
)
|
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
|
@ -189,9 +183,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
banks = load_models(self.cursor, VectorMemoryBank)
|
banks = load_models(self.cursor, VectorMemoryBank)
|
||||||
for bank in banks:
|
for bank in banks:
|
||||||
if bank.identifier not in self.cache:
|
if bank.identifier not in self.cache:
|
||||||
index = BankWithIndex(
|
index = self._create_bank_with_index(
|
||||||
bank=bank,
|
bank,
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
return banks
|
return banks
|
||||||
|
@ -214,14 +208,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
return await index.query_documents(query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
index = BankWithIndex(
|
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||||
bank=bank,
|
self.cache[bank_id] = self._create_bank_with_index(bank, index)
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
return self.cache[bank_id]
|
||||||
)
|
|
||||||
self.cache[bank_id] = index
|
|
||||||
return index
|
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import QdrantConfig
|
from .config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .qdrant import QdrantVectorMemoryAdapter
|
from .qdrant import QdrantVectorMemoryAdapter
|
||||||
|
|
||||||
impl = QdrantVectorMemoryAdapter(config)
|
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
InferenceEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -100,11 +101,14 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class QdrantVectorMemoryAdapter(
|
||||||
def __init__(self, config: QdrantConfig) -> None:
|
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
|
||||||
|
):
|
||||||
|
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -120,7 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
index = BankWithIndex(
|
index = self._create_bank_with_index(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||||
)
|
)
|
||||||
|
@ -140,7 +144,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
if not bank:
|
if not bank:
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
index = BankWithIndex(
|
index = self._create_bank_with_index(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .weaviate import WeaviateMemoryAdapter
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
impl = WeaviateMemoryAdapter(config)
|
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
InferenceEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||||
|
@ -82,10 +83,14 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemoryAdapter(
|
class WeaviateMemoryAdapter(
|
||||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
InferenceEmbeddingMixin,
|
||||||
|
Memory,
|
||||||
|
NeedsRequestProviderData,
|
||||||
|
MemoryBanksProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: WeaviateConfig) -> None:
|
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -135,11 +140,10 @@ class WeaviateMemoryAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||||
|
@ -160,7 +164,7 @@ class WeaviateMemoryAdapter(
|
||||||
if not client.collections.exists(bank.identifier):
|
if not client.collections.exists(bank.identifier):
|
||||||
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
||||||
|
|
||||||
index = BankWithIndex(
|
index = self._create_bank_with_index(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
|
||||||
global EMBEDDING_MODELS
|
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
|
||||||
if loaded_model is not None:
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
||||||
class BankWithIndex:
|
class BankWithIndex:
|
||||||
bank: VectorMemoryBank
|
bank: VectorMemoryBank
|
||||||
index: EmbeddingIndex
|
index: EmbeddingIndex
|
||||||
|
inference_api: Api.inference
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
continue
|
continue
|
||||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
|
self.bank.embedding_model, [x.content for x in chunks]
|
||||||
|
)
|
||||||
|
embeddings = np.array(embeddings_response.embeddings)
|
||||||
|
|
||||||
await self.index.add_chunks(chunks, embeddings)
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
@ -208,6 +193,25 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
self.bank.embedding_model, [query_str]
|
||||||
|
)
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceEmbeddingMixin:
|
||||||
|
inference_api: Api.inference
|
||||||
|
|
||||||
|
def __init__(self, inference_api: Api.inference):
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
def _create_bank_with_index(
|
||||||
|
self, bank: VectorMemoryBank, index: EmbeddingIndex
|
||||||
|
) -> BankWithIndex:
|
||||||
|
|
||||||
|
return BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=index,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue