mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 09:12:37 +00:00
async call in separate thread
This commit is contained in:
parent
ae43044a57
commit
adb768f827
5 changed files with 72 additions and 55 deletions
|
|
@ -28,14 +28,17 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
async def run_evals(
|
||||
self,
|
||||
model: str,
|
||||
dataset: str,
|
||||
task: str,
|
||||
dataset: Optional[str] = None,
|
||||
) -> EvaluateResponse:
|
||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||
if not dataset:
|
||||
raise ValueError("dataset must be specified for mete-reference evals")
|
||||
|
||||
dataset = DatasetRegistry.get_dataset(dataset)
|
||||
dataset.load()
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
x1 = task_impl.preprocess()
|
||||
|
||||
# TODO: replace w/ batch inference & async return eval job
|
||||
|
|
|
|||
|
|
@ -91,47 +91,52 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
def _nonstream_chat_completion(
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
async with SEMAPHORE:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
tokens.append(token_result.token)
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={token_result.text: token_result.logprobs[0]}
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue