refactor schema check

This commit is contained in:
Xi Yan 2024-12-19 16:20:47 -08:00
parent 55e4f4eeb3
commit c15b0d5395
7 changed files with 82 additions and 119 deletions

View file

@ -14,9 +14,10 @@ from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_utils import (
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
ColumnName,
get_expected_schema_for_eval,
DataSchemaValidatorMixin,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -28,7 +29,7 @@ from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -72,17 +73,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
)
self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
# async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
# if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
# raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = get_expected_schema_for_eval()
# expected_schemas = get_expected_schema_for_eval()
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
# if dataset_def.dataset_schema not in expected_schemas:
# raise ValueError(
# f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
# )
async def run_eval(
self,
@ -93,8 +94,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_eval(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(

View file

@ -12,6 +12,9 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
@ -21,7 +24,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: BasicScoringConfig,
@ -58,30 +63,15 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,

View file

@ -15,19 +15,18 @@ from llama_stack.apis.datasets import * # noqa: F403
import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness, AnswerRelevancy
from autoevals.ragas import AnswerCorrectness
from pydantic import BaseModel
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_utils import (
get_expected_schema_for_scoring,
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
@ -48,16 +47,14 @@ SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
]
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
DataSchemaValidatorMixin,
):
def __init__(
self,
@ -96,32 +93,6 @@ class BraintrustScoringImpl(
"Registering scoring function not allowed for braintrust provider"
)
async def validate_scoring_input_row_schema(
self, input_row: Dict[str, Any]
) -> None:
expected_schemas = get_expected_schema_for_scoring()
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
expected_schemas = get_expected_schema_for_scoring()
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def set_api_key(self) -> None:
# api key is in the request headers
if not self.config.openai_api_key:
@ -141,7 +112,10 @@ class BraintrustScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
@ -172,7 +146,6 @@ class BraintrustScoringImpl(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score
return {"score": score, "metadata": result.metadata}

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Scores answer relevancy according to the question"
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.scoring import (
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
@ -25,7 +28,9 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class LlmAsJudgeScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: LlmAsJudgeScoringConfig,
@ -65,30 +70,15 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,

View file

@ -7,8 +7,7 @@
import pytest
from llama_models.llama3.api import SamplingParams, URL
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.eval.eval import (
@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import (
BenchmarkEvalTaskConfig,
ModelCandidate,
)
from llama_stack.apis.inference import SamplingParams
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
@ -51,3 +52,38 @@ def get_expected_schema_for_eval():
ColumnName.completion_input.value: CompletionInputType(),
},
]
def validate_dataset_schema(
dataset_schema: Dict[str, Any], expected_schemas: List[Dict[str, Any]]
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
input_row: Dict[str, Any], expected_schemas: List[Dict[str, Any]]
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)
class DataSchemaValidatorMixin:
def validate_dataset_schema_for_scoring(self, dataset_schema: Dict[str, Any]):
validate_dataset_schema(dataset_schema, get_expected_schema_for_scoring())
def validate_dataset_schema_for_eval(self, dataset_schema: Dict[str, Any]):
validate_dataset_schema(dataset_schema, get_expected_schema_for_eval())
def validate_row_schema_for_scoring(self, row_schema: Dict[str, Any]):
validate_row_schema(row_schema, get_expected_schema_for_scoring())
def validate_row_schema_for_eval(self, row_schema: Dict[str, Any]):
validate_row_schema(row_schema, get_expected_schema_for_eval())