forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -124,12 +124,10 @@ class BraintrustScoringImpl(
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
self.braintrust_evaluators = {
|
||||
entry.identifier: entry.evaluator
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
self.supported_fn_defs_registry = {
|
||||
entry.identifier: entry.fn_def
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
|
|||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"braintrust"
|
||||
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
)
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||
raise NotImplementedError(
|
||||
"Registering scoring function not allowed for braintrust provider"
|
||||
)
|
||||
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
|
||||
|
||||
async def set_api_key(self) -> None:
|
||||
# api key is in the request headers
|
||||
|
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
|
|||
await self.set_api_key()
|
||||
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||
)
|
||||
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
|
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
|
|||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
|
||||
score_results = [
|
||||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
aggregation_functions = self.supported_fn_defs_registry[
|
||||
scoring_fn_id
|
||||
].params.aggregation_functions
|
||||
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
|
||||
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
|
||||
|
||||
# override scoring_fn params if provided
|
||||
if scoring_functions[scoring_fn_id] is not None:
|
||||
|
|
|
@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="answer-correctness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="answer-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="answer-similarity",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="context-entity-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ context_precision_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="context-precision",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ context_recall_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="context-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -14,13 +14,10 @@ from llama_stack.apis.scoring_functions import (
|
|||
context_relevancy_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-relevancy",
|
||||
description=(
|
||||
"Assesses how relevant the provided context is to the given question. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
"Assesses how relevant the provided context is to the given question. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -21,7 +21,5 @@ factuality_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="factuality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ faithfulness_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="faithfulness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue