Merge branch 'rag_scoring_fn_1' into rag_scoring_fn_2

This commit is contained in:
Xi Yan 2024-12-30 17:56:53 -08:00
commit 86b6d41065
6 changed files with 115 additions and 104 deletions

View file

@ -13,11 +13,13 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import ( from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName, ColumnName,
DataSchemaValidatorMixin, DataSchemaValidatorMixin,
get_valid_schemas,
) )
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
@ -83,7 +85,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_eval(dataset_def.dataset_schema) self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=( rows_in_page=(

View file

@ -14,11 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
@ -73,7 +75,9 @@ class BasicScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,

View file

@ -30,10 +30,13 @@ from llama_stack.apis.scoring import (
) )
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import ( from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin, DataSchemaValidatorMixin,
get_valid_schemas,
) )
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
@ -168,7 +171,9 @@ class BraintrustScoringImpl(
await self.set_api_key() await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,

View file

@ -16,9 +16,11 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import ( from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin, DataSchemaValidatorMixin,
get_valid_schemas,
) )
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
@ -77,7 +79,9 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.distribution.datatypes import Api
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
VALID_SCHEMAS_FOR_SCORING = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.context.value: StringType(),
},
]
VALID_SCHEMAS_FOR_EVAL = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
def get_valid_schemas(api_str: str):
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING
elif api_str == Api.eval.value:
return VALID_SCHEMAS_FOR_EVAL
else:
raise ValueError(f"Invalid API string: {api_str}")
class DataSchemaValidatorMixin:
def validate_dataset_schema(
self,
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
self,
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
class DataSchemaValidatorMixin:
def validate_dataset_schema_for_scoring(self, dataset_schema: Dict[str, Any]):
self.validate_dataset_schema(
dataset_schema, self.get_expected_schema_for_scoring()
)
def validate_dataset_schema_for_eval(self, dataset_schema: Dict[str, Any]):
self.validate_dataset_schema(
dataset_schema, self.get_expected_schema_for_eval()
)
def validate_row_schema_for_scoring(self, input_row: Dict[str, Any]):
self.validate_row_schema(input_row, self.get_expected_schema_for_scoring())
def validate_row_schema_for_eval(self, input_row: Dict[str, Any]):
self.validate_row_schema(input_row, self.get_expected_schema_for_eval())
def get_expected_schema_for_scoring(self):
return [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.context.value: StringType(),
},
]
def get_expected_schema_for_eval(self):
return [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
def validate_dataset_schema(
self,
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
self,
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)