fix tests after registration migration & rename meta-reference -> basic / llm_as_judge provider (#424)

* rename meta-reference -> basic

* config rename

* impl rename

* rename llm_as_judge, fix test

* util

* rebase

* naming fix
This commit is contained in:
Xi Yan 2024-11-12 10:35:44 -05:00 committed by GitHub
parent 3d7561e55c
commit 84c6fbbd93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 268 additions and 73 deletions

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import BasicScoringImpl
impl = BasicScoringImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
await impl.initialize()
return impl

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403 from pydantic import BaseModel
class MetaReferenceScoringConfig(BaseModel): ... class BasicScoringConfig(BaseModel): ...

View file

@ -11,44 +11,33 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import MetaReferenceScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__( def __init__(
self, self,
config: MetaReferenceScoringConfig, config: BasicScoringConfig,
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets_api: Datasets, datasets_api: Datasets,
inference_api: Inference,
) -> None: ) -> None:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {} self.scoring_fn_id_impls = {}
async def initialize(self) -> None: async def initialize(self) -> None:
for x in FIXED_FNS: for fn in FIXED_FNS:
impl = x() impl = fn()
for fn_defs in impl.get_supported_scoring_fn_defs(): for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
@ -61,8 +50,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith( assert f.identifier.startswith(
"meta-reference" "basic"
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! " ), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list return scoring_fn_defs_list
@ -70,18 +59,18 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
) )
for required_column in ["generated_answer", "expected_answer", "input_query"]: for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema: if required_column not in dataset_def.schema:
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column." f"Dataset {dataset_id} does not have a '{required_column}' column."
) )
if dataset_def.dataset_schema[required_column].type != "string": if dataset_def.schema[required_column].type != "string":
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
) )

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -9,10 +9,10 @@ from llama_stack.apis.scoring_functions import ScoringFn
equality = ScoringFn( equality = ScoringFn(
identifier="meta-reference::equality", identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None, params=None,
provider_id="meta-reference", provider_id="basic",
provider_resource_id="equality", provider_resource_id="equality",
return_type=NumberType(), return_type=NumberType(),
) )

View file

@ -57,10 +57,10 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
) )
regex_parser_multiple_choice_answer = ScoringFn( regex_parser_multiple_choice_answer = ScoringFn(
identifier="meta-reference::regex_parser_multiple_choice_answer", identifier="basic::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(), return_type=NumberType(),
provider_id="meta-reference", provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer", provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams( params=RegexParserScoringFnParams(
parsing_regexes=[ parsing_regexes=[

View file

@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn
subset_of = ScoringFn( subset_of = ScoringFn(
identifier="meta-reference::subset_of", identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
return_type=NumberType(), return_type=NumberType(),
provider_id="meta-reference", provider_id="basic",
provider_resource_id="subset-of", provider_resource_id="subset-of",
) )

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import re import re
from .base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -63,18 +63,18 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) )
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
) )
for required_column in ["generated_answer", "expected_answer", "input_query"]: for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema: if required_column not in dataset_def.schema:
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column." f"Dataset {dataset_id} does not have a '{required_column}' column."
) )
if dataset_def.dataset_schema[required_column].type != "string": if dataset_def.schema[required_column].type != "string":
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
) )

View file

