mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
refactor
This commit is contained in:
parent
1077c521b1
commit
7aab3c63f4
8 changed files with 30 additions and 18 deletions
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
|
||||
|
@ -54,7 +54,7 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
if scoring_params is not None:
|
||||
params = scoring_params
|
||||
|
||||
aggregation_functions = [AggregationFunctionType.accuracy]
|
||||
aggregation_functions = []
|
||||
if (
|
||||
params
|
||||
and hasattr(params, "aggregation_functions")
|
||||
|
|
|
@ -5,14 +5,20 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
equality = ScoringFn(
|
||||
identifier="basic::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
params=None,
|
||||
provider_id="basic",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -4,9 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"Answer\s*:",
|
||||
|
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
|
|||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||
],
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5,7 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
subset_of = ScoringFn(
|
||||
|
@ -14,4 +18,7 @@ subset_of = ScoringFn(
|
|||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="subset-of",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -8,11 +8,7 @@ import re
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
ScoringFnParams,
|
||||
ScoringFnParamsType,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
|
||||
|
@ -76,7 +72,7 @@ class RegexParserScoringFn(BaseScoringFn):
|
|||
if scoring_params is not None:
|
||||
params = scoring_params
|
||||
|
||||
aggregation_functions = [AggregationFunctionType.accuracy]
|
||||
aggregation_functions = []
|
||||
if (
|
||||
params
|
||||
and hasattr(params, "aggregation_functions")
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
|
||||
|
@ -48,7 +48,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
if scoring_params is not None:
|
||||
params = scoring_params
|
||||
|
||||
aggregation_functions = [AggregationFunctionType.accuracy]
|
||||
aggregation_functions = []
|
||||
if (
|
||||
params
|
||||
and hasattr(params, "aggregation_functions")
|
||||
|
|
|
@ -147,7 +147,7 @@ class BraintrustScoringImpl(
|
|||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
|
||||
aggregation_functions = [AggregationFunctionType.average]
|
||||
agg_results = aggregate_average(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
|
|||
|
||||
class BaseScoringFn(ABC):
|
||||
"""
|
||||
Base interface class for all meta-reference scoring_fns.
|
||||
Base interface class for all native scoring_fns.
|
||||
Each scoring_fn needs to implement the following methods:
|
||||
- score_row(self, row)
|
||||
- aggregate(self, scoring_fn_results)
|
||||
|
@ -20,7 +20,6 @@ class BaseScoringFn(ABC):
|
|||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue