mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 10:20:00 +00:00
convert blocking calls to async
Signed-off-by: Jaideep Rao <jrao@redhat.com>
This commit is contained in:
parent
5403582582
commit
66412ab12b
3 changed files with 23 additions and 18 deletions
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
|
@ -85,24 +85,24 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
r = await client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
s = await client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue