From 52fe165db8f467be828bfd2b4425cab640494edf Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 25 Oct 2024 11:08:40 -0700 Subject: [PATCH] address comments --- .../impls/meta_reference/eval/eval.py | 74 ++++++++++++------- llama_stack/providers/tests/eval/test_eval.py | 2 +- 2 files changed, 50 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/eval/eval.py b/llama_stack/providers/impls/meta_reference/eval/eval.py index daa17a89e..d675e40eb 100644 --- a/llama_stack/providers/impls/meta_reference/eval/eval.py +++ b/llama_stack/providers/impls/meta_reference/eval/eval.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 @@ -16,6 +17,13 @@ from llama_stack.apis.scoring import Scoring from .config import MetaReferenceEvalConfig +class ColumnName(Enum): + expected_answer = "expected_answer" + chat_completion_input = "chat_completion_input" + completion_input = "completion_input" + generated_answer = "generated_answer" + + class MetaReferenceEvalImpl(Eval): def __init__( self, @@ -41,18 +49,16 @@ class MetaReferenceEvalImpl(Eval): async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") expected_schemas = [ { - "expected_answer": StringType(), - "chat_completion_input": ChatCompletionInputType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.chat_completion_input.value: ChatCompletionInputType(), }, { - "expected_answer": StringType(), - "chat_completion_input": CompletionInputType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.completion_input.value: CompletionInputType(), }, ] @@ -94,27 +100,43 @@ class MetaReferenceEvalImpl(Eval): raise NotImplementedError( "Evaluation with generation has not been implemented for agents" ) + assert ( + candidate.sampling_params.max_tokens is not None + ), "SamplingParams.max_tokens must be provided" + generations = [] for x in input_rows: - if "completion_input" in x: - raise NotImplementedError( - "Evaluation with completion API has not been implemented" + if ColumnName.completion_input.value in x: + input_content = eval(str(x[ColumnName.completion_input.value])) + response = await self.inference_api.completion( + model=candidate.model, + content=input_content, + sampling_params=candidate.sampling_params, ) - - input_messages = eval(str(x["chat_completion_input"])) - input_messages = [UserMessage(**x) for x in input_messages] - messages = [] - if candidate.system_message: - messages.append(candidate.system_message) - messages += input_messages - response = await self.inference_api.chat_completion( - model=candidate.model, - messages=messages, - sampling_params=candidate.sampling_params, - ) - generations.append( - {"generated_answer": response.completion_message.content} - ) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) + elif ColumnName.chat_completion_input.value in x: + input_messages = eval(str(x[ColumnName.chat_completion_input.value])) + input_messages = [UserMessage(**x) for x in input_messages] + messages = [] + if candidate.system_message: + messages.append(candidate.system_message) + messages += input_messages + response = await self.inference_api.chat_completion( + model=candidate.model, + messages=messages, + sampling_params=candidate.sampling_params, + ) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) + else: + raise ValueError("Invalid input row") # scoring with generated_answer score_input_rows = [ @@ -132,6 +154,8 @@ class MetaReferenceEvalImpl(Eval): if job_id in self.jobs: return JobStatus.completed + return None + async def job_cancel(self, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 4632cdd96..6b0d99a22 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -62,7 +62,7 @@ async def test_eval(eval_settings): response = await eval_impl.evaluate_batch( dataset_id=response[0].identifier, candidate=ModelCandidate( - model="Llama3.1-8B-Instruct", + model="Llama3.2-1B-Instruct", sampling_params=SamplingParams(), ), scoring_functions=["subset_of"],