mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
convert blocking calls to async
Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
5403582582
commit
66412ab12b
3 changed files with 23 additions and 18 deletions
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
@ -85,24 +85,24 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: OpenAI
|
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
r = client.completions.create(**params)
|
r = await client.completions.create(**params)
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
s = client.completions.create(**params)
|
s = await client.completions.create(**params)
|
||||||
for chunk in s:
|
for chunk in s:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -37,13 +38,12 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(
|
# Execute the synchronous encode method in an executor
|
||||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
embeddings = await self._run_in_executor(embedding_model.encode, [interleaved_content_as_str(content) for content in contents], show_progress_bar=False)
|
||||||
)
|
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODELS
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
@ -53,6 +53,11 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
# Execute the synchronous SentenceTransformer instantiation in an executor
|
||||||
|
loaded_model = await self._run_in_executor(SentenceTransformer, model)
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
async def _run_in_executor(self, func, *args, **kwargs):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, func, *args, **kwargs)
|
|
@ -112,9 +112,9 @@ class LiteLLMOpenAIMixin(
|
||||||
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
logger.debug(f"params to litellm (openai compat): {params}")
|
logger.debug(f"params to litellm (openai compat): {params}")
|
||||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
# Litellm seems to have implemented an async completion method
|
||||||
# caches various httpx.client objects in a non-eventloop aware manner
|
# https://docs.litellm.ai/docs/completion/stream#async-completion
|
||||||
response = litellm.completion(**params)
|
response = await litellm.acompletion(**params)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(response)
|
return self._stream_chat_completion(response)
|
||||||
else:
|
else:
|
||||||
|
@ -124,7 +124,7 @@ class LiteLLMOpenAIMixin(
|
||||||
self, response: litellm.ModelResponse
|
self, response: litellm.ModelResponse
|
||||||
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
async def _stream_generator():
|
async def _stream_generator():
|
||||||
for chunk in response:
|
async for chunk in response:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async for chunk in convert_openai_chat_completion_stream(
|
async for chunk in convert_openai_chat_completion_stream(
|
||||||
|
@ -223,10 +223,10 @@ class LiteLLMOpenAIMixin(
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
response = litellm.embedding(
|
response = await litellm.embedding(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = [data["embedding"] for data in response["data"]]
|
embeddings = await [data["embedding"] for data in response["data"]]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue