forked from phoenix-oss/llama-stack-mirror
Make embedding generation go through inference (#606)
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
This commit is contained in:
parent
a14785af46
commit
96e158eaac
37 changed files with 677 additions and 156 deletions
|
@ -19,11 +19,10 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
@ -32,7 +31,8 @@ from .config import FaissImplConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
|
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
if not self.kvstore:
|
||||
return
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
stored_data = await self.kvstore.get(index_key)
|
||||
|
||||
if stored_data:
|
||||
|
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
|
|||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||
}
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||
|
||||
async def delete(self):
|
||||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
# Add dimension check
|
||||
embedding_dim = (
|
||||
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||
)
|
||||
if embedding_dim != self.index.d:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
|
||||
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
||||
bank,
|
||||
await FaissIndex.create(
|
||||
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
|
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
|
||||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
await FaissIndex.create(
|
||||
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue