diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 65ba2854b..9bd0aa8ce 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import base64 import struct from typing import TYPE_CHECKING @@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin: task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode( - [interleaved_content_as_str(content) for content in contents], show_progress_bar=False + embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id) + embeddings = await asyncio.to_thread( + embedding_model.encode, + [interleaved_content_as_str(content) for content in contents], + show_progress_bar=False, ) return EmbeddingsResponse(embeddings=embeddings) @@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin: # Get the model and generate embeddings model_obj = await self.model_store.get_model(model) - embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) - embeddings = embedding_model.encode(input_list, show_progress_bar=False) + embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id) + embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False) # Convert embeddings to the requested format data = [] @@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin: usage=usage, ) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model) @@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin: return loaded_model log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - loaded_model = SentenceTransformer(model) + def _load_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer(model) + + loaded_model = await asyncio.to_thread(_load_model) EMBEDDING_MODELS[model] = loaded_model return loaded_model