async call in separate thread

This commit is contained in:
Xi Yan 2024-10-09 13:18:15 -07:00
parent ae43044a57
commit adb768f827
5 changed files with 72 additions and 55 deletions

View file

@ -23,14 +23,16 @@ class EvaluationClient(Evals):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse: async def run_evals(
self, model: str, task: str, dataset: Optional[str] = None
) -> EvaluateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/evals/run", f"{self.base_url}/evals/run",
json={ json={
"model": model, "model": model,
"dataset": dataset,
"task": task, "task": task,
"dataset": dataset,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=3600, timeout=3600,
@ -43,20 +45,19 @@ async def run_main(host: str, port: int):
client = EvaluationClient(f"http://{host}:{port}") client = EvaluationClient(f"http://{host}:{port}")
# CustomDataset # CustomDataset
response = await client.run_evals(
"Llama3.1-8B-Instruct",
"mmlu-simple-eval-en",
"mmlu",
)
cprint(f"evaluate response={response}", "green")
# Eleuther Eval
# response = await client.run_evals( # response = await client.run_evals(
# "Llama3.1-8B-Instruct", # model="Llama3.1-8B-Instruct",
# "PLACEHOLDER_DATASET_NAME", # dataset="mmlu-simple-eval-en",
# "mmlu", # task="mmlu",
# ) # )
# cprint(response.metrics["metrics_table"], "red") # cprint(f"evaluate response={response}", "green")
# Eleuther Eval Task
response = await client.run_evals(
model="Llama3.1-8B-Instruct",
task="meta_mmlu_pro_instruct",
)
cprint(response.metrics["metrics_table"], "red")
def main(host: str, port: int): def main(host: str, port: int):

View file

@ -64,8 +64,8 @@ class Evals(Protocol):
async def run_evals( async def run_evals(
self, self,
model: str, model: str,
dataset: str,
task: str, task: str,
dataset: Optional[str] = None,
) -> EvaluateResponse: ... ) -> EvaluateResponse: ...
@webmethod(route="/evals/jobs") @webmethod(route="/evals/jobs")

View file

@ -28,14 +28,17 @@ class MetaReferenceEvalsImpl(Evals):
async def run_evals( async def run_evals(
self, self,
model: str, model: str,
dataset: str,
task: str, task: str,
dataset: Optional[str] = None,
) -> EvaluateResponse: ) -> EvaluateResponse:
cprint(f"model={model}, dataset={dataset}, task={task}", "red") cprint(f"model={model}, dataset={dataset}, task={task}", "red")
if not dataset:
raise ValueError("dataset must be specified for mete-reference evals")
dataset = DatasetRegistry.get_dataset(dataset) dataset = DatasetRegistry.get_dataset(dataset)
dataset.load() dataset.load()
task_impl = TaskRegistry.get_task(task)(dataset)
task_impl = TaskRegistry.get_task(task)(dataset)
x1 = task_impl.preprocess() x1 = task_impl.preprocess()
# TODO: replace w/ batch inference & async return eval job # TODO: replace w/ batch inference & async return eval job

View file

@ -91,47 +91,52 @@ class MetaReferenceInferenceImpl(Inference):
else: else:
return self._nonstream_chat_completion(request) return self._nonstream_chat_completion(request)
def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
messages = chat_completion_request_to_messages(request) async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
tokens = [] tokens = []
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(
messages=messages, messages=messages,
temperature=request.sampling_params.temperature, temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p, top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens, max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs, logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format, tool_prompt_format=request.tool_prompt_format,
): ):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>": elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append( logprobs.append(
TokenLogProbs( TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]} logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
) )
)
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) message = self.generator.formatter.decode_assistant_message(
return ChatCompletionResponse( tokens, stop_reason
completion_message=message, )
logprobs=logprobs if request.logprobs else None, return ChatCompletionResponse(
) completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
import os import os
import random import random
import threading
from pathlib import Path from pathlib import Path
import lm_eval import lm_eval
@ -19,6 +21,12 @@ from termcolor import cprint
from .config import EleutherEvalsImplConfig # noqa from .config import EleutherEvalsImplConfig # noqa
# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr
# We will use another thread wih its own event loop to run the async api within sync function
_loop = asyncio.new_event_loop()
_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)
class EleutherEvalsWrapper(LM): class EleutherEvalsWrapper(LM):
def __init__( def __init__(
self, self,
@ -89,8 +97,10 @@ class EleutherEvalsWrapper(LM):
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
res = [] res = []
if not _thr.is_alive():
_thr.start()
for req in requests: for req in requests:
response = self.inference_api.chat_completion( chat_completion_coro_fn = self.inference_api.chat_completion(
model=self.model, model=self.model,
messages=[ messages=[
{ {
@ -100,7 +110,8 @@ class EleutherEvalsWrapper(LM):
], ],
stream=False, stream=False,
) )
print(response) future = asyncio.run_coroutine_threadsafe(chat_completion_coro_fn, _loop)
response = future.result()
res.append(response.completion_message.content) res.append(response.completion_message.content)
return res return res
@ -119,16 +130,13 @@ class EleutherEvalsAdapter(Evals):
async def run_evals( async def run_evals(
self, self,
model: str, model: str,
dataset: str,
task: str, task: str,
dataset: Optional[str] = None,
) -> EvaluateResponse: ) -> EvaluateResponse:
eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
cprint(f"Eleuther Evals: {model} {dataset} {task}", "red") cprint(f"Eleuther Evals: {model} {dataset} {task}", "red")
task = "meta_mmlu_pro_instruct" eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
current_dir = Path(os.path.dirname(os.path.abspath(__file__))) current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
print(current_dir)
task_manager = TaskManager( task_manager = TaskManager(
include_path=str(current_dir / "tasks"), include_path=str(current_dir / "tasks"),