chatcompletion & completion input type validation

This commit is contained in:
Xi Yan 2024-10-24 16:11:25 -07:00
parent e468e23249
commit 29e48cc5c1
5 changed files with 57 additions and 30 deletions

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
# TODO: we will require user defined message types for ToolResponseMessage or include message.context
# for now uses basic schema where messages={type: "user", content: "input_query"}
for required_column in ["expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
expected_schemas = [
{
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
},
{
"expected_answer": StringType(),
"chat_completion_input": CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch(
self,
@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval):
)
generations = []
for x in input_rows:
input_query = x["input_query"]
input_messages = eval(str(x["chat_completion_input"]))
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages.append(
UserMessage(content=input_query),
)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
messages=messages,