@ -7,16 +7,16 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig from .config import LlmAsJudgeScoringConfig
async def get_provider_impl( async def get_provider_impl(
config: MetaReferenceScoringConfig, config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, ProviderSpec],
): ):
from .scoring import MetaReferenceScoringImpl from .scoring import LlmAsJudgeScoringImpl
impl = MetaReferenceScoringImpl( impl = LlmAsJudgeScoringImpl(
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference] config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
) )
await impl.initialize() await impl.initialize()

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class LlmAsJudgeScoringConfig(BaseModel): ...

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: LlmAsJudgeScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for fn in LLM_JUDGE_FNS:
impl = fn(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"llm-as-judge"
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn
llm_as_judge_base = ScoringFn( llm_as_judge_base = ScoringFn(
identifier="meta-reference::llm_as_judge_base", identifier="llm-as-judge::llm_as_judge_base",
description="Llm As Judge Scoring Function", description="Llm As Judge Scoring Function",
return_type=NumberType(), return_type=NumberType(),
provider_id="meta-reference", provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base", provider_resource_id="llm-as-judge-base",
) )

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from .base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -13,10 +13,21 @@ def available_providers() -> List[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.scoring, api=Api.scoring,
provider_type="inline::meta-reference", provider_type="inline::basic",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.scoring.meta_reference", module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.meta_reference.MetaReferenceScoringConfig", config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::llm-as-judge",
pip_packages=[],
module="llama_stack.providers.inline.scoring.llm_as_judge",
config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig",
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
Api.datasets, Api.datasets,
@ -25,7 +36,7 @@ def available_providers() -> List[ProviderSpec]:
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.scoring, api=Api.scoring,
provider_type="braintrust", provider_type="inline::braintrust",
pip_packages=["autoevals", "openai"], pip_packages=["autoevals", "openai"],
module="llama_stack.providers.inline.scoring.braintrust", module="llama_stack.providers.inline.scoring.braintrust",
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig", config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",

View file

@ -15,21 +15,12 @@ from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"scoring": "meta_reference", "scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
),
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "localfs", "datasetio": "localfs",
"inference": "together", "inference": "together",
}, },
id="meta_reference_scoring_together_inference", id="basic_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference, marks=pytest.mark.basic_scoring_together_inference,
), ),
pytest.param( pytest.param(
{ {
@ -40,13 +31,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="braintrust_scoring_together_inference", id="braintrust_scoring_together_inference",
marks=pytest.mark.braintrust_scoring_together_inference, marks=pytest.mark.braintrust_scoring_together_inference,
), ),
pytest.param(
{
"scoring": "llm_as_judge",
"datasetio": "localfs",
"inference": "together",
},
id="llm_as_judge_scoring_together_inference",
marks=pytest.mark.llm_as_judge_scoring_together_inference,
),
] ]
def pytest_configure(config): def pytest_configure(config):
for fixture_name in [ for fixture_name in [
"meta_reference_scoring_fireworks_inference", "basic_scoring_together_inference",
"meta_reference_scoring_together_inference",
"braintrust_scoring_together_inference", "braintrust_scoring_together_inference",
]: ]:
config.addinivalue_line( config.addinivalue_line(

View file

@ -19,12 +19,12 @@ def scoring_remote() -> ProviderFixture:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def scoring_meta_reference() -> ProviderFixture: def scoring_basic() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="meta-reference", provider_id="basic",
provider_type="meta-reference", provider_type="inline::basic",
config={}, config={},
) )
], ],
@ -37,14 +37,27 @@ def scoring_braintrust() -> ProviderFixture:
providers=[ providers=[
Provider( Provider(
provider_id="braintrust", provider_id="braintrust",
provider_type="braintrust", provider_type="inline::braintrust",
config={}, config={},
) )
], ],
) )
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"] @pytest.fixture(scope="session")
def scoring_llm_as_judge() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
config={},
)
],
)
SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -43,6 +43,13 @@ class TestScoring:
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models], scoring_stack[Api.models],
) )
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "llm-as-judge":
pytest.skip(
f"{provider_id} provider does not support scoring without params"
)
await register_dataset(datasets_impl) await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
@ -111,8 +118,8 @@ class TestScoring:
scoring_fns_list = await scoring_functions_impl.list_scoring_functions() scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust": if provider_id == "braintrust" or provider_id == "basic":
pytest.skip("Braintrust provider does not support scoring with params") pytest.skip(f"{provider_id} provider does not support scoring with params")
# scoring individual rows # scoring individual rows
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
@ -122,7 +129,7 @@ class TestScoring:
assert len(rows.rows) == 3 assert len(rows.rows) == 3
scoring_functions = { scoring_functions = {
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( "llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct", judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.", prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"], judge_score_regexes=[r"Score: (\d+)"],

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
class BaseScoringFn(ABC): class BaseScoringFn(ABC):