mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
fix: Make SentenceTransformer embedding operations non-blocking (#3335)
- Wrap model loading with asyncio.to_thread() to prevent blocking during model download/initialization - Wrap encoding operations with asyncio.to_thread() to run in background thread - Convert _load_sentence_transformer_model() to async method This ensures the async event loop remains responsive during embedding operations. Closes: #3332 Signed-off-by: Derek Higgins <derekh@redhat.com> Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
This commit is contained in:
parent
85f33762d7
commit
5bbca56cfc
1 changed files with 15 additions and 8 deletions
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(
|
embeddings = await asyncio.to_thread(
|
||||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
embedding_model.encode,
|
||||||
|
[interleaved_content_as_str(content) for content in contents],
|
||||||
|
show_progress_bar=False,
|
||||||
)
|
)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
||||||
# Get the model and generate embeddings
|
# Get the model and generate embeddings
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||||
|
|
||||||
# Convert embeddings to the requested format
|
# Convert embeddings to the requested format
|
||||||
data = []
|
data = []
|
||||||
|
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODELS
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
|
|
||||||
|
def _load_model():
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
return SentenceTransformer(model)
|
||||||
|
|
||||||
|
loaded_model = await asyncio.to_thread(_load_model)
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue