Merge branch 'rag_scoring_fn_2' into rag_scoring_fn_3

This commit is contained in:
Xi Yan 2024-12-30 17:59:19 -08:00
commit e4d8290f72
6 changed files with 117 additions and 106 deletions

View file

@ -13,11 +13,13 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -83,7 +85,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_eval(dataset_def.dataset_schema)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(