mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
chatcompletion & completion input type validation
This commit is contained in:
parent
e468e23249
commit
29e48cc5c1
5 changed files with 57 additions and 30 deletions
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
||||
|
@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval):
|
|||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
# TODO: we will require user defined message types for ToolResponseMessage or include message.context
|
||||
# for now uses basic schema where messages={type: "user", content: "input_query"}
|
||||
for required_column in ["expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
expected_schemas = [
|
||||
{
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
},
|
||||
{
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": CompletionInputType(),
|
||||
},
|
||||
]
|
||||
|
||||
if dataset_def.dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def evaluate_batch(
|
||||
self,
|
||||
|
@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval):
|
|||
)
|
||||
generations = []
|
||||
for x in input_rows:
|
||||
input_query = x["input_query"]
|
||||
input_messages = eval(str(x["chat_completion_input"]))
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages.append(
|
||||
UserMessage(content=input_query),
|
||||
)
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue