mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 16:19:49 +00:00
implement embedding generation in supported inference providers (#589)
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
This commit is contained in:
parent
6a23f24ee0
commit
d362d2d740
32 changed files with 597 additions and 143 deletions
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||
|
||||
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(
|
||||
model.provider_resource_id
|
||||
)
|
||||
embeddings = embedding_model.encode(contents)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
|||
class BankWithIndex:
|
||||
bank: VectorMemoryBank
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
|
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
|||
)
|
||||
if not chunks:
|
||||
continue
|
||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [x.content for x in chunks]
|
||||
)
|
||||
embeddings = np.array(embeddings_response.embeddings)
|
||||
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
|
|
@ -208,6 +193,8 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [query_str]
|
||||
)
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue