refactor schema check

This commit is contained in:
Xi Yan 2024-12-19 16:20:47 -08:00
parent 55e4f4eeb3
commit c15b0d5395
7 changed files with 82 additions and 119 deletions

View file

@ -14,9 +14,10 @@ from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_utils import (
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
ColumnName, ColumnName,
get_expected_schema_for_eval, DataSchemaValidatorMixin,
) )
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
@ -28,7 +29,7 @@ from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
def __init__( def __init__(
self, self,
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
@ -72,17 +73,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
) )
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: # async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) # dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: # if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") # raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = get_expected_schema_for_eval() # expected_schemas = get_expected_schema_for_eval()
if dataset_def.dataset_schema not in expected_schemas: # if dataset_def.dataset_schema not in expected_schemas:
raise ValueError( # raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" # f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
) # )
async def run_eval( async def run_eval(
self, self,
@ -93,8 +94,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) self.validate_dataset_schema_for_eval(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=( rows_in_page=(

View file

@ -12,6 +12,9 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
@ -21,7 +24,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__( def __init__(
self, self,
config: BasicScoringConfig, config: BasicScoringConfig,
@ -58,30 +63,15 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,

View file

@ -15,19 +15,18 @@ from llama_stack.apis.datasets import * # noqa: F403
import os import os
from autoevals.llm import Factuality from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness, AnswerRelevancy from autoevals.ragas import AnswerCorrectness
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_utils import ( from llama_stack.providers.utils.common.data_schema_validator_mixin import (
get_expected_schema_for_scoring, DataSchemaValidatorMixin,
) )
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def
@ -48,16 +47,14 @@ SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
evaluator=AnswerCorrectness(), evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def, fn_def=answer_correctness_fn_def,
), ),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
] ]
class BraintrustScoringImpl( class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
DataSchemaValidatorMixin,
): ):
def __init__( def __init__(
self, self,
@ -96,32 +93,6 @@ class BraintrustScoringImpl(
"Registering scoring function not allowed for braintrust provider" "Registering scoring function not allowed for braintrust provider"
) )
async def validate_scoring_input_row_schema(
self, input_row: Dict[str, Any]
) -> None:
expected_schemas = get_expected_schema_for_scoring()
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
expected_schemas = get_expected_schema_for_scoring()
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def set_api_key(self) -> None: async def set_api_key(self) -> None:
# api key is in the request headers # api key is in the request headers
if not self.config.openai_api_key: if not self.config.openai_api_key:
@ -141,7 +112,10 @@ class BraintrustScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.set_api_key() await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
@ -172,7 +146,6 @@ class BraintrustScoringImpl(
generated_answer, generated_answer,
expected_answer, expected_answer,
input=input_query, input=input_query,
context=input_row["context"] if "context" in input_row else None,
) )
score = result.score score = result.score
return {"score": score, "metadata": result.metadata} return {"score": score, "metadata": result.metadata}

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Scores answer relevancy according to the question"
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.scoring import (
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
@ -25,7 +28,9 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class LlmAsJudgeScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__( def __init__(
self, self,
config: LlmAsJudgeScoringConfig, config: LlmAsJudgeScoringConfig,
@ -65,30 +70,15 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,

View file

@ -7,8 +7,7 @@
import pytest import pytest
from llama_models.llama3.api import SamplingParams, URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.eval.eval import ( from llama_stack.apis.eval.eval import (
@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import (
BenchmarkEvalTaskConfig, BenchmarkEvalTaskConfig,
ModelCandidate, ModelCandidate,
) )
from llama_stack.apis.inference import SamplingParams
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import ( from llama_stack.apis.common.type_system import (
ChatCompletionInputType, ChatCompletionInputType,
@ -51,3 +52,38 @@ def get_expected_schema_for_eval():
ColumnName.completion_input.value: CompletionInputType(), ColumnName.completion_input.value: CompletionInputType(),
}, },
] ]
def validate_dataset_schema(
dataset_schema: Dict[str, Any], expected_schemas: List[Dict[str, Any]]
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
input_row: Dict[str, Any], expected_schemas: List[Dict[str, Any]]
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)
class DataSchemaValidatorMixin:
def validate_dataset_schema_for_scoring(self, dataset_schema: Dict[str, Any]):
validate_dataset_schema(dataset_schema, get_expected_schema_for_scoring())
def validate_dataset_schema_for_eval(self, dataset_schema: Dict[str, Any]):
validate_dataset_schema(dataset_schema, get_expected_schema_for_eval())
def validate_row_schema_for_scoring(self, row_schema: Dict[str, Any]):
validate_row_schema(row_schema, get_expected_schema_for_scoring())
def validate_row_schema_for_eval(self, row_schema: Dict[str, Any]):
validate_row_schema(row_schema, get_expected_schema_for_eval())