address comments

This commit is contained in:
Xi Yan 2024-10-25 11:08:40 -07:00
parent 6b0baa6d53
commit 52fe165db8
2 changed files with 50 additions and 26 deletions

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
@ -16,6 +17,13 @@ from llama_stack.apis.scoring import Scoring
from .config import MetaReferenceEvalConfig
class ColumnName(Enum):
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval):
def __init__(
self,
@ -41,18 +49,16 @@ class MetaReferenceEvalImpl(Eval):
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
{
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
"expected_answer": StringType(),
"chat_completion_input": CompletionInputType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
@ -94,27 +100,43 @@ class MetaReferenceEvalImpl(Eval):
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
)
assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = []
for x in input_rows:
if "completion_input" in x:
raise NotImplementedError(
"Evaluation with completion API has not been implemented"
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
)
input_messages = eval(str(x["chat_completion_input"]))
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append(
{"generated_answer": response.completion_message.content}
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x:
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else:
raise ValueError("Invalid input row")
# scoring with generated_answer
score_input_rows = [
@ -132,6 +154,8 @@ class MetaReferenceEvalImpl(Eval):
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")

View file

@ -62,7 +62,7 @@ async def test_eval(eval_settings):
response = await eval_impl.evaluate_batch(
dataset_id=response[0].identifier,
candidate=ModelCandidate(
model="Llama3.1-8B-Instruct",
model="Llama3.2-1B-Instruct",
sampling_params=SamplingParams(),
),
scoring_functions=["subset_of"],