mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
refactor base scoring fn v.s. registerable scoring fn
This commit is contained in:
parent
0096c1a6fc
commit
199f92ddad
5 changed files with 26 additions and 11 deletions
|
@ -9,14 +9,14 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||||
regex_parser_multiple_choice_answer,
|
regex_parser_multiple_choice_answer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegexParserScoringFn(BaseScoringFn):
|
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -8,12 +8,12 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from .fn_defs.subset_of import subset_of
|
from .fn_defs.subset_of import subset_of
|
||||||
|
|
||||||
|
|
||||||
class SubsetOfScoringFn(BaseScoringFn):
|
class SubsetOfScoringFn(RegisteredBaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -20,7 +20,7 @@ from autoevals.ragas import AnswerCorrectness
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||||
|
@ -147,8 +147,11 @@ class BraintrustScoringImpl(
|
||||||
await self.score_row(input_row, scoring_fn_id)
|
await self.score_row(input_row, scoring_fn_id)
|
||||||
for input_row in input_rows
|
for input_row in input_rows
|
||||||
]
|
]
|
||||||
aggregation_functions = [AggregationFunctionType.average]
|
aggregation_functions = self.supported_fn_defs_registry[
|
||||||
agg_results = aggregate_average(score_results)
|
scoring_fn_id
|
||||||
|
].params.aggregation_functions
|
||||||
|
|
||||||
|
agg_results = aggregate_metrics(score_results, aggregation_functions)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -5,14 +5,20 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
answer_correctness_fn_def = ScoringFn(
|
answer_correctness_fn_def = ScoringFn(
|
||||||
identifier="braintrust::answer-correctness",
|
identifier="braintrust::answer-correctness",
|
||||||
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
params=None,
|
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="answer-correctness",
|
provider_resource_id="answer-correctness",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
|
params=BasicScoringFnParams(
|
||||||
|
aggregation_functions=[AggregationFunctionType.average]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,14 +5,20 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
factuality_fn_def = ScoringFn(
|
factuality_fn_def = ScoringFn(
|
||||||
identifier="braintrust::factuality",
|
identifier="braintrust::factuality",
|
||||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
params=None,
|
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="factuality",
|
provider_resource_id="factuality",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
|
params=BasicScoringFnParams(
|
||||||
|
aggregation_functions=[AggregationFunctionType.average]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue