params override

This commit is contained in:
Xi Yan 2024-12-10 17:14:37 -08:00
parent f39e4c7117
commit ed33d00eb4
5 changed files with 24 additions and 88 deletions

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.equality import equality
@ -43,22 +42,3 @@ class EqualityScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
params = scoring_params
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -5,11 +5,10 @@
# the root directory of this source tree.
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import (
@ -61,22 +60,3 @@ class RegexParserScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
params = scoring_params
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.subset_of import subset_of
@ -37,22 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
params = scoring_params
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -5,15 +5,13 @@
# the root directory of this source tree.
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
@ -47,7 +45,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
for attr in scoring_params.__dict__:
override_attr = getattr(scoring_params, attr)
if override_attr is not None:
setattr(fn_def.params, attr, override_attr)
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert (
@ -90,22 +91,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
"score": judge_rating,
"judge_feedback": content,
}
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
params = scoring_params
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
class BaseScoringFn(ABC):
@ -44,14 +45,27 @@ class BaseScoringFn(ABC):
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
if params is None:
params = scoring_params
else:
params.aggregation_functions = scoring_params.aggregation_functions
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)
async def score(
self,