fix tests after registration migration & rename meta-reference -> basic / llm_as_judge provider (#424)

* rename meta-reference -> basic

* config rename

* impl rename

* rename llm_as_judge, fix test

* util

* rebase

* naming fix
This commit is contained in:
Xi Yan 2024-11-12 10:35:44 -05:00 committed by GitHub
parent 3d7561e55c
commit 84c6fbbd93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 268 additions and 73 deletions

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import BasicScoringImpl
impl = BasicScoringImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
await impl.initialize()
return impl

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
from pydantic import BaseModel
class MetaReferenceScoringConfig(BaseModel): ...
class BasicScoringConfig(BaseModel): ...

View file

@ -11,44 +11,33 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import MetaReferenceScoringConfig
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: MetaReferenceScoringConfig,
config: BasicScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for x in FIXED_FNS:
impl = x()
for fn in FIXED_FNS:
impl = fn()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
@ -61,8 +50,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"meta-reference"
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
"basic"
), "All basic scoring fn must have identifier prefixed with 'basic'! "
return scoring_fn_defs_list
@ -70,18 +59,18 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
if required_column not in dataset_def.schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
if dataset_def.schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -9,10 +9,10 @@ from llama_stack.apis.scoring_functions import ScoringFn
equality = ScoringFn(
identifier="meta-reference::equality",
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="meta-reference",
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
)

View file

@ -57,10 +57,10 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
)
regex_parser_multiple_choice_answer = ScoringFn(
identifier="meta-reference::regex_parser_multiple_choice_answer",
identifier="basic::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
provider_id="meta-reference",
provider_id="basic",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[

View file

@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn
subset_of = ScoringFn(
identifier="meta-reference::subset_of",
identifier="basic::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
return_type=NumberType(),
provider_id="meta-reference",
provider_id="basic",
provider_resource_id="subset-of",
)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import re
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -63,18 +63,18 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
if required_column not in dataset_def.schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
if dataset_def.schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)

View file

@ -7,16 +7,16 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import MetaReferenceScoringImpl
from .scoring import LlmAsJudgeScoringImpl
impl = MetaReferenceScoringImpl(
impl = LlmAsJudgeScoringImpl(
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
)
await impl.initialize()

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class LlmAsJudgeScoringConfig(BaseModel): ...

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: LlmAsJudgeScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for fn in LLM_JUDGE_FNS:
impl = fn(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"llm-as-judge"
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn
llm_as_judge_base = ScoringFn(
identifier="meta-reference::llm_as_judge_base",
identifier="llm-as-judge::llm_as_judge_base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
provider_id="meta-reference",
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base",
)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -13,10 +13,21 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::meta-reference",
provider_type="inline::basic",
pip_packages=[],
module="llama_stack.providers.inline.scoring.meta_reference",
config_class="llama_stack.providers.inline.scoring.meta_reference.MetaReferenceScoringConfig",
module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::llm-as-judge",
pip_packages=[],
module="llama_stack.providers.inline.scoring.llm_as_judge",
config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
@ -25,7 +36,7 @@ def available_providers() -> List[ProviderSpec]:
),
InlineProviderSpec(
api=Api.scoring,
provider_type="braintrust",
provider_type="inline::braintrust",
pip_packages=["autoevals", "openai"],
module="llama_stack.providers.inline.scoring.braintrust",
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",

View file

@ -15,21 +15,12 @@ from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "localfs",
"inference": "fireworks",
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
),
pytest.param(
{
"scoring": "meta_reference",
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
},
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
id="basic_scoring_together_inference",
marks=pytest.mark.basic_scoring_together_inference,
),
pytest.param(
{
@ -40,13 +31,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="braintrust_scoring_together_inference",
marks=pytest.mark.braintrust_scoring_together_inference,
),
pytest.param(
{
"scoring": "llm_as_judge",
"datasetio": "localfs",
"inference": "together",
},
id="llm_as_judge_scoring_together_inference",
marks=pytest.mark.llm_as_judge_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
"basic_scoring_together_inference",
"braintrust_scoring_together_inference",
]:
config.addinivalue_line(

View file

@ -19,12 +19,12 @@ def scoring_remote() -> ProviderFixture:
@pytest.fixture(scope="session")
def scoring_meta_reference() -> ProviderFixture:
def scoring_basic() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
provider_id="basic",
provider_type="inline::basic",
config={},
)
],
@ -37,14 +37,27 @@ def scoring_braintrust() -> ProviderFixture:
providers=[
Provider(
provider_id="braintrust",
provider_type="braintrust",
provider_type="inline::braintrust",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
@pytest.fixture(scope="session")
def scoring_llm_as_judge() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
config={},
)
],
)
SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
@pytest_asyncio.fixture(scope="session")

View file

@ -43,6 +43,13 @@ class TestScoring:
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "llm-as-judge":
pytest.skip(
f"{provider_id} provider does not support scoring without params"
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
@ -111,8 +118,8 @@ class TestScoring:
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust":
pytest.skip("Braintrust provider does not support scoring with params")
if provider_id == "braintrust" or provider_id == "basic":
pytest.skip(f"{provider_id} provider does not support scoring with params")
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
@ -122,7 +129,7 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
"llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
class BaseScoringFn(ABC):