Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -124,12 +124,10 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api
self.braintrust_evaluators = {
entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
async def initialize(self) -> None: ...
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"braintrust"
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
)
return scoring_fn_defs_list
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
async def set_api_key(self) -> None:
# api key is in the request headers
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:

View file

@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ context_precision_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ context_recall_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -14,13 +14,10 @@ from llama_stack.apis.scoring_functions import (
context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. "
"See: github.com/braintrustdata/autoevals"
"Assesses how relevant the provided context is to the given question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -21,7 +21,5 @@ factuality_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="factuality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -20,7 +20,5 @@ faithfulness_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)