mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat: update eval runner to use openai endpoints
note: this breaks backward compatibility if eval setup used sampling params' repetition_penalty or strategy
This commit is contained in:
parent
65f7b81e98
commit
5ef2291a51
6 changed files with 300 additions and 11 deletions
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
|
@ -159,31 +159,40 @@ class MetaReferenceEvalImpl(
|
|||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
if candidate.sampling_params.stop:
|
||||
sampling_params["stop"] = candidate.sampling_params.stop
|
||||
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.completion(
|
||||
response = await self.inference_api.openai_completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
prompt=input_content,
|
||||
**sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
input_messages = [
|
||||
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
|
||||
]
|
||||
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
**sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue