refactor schema check

This commit is contained in:
Xi Yan 2024-12-19 16:20:47 -08:00
parent 55e4f4eeb3
commit c15b0d5395
7 changed files with 82 additions and 119 deletions

View file

@ -12,6 +12,9 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
@ -21,7 +24,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: BasicScoringConfig,
@ -58,30 +63,15 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,

View file

@ -15,19 +15,18 @@ from llama_stack.apis.datasets import * # noqa: F403
import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness, AnswerRelevancy
from autoevals.ragas import AnswerCorrectness
from pydantic import BaseModel
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_utils import (
get_expected_schema_for_scoring,
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
@ -48,16 +47,14 @@ SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
]
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
DataSchemaValidatorMixin,
):
def __init__(
self,
@ -96,32 +93,6 @@ class BraintrustScoringImpl(
"Registering scoring function not allowed for braintrust provider"
)
async def validate_scoring_input_row_schema(
self, input_row: Dict[str, Any]
) -> None:
expected_schemas = get_expected_schema_for_scoring()
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
expected_schemas = get_expected_schema_for_scoring()
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def set_api_key(self) -> None:
# api key is in the request headers
if not self.config.openai_api_key:
@ -141,7 +112,10 @@ class BraintrustScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
@ -172,7 +146,6 @@ class BraintrustScoringImpl(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score
return {"score": score, "metadata": result.metadata}

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Scores answer relevancy according to the question"
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.scoring import (
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
@ -25,7 +28,9 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class LlmAsJudgeScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: LlmAsJudgeScoringConfig,
@ -65,30 +70,15 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,