From adb768f8272c6d7a163b9246b233b831bfbe48ba Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 9 Oct 2024 13:18:15 -0700 Subject: [PATCH] async call in separate thread --- llama_stack/apis/evals/client.py | 29 ++++---- llama_stack/apis/evals/evals.py | 2 +- .../impls/meta_reference/evals/evals.py | 7 +- .../meta_reference/inference/inference.py | 67 ++++++++++--------- .../third_party/evals/eleuther/eleuther.py | 22 ++++-- 5 files changed, 72 insertions(+), 55 deletions(-) diff --git a/llama_stack/apis/evals/client.py b/llama_stack/apis/evals/client.py index 4acbff5f6..3c9ba3bca 100644 --- a/llama_stack/apis/evals/client.py +++ b/llama_stack/apis/evals/client.py @@ -23,14 +23,16 @@ class EvaluationClient(Evals): async def shutdown(self) -> None: pass - async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse: + async def run_evals( + self, model: str, task: str, dataset: Optional[str] = None + ) -> EvaluateResponse: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/evals/run", json={ "model": model, - "dataset": dataset, "task": task, + "dataset": dataset, }, headers={"Content-Type": "application/json"}, timeout=3600, @@ -43,20 +45,19 @@ async def run_main(host: str, port: int): client = EvaluationClient(f"http://{host}:{port}") # CustomDataset - response = await client.run_evals( - "Llama3.1-8B-Instruct", - "mmlu-simple-eval-en", - "mmlu", - ) - cprint(f"evaluate response={response}", "green") - - # Eleuther Eval # response = await client.run_evals( - # "Llama3.1-8B-Instruct", - # "PLACEHOLDER_DATASET_NAME", - # "mmlu", + # model="Llama3.1-8B-Instruct", + # dataset="mmlu-simple-eval-en", + # task="mmlu", # ) - # cprint(response.metrics["metrics_table"], "red") + # cprint(f"evaluate response={response}", "green") + + # Eleuther Eval Task + response = await client.run_evals( + model="Llama3.1-8B-Instruct", + task="meta_mmlu_pro_instruct", + ) + cprint(response.metrics["metrics_table"], "red") def main(host: str, port: int): diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index bc9215993..3acc9e68b 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -64,8 +64,8 @@ class Evals(Protocol): async def run_evals( self, model: str, - dataset: str, task: str, + dataset: Optional[str] = None, ) -> EvaluateResponse: ... @webmethod(route="/evals/jobs") diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py index 9077a4905..e648e7dad 100644 --- a/llama_stack/providers/impls/meta_reference/evals/evals.py +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -28,14 +28,17 @@ class MetaReferenceEvalsImpl(Evals): async def run_evals( self, model: str, - dataset: str, task: str, + dataset: Optional[str] = None, ) -> EvaluateResponse: cprint(f"model={model}, dataset={dataset}, task={task}", "red") + if not dataset: + raise ValueError("dataset must be specified for mete-reference evals") + dataset = DatasetRegistry.get_dataset(dataset) dataset.load() - task_impl = TaskRegistry.get_task(task)(dataset) + task_impl = TaskRegistry.get_task(task)(dataset) x1 = task_impl.preprocess() # TODO: replace w/ batch inference & async return eval job diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index d9b3cefd8..bda5e54c1 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -91,47 +91,52 @@ class MetaReferenceInferenceImpl(Inference): else: return self._nonstream_chat_completion(request) - def _nonstream_chat_completion( + async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - messages = chat_completion_request_to_messages(request) + async with SEMAPHORE: + messages = chat_completion_request_to_messages(request) - tokens = [] - logprobs = [] - stop_reason = None + tokens = [] + logprobs = [] + stop_reason = None - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): - tokens.append(token_result.token) + for token_result in self.generator.chat_completion( + messages=messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, + ): + tokens.append(token_result.token) - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message - if request.logprobs: - assert len(token_result.logprobs) == 1 + if request.logprobs: + assert len(token_result.logprobs) == 1 - logprobs.append( - TokenLogProbs( - logprobs_by_token={token_result.text: token_result.logprobs[0]} + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) ) - ) - if stop_reason is None: - stop_reason = StopReason.out_of_tokens + if stop_reason is None: + stop_reason = StopReason.out_of_tokens - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - return ChatCompletionResponse( - completion_message=message, - logprobs=logprobs if request.logprobs else None, - ) + message = self.generator.formatter.decode_assistant_message( + tokens, stop_reason + ) + return ChatCompletionResponse( + completion_message=message, + logprobs=logprobs if request.logprobs else None, + ) async def _stream_chat_completion( self, request: ChatCompletionRequest diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py index ab27fcaee..102d15020 100644 --- a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py +++ b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403 import os import random +import threading from pathlib import Path import lm_eval @@ -19,6 +21,12 @@ from termcolor import cprint from .config import EleutherEvalsImplConfig # noqa +# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr +# We will use another thread wih its own event loop to run the async api within sync function +_loop = asyncio.new_event_loop() +_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True) + + class EleutherEvalsWrapper(LM): def __init__( self, @@ -89,8 +97,10 @@ class EleutherEvalsWrapper(LM): def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: res = [] + if not _thr.is_alive(): + _thr.start() for req in requests: - response = self.inference_api.chat_completion( + chat_completion_coro_fn = self.inference_api.chat_completion( model=self.model, messages=[ { @@ -100,7 +110,8 @@ class EleutherEvalsWrapper(LM): ], stream=False, ) - print(response) + future = asyncio.run_coroutine_threadsafe(chat_completion_coro_fn, _loop) + response = future.result() res.append(response.completion_message.content) return res @@ -119,16 +130,13 @@ class EleutherEvalsAdapter(Evals): async def run_evals( self, model: str, - dataset: str, task: str, + dataset: Optional[str] = None, ) -> EvaluateResponse: - eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model) - cprint(f"Eleuther Evals: {model} {dataset} {task}", "red") - task = "meta_mmlu_pro_instruct" + eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model) current_dir = Path(os.path.dirname(os.path.abspath(__file__))) - print(current_dir) task_manager = TaskManager( include_path=str(current_dir / "tasks"),