forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -54,15 +54,11 @@ class BasicScoringImpl(
|
|||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"basic"
|
||||
), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
|
@ -76,9 +72,7 @@ class BasicScoringImpl(
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -108,12 +102,8 @@ class BasicScoringImpl(
|
|||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -32,9 +32,7 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
|
|||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -18,7 +18,5 @@ equality = ScoringFn(
|
|||
provider_id="basic",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
|
@ -55,9 +55,7 @@ MULTILINGUAL_ANSWER_REGEXES = [
|
|||
r"Àṣàyàn\s*:",
|
||||
]
|
||||
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
)
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
|
||||
regex_parser_multiple_choice_answer = ScoringFn(
|
||||
identifier="basic::regex_parser_multiple_choice_answer",
|
||||
|
@ -66,10 +64,7 @@ regex_parser_multiple_choice_answer = ScoringFn(
|
|||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-multiple-choice-answer",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=[
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||
],
|
||||
parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES],
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
|
@ -18,7 +18,5 @@ subset_of = ScoringFn(
|
|||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="subset-of",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
|
@ -33,17 +33,14 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
|||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert (
|
||||
fn_def.params is not None
|
||||
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
|
||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -124,12 +124,10 @@ class BraintrustScoringImpl(
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
self.braintrust_evaluators = {
|
||||
entry.identifier: entry.evaluator
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
self.supported_fn_defs_registry = {
|
||||
entry.identifier: entry.fn_def
|
||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||
}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
@ -139,16 +137,14 @@ class BraintrustScoringImpl(
|
|||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"braintrust"
|
||||
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
)
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||
raise NotImplementedError(
|
||||
"Registering scoring function not allowed for braintrust provider"
|
||||
)
|
||||
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
|
||||
|
||||
async def set_api_key(self) -> None:
|
||||
# api key is in the request headers
|
||||
|
@ -171,17 +167,13 @@ class BraintrustScoringImpl(
|
|||
await self.set_api_key()
|
||||
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||
)
|
||||
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
|
@ -222,13 +214,8 @@ class BraintrustScoringImpl(
|
|||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
|
||||
score_results = [
|
||||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
aggregation_functions = self.supported_fn_defs_registry[
|
||||
scoring_fn_id
|
||||
].params.aggregation_functions
|
||||
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
|
||||
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
|
||||
|
||||
# override scoring_fn params if provided
|
||||
if scoring_functions[scoring_fn_id] is not None:
|
||||
|
|
|
@ -21,7 +21,5 @@ answer_correctness_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="answer-correctness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ answer_relevancy_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="answer-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ answer_similarity_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="answer-similarity",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="context-entity-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ context_precision_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="context-precision",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ context_recall_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="context-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -14,13 +14,10 @@ from llama_stack.apis.scoring_functions import (
|
|||
context_relevancy_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-relevancy",
|
||||
description=(
|
||||
"Assesses how relevant the provided context is to the given question. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
"Assesses how relevant the provided context is to the given question. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -21,7 +21,5 @@ factuality_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="factuality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -20,7 +20,5 @@ faithfulness_fn_def = ScoringFn(
|
|||
provider_id="braintrust",
|
||||
provider_resource_id="faithfulness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||
)
|
||||
|
|
|
@ -16,8 +16,6 @@ async def get_provider_impl(
|
|||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
impl = LlmAsJudgeScoringImpl(
|
||||
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
|
||||
)
|
||||
impl = LlmAsJudgeScoringImpl(config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -58,15 +58,13 @@ class LlmAsJudgeScoringImpl(
|
|||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"llm-as-judge"
|
||||
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
assert f.identifier.startswith("llm-as-judge"), (
|
||||
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
)
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
|
@ -80,9 +78,7 @@ class LlmAsJudgeScoringImpl(
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -112,12 +108,8 @@ class LlmAsJudgeScoringImpl(
|
|||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -38,9 +38,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
||||
# override params if scoring_params is provided
|
||||
|
@ -48,12 +46,8 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert (
|
||||
fn_def.params.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
assert (
|
||||
fn_def.params.judge_score_regexes is not None
|
||||
), "LLM Judge judge_score_regexes not found."
|
||||
assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found."
|
||||
assert fn_def.params.judge_score_regexes is not None, "LLM Judge judge_score_regexes not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue