mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 21:19:48 +00:00
implement embedding generation in supported inference providers (#589)
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
This commit is contained in:
parent
6a23f24ee0
commit
d362d2d740
32 changed files with 597 additions and 143 deletions
|
|
@ -19,11 +19,10 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
|
@ -32,7 +31,8 @@ from .config import FaissImplConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
|
|
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
if not self.kvstore:
|
||||
return
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
stored_data = await self.kvstore.get(index_key)
|
||||
|
||||
if stored_data:
|
||||
|
|
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
|
|||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||
}
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||
|
||||
async def delete(self):
|
||||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
# Add dimension check
|
||||
embedding_dim = (
|
||||
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||
)
|
||||
if embedding_dim != self.index.d:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
|
||||
|
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
||||
bank,
|
||||
await FaissIndex.create(
|
||||
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
|
|
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
|
||||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
await FaissIndex.create(
|
||||
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue