async call in separate thread

This commit is contained in:
Xi Yan 2024-10-09 13:18:15 -07:00
parent ae43044a57
commit adb768f827
5 changed files with 72 additions and 55 deletions

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
import os
import random
import threading
from pathlib import Path
import lm_eval
@ -19,6 +21,12 @@ from termcolor import cprint
from .config import EleutherEvalsImplConfig # noqa
# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr
# We will use another thread wih its own event loop to run the async api within sync function
_loop = asyncio.new_event_loop()
_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)
class EleutherEvalsWrapper(LM):
def __init__(
self,
@ -89,8 +97,10 @@ class EleutherEvalsWrapper(LM):
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
res = []
if not _thr.is_alive():
_thr.start()
for req in requests:
response = self.inference_api.chat_completion(
chat_completion_coro_fn = self.inference_api.chat_completion(
model=self.model,
messages=[
{
@ -100,7 +110,8 @@ class EleutherEvalsWrapper(LM):
],
stream=False,
)
print(response)
future = asyncio.run_coroutine_threadsafe(chat_completion_coro_fn, _loop)
response = future.result()
res.append(response.completion_message.content)
return res
@ -119,16 +130,13 @@ class EleutherEvalsAdapter(Evals):
async def run_evals(
self,
model: str,
dataset: str,
task: str,
dataset: Optional[str] = None,
) -> EvaluateResponse:
eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
cprint(f"Eleuther Evals: {model} {dataset} {task}", "red")
task = "meta_mmlu_pro_instruct"
eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
print(current_dir)
task_manager = TaskManager(
include_path=str(current_dir / "tasks"),