mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
async call in separate thread
This commit is contained in:
parent
ae43044a57
commit
adb768f827
5 changed files with 72 additions and 55 deletions
|
@ -23,14 +23,16 @@ class EvaluationClient(Evals):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse:
|
async def run_evals(
|
||||||
|
self, model: str, task: str, dataset: Optional[str] = None
|
||||||
|
) -> EvaluateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/evals/run",
|
f"{self.base_url}/evals/run",
|
||||||
json={
|
json={
|
||||||
"model": model,
|
"model": model,
|
||||||
"dataset": dataset,
|
|
||||||
"task": task,
|
"task": task,
|
||||||
|
"dataset": dataset,
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=3600,
|
timeout=3600,
|
||||||
|
@ -43,20 +45,19 @@ async def run_main(host: str, port: int):
|
||||||
client = EvaluationClient(f"http://{host}:{port}")
|
client = EvaluationClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# CustomDataset
|
# CustomDataset
|
||||||
response = await client.run_evals(
|
|
||||||
"Llama3.1-8B-Instruct",
|
|
||||||
"mmlu-simple-eval-en",
|
|
||||||
"mmlu",
|
|
||||||
)
|
|
||||||
cprint(f"evaluate response={response}", "green")
|
|
||||||
|
|
||||||
# Eleuther Eval
|
|
||||||
# response = await client.run_evals(
|
# response = await client.run_evals(
|
||||||
# "Llama3.1-8B-Instruct",
|
# model="Llama3.1-8B-Instruct",
|
||||||
# "PLACEHOLDER_DATASET_NAME",
|
# dataset="mmlu-simple-eval-en",
|
||||||
# "mmlu",
|
# task="mmlu",
|
||||||
# )
|
# )
|
||||||
# cprint(response.metrics["metrics_table"], "red")
|
# cprint(f"evaluate response={response}", "green")
|
||||||
|
|
||||||
|
# Eleuther Eval Task
|
||||||
|
response = await client.run_evals(
|
||||||
|
model="Llama3.1-8B-Instruct",
|
||||||
|
task="meta_mmlu_pro_instruct",
|
||||||
|
)
|
||||||
|
cprint(response.metrics["metrics_table"], "red")
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
|
|
|
@ -64,8 +64,8 @@ class Evals(Protocol):
|
||||||
async def run_evals(
|
async def run_evals(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
dataset: str,
|
|
||||||
task: str,
|
task: str,
|
||||||
|
dataset: Optional[str] = None,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/evals/jobs")
|
@webmethod(route="/evals/jobs")
|
||||||
|
|
|
@ -28,14 +28,17 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
async def run_evals(
|
async def run_evals(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
dataset: str,
|
|
||||||
task: str,
|
task: str,
|
||||||
|
dataset: Optional[str] = None,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("dataset must be specified for mete-reference evals")
|
||||||
|
|
||||||
dataset = DatasetRegistry.get_dataset(dataset)
|
dataset = DatasetRegistry.get_dataset(dataset)
|
||||||
dataset.load()
|
dataset.load()
|
||||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
|
||||||
|
|
||||||
|
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||||
x1 = task_impl.preprocess()
|
x1 = task_impl.preprocess()
|
||||||
|
|
||||||
# TODO: replace w/ batch inference & async return eval job
|
# TODO: replace w/ batch inference & async return eval job
|
||||||
|
|
|
@ -91,47 +91,52 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
messages = chat_completion_request_to_messages(request)
|
async with SEMAPHORE:
|
||||||
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=request.sampling_params.temperature,
|
temperature=request.sampling_params.temperature,
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=request.sampling_params.top_p,
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
):
|
):
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif token_result.text == "<|eom_id|>":
|
elif token_result.text == "<|eom_id|>":
|
||||||
stop_reason = StopReason.end_of_message
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(
|
logprobs.append(
|
||||||
TokenLogProbs(
|
TokenLogProbs(
|
||||||
logprobs_by_token={token_result.text: token_result.logprobs[0]}
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
return ChatCompletionResponse(
|
tokens, stop_reason
|
||||||
completion_message=message,
|
)
|
||||||
logprobs=logprobs if request.logprobs else None,
|
return ChatCompletionResponse(
|
||||||
)
|
completion_message=message,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
|
|
|
@ -4,10 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import lm_eval
|
import lm_eval
|
||||||
|
@ -19,6 +21,12 @@ from termcolor import cprint
|
||||||
from .config import EleutherEvalsImplConfig # noqa
|
from .config import EleutherEvalsImplConfig # noqa
|
||||||
|
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr
|
||||||
|
# We will use another thread wih its own event loop to run the async api within sync function
|
||||||
|
_loop = asyncio.new_event_loop()
|
||||||
|
_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)
|
||||||
|
|
||||||
|
|
||||||
class EleutherEvalsWrapper(LM):
|
class EleutherEvalsWrapper(LM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -89,8 +97,10 @@ class EleutherEvalsWrapper(LM):
|
||||||
|
|
||||||
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
||||||
res = []
|
res = []
|
||||||
|
if not _thr.is_alive():
|
||||||
|
_thr.start()
|
||||||
for req in requests:
|
for req in requests:
|
||||||
response = self.inference_api.chat_completion(
|
chat_completion_coro_fn = self.inference_api.chat_completion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -100,7 +110,8 @@ class EleutherEvalsWrapper(LM):
|
||||||
],
|
],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
print(response)
|
future = asyncio.run_coroutine_threadsafe(chat_completion_coro_fn, _loop)
|
||||||
|
response = future.result()
|
||||||
res.append(response.completion_message.content)
|
res.append(response.completion_message.content)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -119,16 +130,13 @@ class EleutherEvalsAdapter(Evals):
|
||||||
async def run_evals(
|
async def run_evals(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
dataset: str,
|
|
||||||
task: str,
|
task: str,
|
||||||
|
dataset: Optional[str] = None,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
|
|
||||||
|
|
||||||
cprint(f"Eleuther Evals: {model} {dataset} {task}", "red")
|
cprint(f"Eleuther Evals: {model} {dataset} {task}", "red")
|
||||||
|
|
||||||
task = "meta_mmlu_pro_instruct"
|
eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
|
||||||
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
|
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
|
||||||
print(current_dir)
|
|
||||||
|
|
||||||
task_manager = TaskManager(
|
task_manager = TaskManager(
|
||||||
include_path=str(current_dir / "tasks"),
|
include_path=str(current_dir / "tasks"),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue