This commit is contained in:
Xi Yan 2025-03-15 14:48:26 -07:00
parent 9b38ae9323
commit f262bfd061
10 changed files with 107 additions and 319 deletions

View file

@ -64,15 +64,11 @@ class BasicScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
@ -86,9 +82,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
@ -118,12 +112,8 @@ class BasicScoringImpl(
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -122,12 +122,10 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api
self.braintrust_evaluators = {
entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
async def initialize(self) -> None: ...
@ -137,16 +135,14 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"braintrust"
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
)
return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
async def set_api_key(self) -> None:
# api key is in the request headers
@ -169,17 +165,13 @@ class BraintrustScoringImpl(
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
@ -220,13 +212,8 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:

View file

@ -54,9 +54,9 @@ class LlmAsJudgeScoringImpl(
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
assert f.identifier.startswith(
"llm-as-judge"
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
assert f.identifier.startswith("llm-as-judge"), (
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
)
return scoring_fn_defs_list
@ -70,9 +70,7 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
@ -100,12 +98,8 @@ class LlmAsJudgeScoringImpl(
for scoring_fn_id in scoring_functions.keys():
scoring_fn = self.llm_as_judge_fn
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,