llama-stack-mirror/llama_stack/providers/utils/inference/embedding_mixin.py
Jaideep Rao 66412ab12b convert blocking calls to async
Signed-off-by: Jaideep Rao <jrao@redhat.com>
2025-03-20 13:07:23 -04:00

63 lines
No EOL
2.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import asyncio
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
InterleavedContentItem,
ModelStore,
TextTruncation,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
EMBEDDING_MODELS = {}
log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
# Execute the synchronous encode method in an executor
embeddings = await self._run_in_executor(embedding_model.encode, [interleaved_content_as_str(content) for content in contents], show_progress_bar=False)
return EmbeddingsResponse(embeddings=embeddings)
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
# Execute the synchronous SentenceTransformer instantiation in an executor
loaded_model = await self._run_in_executor(SentenceTransformer, model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
async def _run_in_executor(self, func, *args, **kwargs):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, func, *args, **kwargs)