convert blocking calls to async

Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
Jaideep Rao 2025-03-14 13:36:27 -04:00
parent 5403582582
commit 66412ab12b
3 changed files with 23 additions and 18 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator
from openai import OpenAI from openai import AsyncOpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -85,24 +85,24 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
tool_config=tool_config, tool_config=tool_config,
) )
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request, client)
else: else:
return await self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = client.completions.create(**params) r = await client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
params = self._get_params(request) params = self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
s = client.completions.create(**params) s = await client.completions.create(**params)
for chunk in s: for chunk in s:
yield chunk yield chunk

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import asyncio
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@ -37,13 +38,12 @@ class SentenceTransformerEmbeddingMixin:
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode( # Execute the synchronous encode method in an executor
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False embeddings = await self._run_in_executor(embedding_model.encode, [interleaved_content_as_str(content) for content in contents], show_progress_bar=False)
)
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model) loaded_model = EMBEDDING_MODELS.get(model)
@ -53,6 +53,11 @@ class SentenceTransformerEmbeddingMixin:
log.info(f"Loading sentence transformer for {model}...") log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model) # Execute the synchronous SentenceTransformer instantiation in an executor
loaded_model = await self._run_in_executor(SentenceTransformer, model)
EMBEDDING_MODELS[model] = loaded_model EMBEDDING_MODELS[model] = loaded_model
return loaded_model return loaded_model
async def _run_in_executor(self, func, *args, **kwargs):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, func, *args, **kwargs)

View file

@ -112,9 +112,9 @@ class LiteLLMOpenAIMixin(
params = await self._get_params(request) params = await self._get_params(request)
logger.debug(f"params to litellm (openai compat): {params}") logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm # Litellm seems to have implemented an async completion method
# caches various httpx.client objects in a non-eventloop aware manner # https://docs.litellm.ai/docs/completion/stream#async-completion
response = litellm.completion(**params) response = await litellm.acompletion(**params)
if stream: if stream:
return self._stream_chat_completion(response) return self._stream_chat_completion(response)
else: else:
@ -124,7 +124,7 @@ class LiteLLMOpenAIMixin(
self, response: litellm.ModelResponse self, response: litellm.ModelResponse
) -> AsyncIterator[ChatCompletionResponseStreamChunk]: ) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
async def _stream_generator(): async def _stream_generator():
for chunk in response: async for chunk in response:
yield chunk yield chunk
async for chunk in convert_openai_chat_completion_stream( async for chunk in convert_openai_chat_completion_stream(
@ -223,10 +223,10 @@ class LiteLLMOpenAIMixin(
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
response = litellm.embedding( response = await litellm.embedding(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )
embeddings = [data["embedding"] for data in response["data"]] embeddings = await [data["embedding"] for data in response["data"]]
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)