mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refactor
This commit is contained in:
parent
1077c521b1
commit
7aab3c63f4
8 changed files with 30 additions and 18 deletions
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class EqualityScoringFn(BaseScoringFn):
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
params = scoring_params
|
params = scoring_params
|
||||||
|
|
||||||
aggregation_functions = [AggregationFunctionType.accuracy]
|
aggregation_functions = []
|
||||||
if (
|
if (
|
||||||
params
|
params
|
||||||
and hasattr(params, "aggregation_functions")
|
and hasattr(params, "aggregation_functions")
|
||||||
|
|
|
@ -5,14 +5,20 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
equality = ScoringFn(
|
equality = ScoringFn(
|
||||||
identifier="basic::equality",
|
identifier="basic::equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
params=None,
|
|
||||||
provider_id="basic",
|
provider_id="basic",
|
||||||
provider_resource_id="equality",
|
provider_resource_id="equality",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
|
params=BasicScoringFnParams(
|
||||||
|
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,9 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
RegexParserScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
MULTILINGUAL_ANSWER_REGEXES = [
|
MULTILINGUAL_ANSWER_REGEXES = [
|
||||||
r"Answer\s*:",
|
r"Answer\s*:",
|
||||||
|
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
|
||||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||||
],
|
],
|
||||||
|
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
subset_of = ScoringFn(
|
subset_of = ScoringFn(
|
||||||
|
@ -14,4 +18,7 @@ subset_of = ScoringFn(
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
provider_id="basic",
|
provider_id="basic",
|
||||||
provider_resource_id="subset-of",
|
provider_resource_id="subset-of",
|
||||||
|
params=BasicScoringFnParams(
|
||||||
|
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,11 +8,7 @@ import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import (
|
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||||
AggregationFunctionType,
|
|
||||||
ScoringFnParams,
|
|
||||||
ScoringFnParamsType,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||||
|
|
||||||
|
@ -76,7 +72,7 @@ class RegexParserScoringFn(BaseScoringFn):
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
params = scoring_params
|
params = scoring_params
|
||||||
|
|
||||||
aggregation_functions = [AggregationFunctionType.accuracy]
|
aggregation_functions = []
|
||||||
if (
|
if (
|
||||||
params
|
params
|
||||||
and hasattr(params, "aggregation_functions")
|
and hasattr(params, "aggregation_functions")
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
params = scoring_params
|
params = scoring_params
|
||||||
|
|
||||||
aggregation_functions = [AggregationFunctionType.accuracy]
|
aggregation_functions = []
|
||||||
if (
|
if (
|
||||||
params
|
params
|
||||||
and hasattr(params, "aggregation_functions")
|
and hasattr(params, "aggregation_functions")
|
||||||
|
|
|
@ -147,7 +147,7 @@ class BraintrustScoringImpl(
|
||||||
await self.score_row(input_row, scoring_fn_id)
|
await self.score_row(input_row, scoring_fn_id)
|
||||||
for input_row in input_rows
|
for input_row in input_rows
|
||||||
]
|
]
|
||||||
|
aggregation_functions = [AggregationFunctionType.average]
|
||||||
agg_results = aggregate_average(score_results)
|
agg_results = aggregate_average(score_results)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
|
||||||
class BaseScoringFn(ABC):
|
class BaseScoringFn(ABC):
|
||||||
"""
|
"""
|
||||||
Base interface class for all meta-reference scoring_fns.
|
Base interface class for all native scoring_fns.
|
||||||
Each scoring_fn needs to implement the following methods:
|
Each scoring_fn needs to implement the following methods:
|
||||||
- score_row(self, row)
|
- score_row(self, row)
|
||||||
- aggregate(self, scoring_fn_results)
|
- aggregate(self, scoring_fn_results)
|
||||||
|
@ -20,7 +20,6 @@ class BaseScoringFn(ABC):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.supported_fn_defs_registry = {}
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue