From 66412ab12b99509f708538ef09adbe950ff85aaf Mon Sep 17 00:00:00 2001 From: Jaideep Rao Date: Fri, 14 Mar 2025 13:36:27 -0400 Subject: [PATCH] convert blocking calls to async Signed-off-by: Jaideep Rao --- .../providers/remote/inference/runpod/runpod.py | 12 ++++++------ .../utils/inference/embedding_mixin.py | 17 +++++++++++------ .../utils/inference/litellm_openai_mixin.py | 12 ++++++------ 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 72f858cd8..6ae21e545 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from typing import AsyncGenerator -from openai import OpenAI +from openai import AsyncOpenAI from llama_stack.apis.inference import * # noqa: F403 @@ -85,24 +85,24 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): tool_config=tool_config, ) - client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token) if stream: return self._stream_chat_completion(request, client) else: return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: OpenAI + self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> ChatCompletionResponse: params = self._get_params(request) - r = client.completions.create(**params) + r = await client.completions.create(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: + async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator: params = self._get_params(request) async def _to_async_generator(): - s = client.completions.create(**params) + s = await client.completions.create(**params) for chunk in s: yield chunk diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 8b14c7502..e00bfd91d 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import logging +import asyncio from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: @@ -37,13 +38,12 @@ class SentenceTransformerEmbeddingMixin: task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode( - [interleaved_content_as_str(content) for content in contents], show_progress_bar=False - ) + embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id) + # Execute the synchronous encode method in an executor + embeddings = await self._run_in_executor(embedding_model.encode, [interleaved_content_as_str(content) for content in contents], show_progress_bar=False) return EmbeddingsResponse(embeddings=embeddings) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model) @@ -53,6 +53,11 @@ class SentenceTransformerEmbeddingMixin: log.info(f"Loading sentence transformer for {model}...") from sentence_transformers import SentenceTransformer - loaded_model = SentenceTransformer(model) + # Execute the synchronous SentenceTransformer instantiation in an executor + loaded_model = await self._run_in_executor(SentenceTransformer, model) EMBEDDING_MODELS[model] = loaded_model return loaded_model + + async def _run_in_executor(self, func, *args, **kwargs): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, func, *args, **kwargs) \ No newline at end of file diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index f99883990..db8bcdb81 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -112,9 +112,9 @@ class LiteLLMOpenAIMixin( params = await self._get_params(request) logger.debug(f"params to litellm (openai compat): {params}") - # unfortunately, we need to use synchronous litellm.completion here because litellm - # caches various httpx.client objects in a non-eventloop aware manner - response = litellm.completion(**params) + # Litellm seems to have implemented an async completion method + # https://docs.litellm.ai/docs/completion/stream#async-completion + response = await litellm.acompletion(**params) if stream: return self._stream_chat_completion(response) else: @@ -124,7 +124,7 @@ class LiteLLMOpenAIMixin( self, response: litellm.ModelResponse ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: async def _stream_generator(): - for chunk in response: + async for chunk in response: yield chunk async for chunk in convert_openai_chat_completion_stream( @@ -223,10 +223,10 @@ class LiteLLMOpenAIMixin( ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - response = litellm.embedding( + response = await litellm.embedding( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], ) - embeddings = [data["embedding"] for data in response["data"]] + embeddings = await [data["embedding"] for data in response["data"]] return EmbeddingsResponse(embeddings=embeddings)