From 29e48cc5c100962b6c624616a529f61cecd955cb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 16:11:25 -0700 Subject: [PATCH] chatcompletion & completion input type validation --- llama_stack/apis/common/type_system.py | 18 ++++++++++ .../impls/meta_reference/eval/eval.py | 36 ++++++++++--------- .../tests/datasetio/test_dataset.csv | 12 +++---- .../tests/datasetio/test_datasetio.py | 19 ++++++---- llama_stack/providers/tests/eval/test_eval.py | 2 +- 5 files changed, 57 insertions(+), 30 deletions(-) diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index 35a26e9ef..5f94e85f3 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -46,6 +46,21 @@ class CustomType(BaseModel): validator_class: str +class ChatCompletionInputType(BaseModel): + # expects List[Message] for messages + type: Literal["chat_completion_input"] = "chat_completion_input" + + +class CompletionInputType(BaseModel): + # expects InterleavedTextMedia for content + type: Literal["completion_input"] = "completion_input" + + +class AgentTurnInputType(BaseModel): + # expects List[Message] for messages (may also include attachments?) + type: Literal["agent_turn_input"] = "agent_turn_input" + + ParamType = Annotated[ Union[ StringType, @@ -56,6 +71,9 @@ ParamType = Annotated[ JsonType, UnionType, CustomType, + ChatCompletionInputType, + CompletionInputType, + AgentTurnInputType, ], Field(discriminator="type"), ] diff --git a/llama_stack/providers/impls/meta_reference/eval/eval.py b/llama_stack/providers/impls/meta_reference/eval/eval.py index e5e2bcdc0..70f523040 100644 --- a/llama_stack/providers/impls/meta_reference/eval/eval.py +++ b/llama_stack/providers/impls/meta_reference/eval/eval.py @@ -5,10 +5,11 @@ # the root directory of this source tree. from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.job_types import Job from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring @@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval): f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." ) - # TODO: we will require user defined message types for ToolResponseMessage or include message.context - # for now uses basic schema where messages={type: "user", content: "input_query"} - for required_column in ["expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) + expected_schemas = [ + { + "expected_answer": StringType(), + "chat_completion_input": ChatCompletionInputType(), + }, + { + "expected_answer": StringType(), + "chat_completion_input": CompletionInputType(), + }, + ] + + if dataset_def.dataset_schema not in expected_schemas: + raise ValueError( + f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" + ) async def evaluate_batch( self, @@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval): ) generations = [] for x in input_rows: - input_query = x["input_query"] + input_messages = eval(str(x["chat_completion_input"])) + input_messages = [UserMessage(**x) for x in input_messages] messages = [] if candidate.system_message: messages.append(candidate.system_message) - messages.append( - UserMessage(content=input_query), - ) + messages += input_messages response = await self.inference_api.chat_completion( model=candidate.model, messages=messages, diff --git a/llama_stack/providers/tests/datasetio/test_dataset.csv b/llama_stack/providers/tests/datasetio/test_dataset.csv index a1a250753..f682c6d3d 100644 --- a/llama_stack/providers/tests/datasetio/test_dataset.csv +++ b/llama_stack/providers/tests/datasetio/test_dataset.csv @@ -1,6 +1,6 @@ -input_query,generated_answer,expected_answer -What is the capital of France?,London,Paris -Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg -What is the largest planet in our solar system?,Jupiter,Jupiter -What is the smallest country in the world?,China,Vatican City -What is the currency of Japan?,Yen,Yen +input_query,generated_answer,expected_answer,chat_completion_input +What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]" +Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]" +What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]" +What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]" +What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]" diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 755ed9735..9bd80f94d 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -62,17 +62,22 @@ def data_url_from_file(file_path: str) -> str: async def register_dataset( - datasets_impl: Datasets, include_generated_answer=True, dataset_id="test_dataset" + datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" ): test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" test_url = data_url_from_file(str(test_file)) - dataset_schema = { - "expected_answer": StringType(), - "input_query": StringType(), - } - if include_generated_answer: - dataset_schema["generated_answer"] = StringType() + if for_generation: + dataset_schema = { + "expected_answer": StringType(), + "chat_completion_input": ChatCompletionInputType(), + } + else: + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + "generated_answer": StringType(), + } dataset = DatasetDefWithProvider( identifier=dataset_id, diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 3a1ca169b..e4f47f8c3 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -51,7 +51,7 @@ async def test_eval(eval_settings): datasets_impl = eval_settings["datasets_impl"] await register_dataset( datasets_impl, - include_generated_answer=False, + for_generation=True, dataset_id="test_dataset_for_eval", )