From fd581c3d883e18f20246014da0d677fa29c1179d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 16:17:49 -0800 Subject: [PATCH 1/5] only keep 1 run_eval --- llama_stack/apis/eval/eval.py | 22 +++------ llama_stack/distribution/routers/routers.py | 48 ++++++++----------- .../inline/meta_reference/eval/eval.py | 36 +++++--------- llama_stack/providers/tests/eval/test_eval.py | 15 +++--- 4 files changed, 45 insertions(+), 76 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 6aa4cae34..5ae779ca7 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -66,36 +66,28 @@ class EvaluateResponse(BaseModel): class Eval(Protocol): - @webmethod(route="/eval/run_benchmark", method="POST") - async def run_benchmark( - self, - benchmark_id: str, - benchmark_config: BenchmarkEvalTaskConfig, - ) -> Job: ... - @webmethod(route="/eval/run_eval", method="POST") async def run_eval( self, - task: EvalTaskDef, - task_config: AppEvalTaskConfig, + task_id: str, + task_def: EvalTaskDef, + task_config: EvalTaskConfig, ) -> Job: ... @webmethod(route="/eval/evaluate_rows", method="POST") async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], task_config: EvalTaskConfig, - eval_task_id: Optional[str] = None, ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status( - self, job_id: str, eval_task_id: str - ) -> Optional[JobStatus]: ... + async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") - async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ... + async def job_cancel(self, task_id: str, job_id: str) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ... + async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 06d50bd65..4b28a20d7 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -16,10 +16,6 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.providers.inline.meta_reference.eval.eval import ( - DEFAULT_EVAL_TASK_IDENTIFIER, -) - class MemoryRouter(Memory): """Routes to an provider based on the memory bank identifier""" @@ -268,36 +264,28 @@ class EvalRouter(Eval): async def shutdown(self) -> None: pass - async def run_benchmark( - self, - benchmark_id: str, - benchmark_config: BenchmarkEvalTaskConfig, - ) -> Job: - pass - async def run_eval( self, - task: EvalTaskDef, + task_id: str, + task_def: EvalTaskDef, task_config: AppEvalTaskConfig, ) -> Job: - return await self.routing_table.get_provider_impl(task.identifier).run_eval( - task=task, + return await self.routing_table.get_provider_impl(task_id).run_eval( + task_id=task_id, + task_def=task_def, task_config=task_config, ) @webmethod(route="/eval/evaluate_rows", method="POST") async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], task_config: EvalTaskConfig, - eval_task_id: Optional[str] = None, ) -> EvaluateResponse: - # NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task - # We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier - if eval_task_id is None: - eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER - return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows( + return await self.routing_table.get_provider_impl(task_id).evaluate_rows( + task_id=task_id, input_rows=input_rows, scoring_functions=scoring_functions, task_config=task_config, @@ -305,27 +293,29 @@ class EvalRouter(Eval): async def job_status( self, + task_id: str, job_id: str, - eval_task_id: str, ) -> Optional[JobStatus]: - return await self.routing_table.get_provider_impl(eval_task_id).job_status( - job_id, eval_task_id + return await self.routing_table.get_provider_impl(task_id).job_status( + task_id, job_id ) async def job_cancel( self, + task_id: str, job_id: str, - eval_task_id: str, ) -> None: - await self.routing_table.get_provider_impl(eval_task_id).job_cancel( - job_id, eval_task_id + await self.routing_table.get_provider_impl(task_id).job_cancel( + task_id, + job_id, ) async def job_result( self, + task_id: str, job_id: str, - eval_task_id: str, ) -> EvaluateResponse: - return await self.routing_table.get_provider_impl(eval_task_id).job_result( - job_id, eval_task_id + return await self.routing_table.get_provider_impl(task_id).job_result( + task_id, + job_id, ) diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index a9a1978e9..d20cf30d2 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -7,14 +7,7 @@ from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 from .....apis.common.job_types import Job -from .....apis.eval.eval import ( - AppEvalTaskConfig, - BenchmarkEvalTaskConfig, - Eval, - EvalTaskConfig, - EvaluateResponse, - JobStatus, -) +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -98,21 +91,15 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) - async def run_benchmark( - self, - benchmark_id: str, - benchmark_config: BenchmarkEvalTaskConfig, - ) -> Job: - raise NotImplementedError("Benchmark eval is not implemented yet") - async def run_eval( self, - task: EvalTaskDef, - task_config: AppEvalTaskConfig, + task_id: str, + task_def: EvalTaskDef, + task_config: EvalTaskConfig, ) -> Job: - dataset_id = task.dataset_id + dataset_id = task_def.dataset_id candidate = task_config.eval_candidate - scoring_functions = task.scoring_functions + scoring_functions = task_def.scoring_functions await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( @@ -120,6 +107,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): rows_in_page=-1, ) res = await self.evaluate_rows( + task_id=task_id, input_rows=all_rows.rows, scoring_functions=scoring_functions, task_config=task_config, @@ -133,10 +121,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], task_config: EvalTaskConfig, - eval_task_id: Optional[str] = None, ) -> EvaluateResponse: candidate = task_config.eval_candidate if candidate.type == "agent": @@ -206,17 +194,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]: + async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, job_id: str, eval_task_id: str) -> None: + async def job_cancel(self, task_id: str, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: - status = await self.job_status(job_id, eval_task_id) + async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: + status = await self.job_status(task_id, job_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index d97f74ec4..794026e63 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -50,6 +50,7 @@ class Testeval: ] response = await eval_impl.evaluate_rows( + task_id="meta-reference::app_eval", input_rows=rows.rows, scoring_functions=scoring_functions, task_config=AppEvalTaskConfig( @@ -75,10 +76,12 @@ class Testeval: "meta-reference::subset_of", ] + task_id = "meta-reference::app_eval" response = await eval_impl.run_eval( - task=EvalTaskDef( + task_id=task_id, + task_def=EvalTaskDef( # NOTE: this is needed to make the router work for all app evals - identifier="meta-reference::app_eval", + identifier=task_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, ), @@ -90,13 +93,9 @@ class Testeval: ), ) assert response.job_id == "0" - job_status = await eval_impl.job_status( - response.job_id, "meta-reference::app_eval" - ) + job_status = await eval_impl.job_status(task_id, response.job_id) assert job_status and job_status.value == "completed" - eval_response = await eval_impl.job_result( - response.job_id, "meta-reference::app_eval" - ) + eval_response = await eval_impl.job_result(task_id, response.job_id) assert eval_response is not None assert len(eval_response.generations) == 5 From 7ca479f400aee2c8a9e04f94d8a126dbb19582f5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 16:22:33 -0800 Subject: [PATCH 2/5] fix optional --- llama_stack/apis/scoring/scoring.py | 4 ++-- llama_stack/distribution/routers/routers.py | 4 ++-- .../providers/inline/meta_reference/scoring/scoring.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a68582057..c2bfdcd23 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,7 +48,7 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @@ -56,5 +56,5 @@ class Scoring(Protocol): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 4b28a20d7..7fc65800f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -212,7 +212,7 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} @@ -235,7 +235,7 @@ class ScoringRouter(Scoring): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index d6eb3ae96..c4add966d 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -96,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -120,7 +120,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions.keys(): From 94a56cc3f358fdfd2beb21538eebd5c0c4741380 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 16:41:23 -0800 Subject: [PATCH 3/5] register task required --- llama_stack/apis/eval/eval.py | 1 - llama_stack/distribution/routers/routers.py | 2 -- .../inline/meta_reference/eval/eval.py | 16 ++++----- llama_stack/providers/tests/eval/test_eval.py | 33 +++++++++++++------ 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 5ae779ca7..50fb922fe 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -70,7 +70,6 @@ class Eval(Protocol): async def run_eval( self, task_id: str, - task_def: EvalTaskDef, task_config: EvalTaskConfig, ) -> Job: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 7fc65800f..8edf950b2 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -267,12 +267,10 @@ class EvalRouter(Eval): async def run_eval( self, task_id: str, - task_def: EvalTaskDef, task_config: AppEvalTaskConfig, ) -> Job: return await self.routing_table.get_provider_impl(task_id).run_eval( task_id=task_id, - task_def=task_def, task_config=task_config, ) diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index d20cf30d2..57d1d0124 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -51,22 +51,20 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} + self.eval_tasks = {} + async def initialize(self) -> None: ... async def shutdown(self) -> None: ... + async def register_eval_task(self, task_def: EvalTaskDef) -> None: + self.eval_tasks[task_def.identifier] = task_def + async def list_eval_tasks(self) -> List[EvalTaskDef]: # NOTE: In order to be routed to this provider, the eval task def must have # a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER # for app eval where eval task benchmark_id is not pre-registered - eval_tasks = [ - EvalTaskDef( - identifier=DEFAULT_EVAL_TASK_IDENTIFIER, - dataset_id="", - scoring_functions=[], - ) - ] - return eval_tasks + return list(self.eval_tasks.values()) async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -94,9 +92,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): async def run_eval( self, task_id: str, - task_def: EvalTaskDef, task_config: EvalTaskConfig, ) -> Job: + task_def = self.eval_tasks[task_id] dataset_id = task_def.dataset_id candidate = task_config.eval_candidate scoring_functions = task_def.scoring_functions diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 794026e63..242be50ca 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -9,7 +9,11 @@ import pytest from llama_models.llama3.api import SamplingParams -from llama_stack.apis.eval.eval import AppEvalTaskConfig, EvalTaskDef, ModelCandidate +from llama_stack.apis.eval.eval import ( + AppEvalTaskConfig, + EvalTaskDefWithProvider, + ModelCandidate, +) from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset @@ -28,7 +32,7 @@ class Testeval: _, eval_tasks_impl, _, _, _, _ = eval_stack response = await eval_tasks_impl.list_eval_tasks() assert isinstance(response, list) - assert len(response) >= 1 + assert len(response) == 0 @pytest.mark.asyncio async def test_eval_evaluate_rows(self, eval_stack): @@ -48,9 +52,17 @@ class Testeval: "meta-reference::llm_as_judge_8b_correctness", "meta-reference::equality", ] + task_id = "meta-reference::app_eval" + task_def = EvalTaskDefWithProvider( + identifier=task_id, + dataset_id="test_dataset_for_eval", + scoring_functions=scoring_functions, + provider_id="meta-reference", + ) + await eval_tasks_impl.register_eval_task(task_def) response = await eval_impl.evaluate_rows( - task_id="meta-reference::app_eval", + task_id=task_id, input_rows=rows.rows, scoring_functions=scoring_functions, task_config=AppEvalTaskConfig( @@ -76,15 +88,16 @@ class Testeval: "meta-reference::subset_of", ] - task_id = "meta-reference::app_eval" + task_id = "meta-reference::app_eval-2" + task_def = EvalTaskDefWithProvider( + identifier=task_id, + dataset_id="test_dataset_for_eval", + scoring_functions=scoring_functions, + provider_id="meta-reference", + ) + await eval_tasks_impl.register_eval_task(task_def) response = await eval_impl.run_eval( task_id=task_id, - task_def=EvalTaskDef( - # NOTE: this is needed to make the router work for all app evals - identifier=task_id, - dataset_id="test_dataset_for_eval", - scoring_functions=scoring_functions, - ), task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( model="Llama3.2-3B-Instruct", From 3c17853d793e25cd4279b8fc82c655fc25ef39e0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 16:42:44 -0800 Subject: [PATCH 4/5] register task required --- llama_stack/providers/inline/meta_reference/eval/eval.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 57d1d0124..4a61c9d93 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -19,12 +19,6 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from .config import MetaReferenceEvalConfig -# NOTE: this is the default eval task identifier for app eval -# it is used to make the router work for all app evals -# For app eval using other eval providers, the eval task identifier will be different -DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval" - - class ColumnName(Enum): input_query = "input_query" expected_answer = "expected_answer" @@ -61,9 +55,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self.eval_tasks[task_def.identifier] = task_def async def list_eval_tasks(self) -> List[EvalTaskDef]: - # NOTE: In order to be routed to this provider, the eval task def must have - # a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER - # for app eval where eval task benchmark_id is not pre-registered return list(self.eval_tasks.values()) async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: From 027ee2335cb60ba2c3f669b32775712078ee4074 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 18:06:21 -0800 Subject: [PATCH 5/5] delete old tests --- .../scoring/provider_config_example.yaml | 17 -- .../tests/scoring/test_scoring_old.py | 152 ------------------ 2 files changed, 169 deletions(-) delete mode 100644 llama_stack/providers/tests/scoring/provider_config_example.yaml delete mode 100644 llama_stack/providers/tests/scoring/test_scoring_old.py diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml deleted file mode 100644 index 6a9c0d842..000000000 --- a/llama_stack/providers/tests/scoring/provider_config_example.yaml +++ /dev/null @@ -1,17 +0,0 @@ -providers: - datasetio: - - provider_id: test-meta - provider_type: meta-reference - config: {} - scoring: - - provider_id: test-meta - provider_type: meta-reference - config: {} - - provider_id: test-braintrust - provider_type: braintrust - config: {} - inference: - - provider_id: tgi0 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 diff --git a/llama_stack/providers/tests/scoring/test_scoring_old.py b/llama_stack/providers/tests/scoring/test_scoring_old.py deleted file mode 100644 index b9b920739..000000000 --- a/llama_stack/providers/tests/scoring/test_scoring_old.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import pytest -import pytest_asyncio - -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 - -from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset -from llama_stack.providers.tests.resolver import resolve_impls_for_test - -# How to run this test: -# -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ -# --tb=short --disable-warnings -# ``` - - -@pytest_asyncio.fixture(scope="session") -async def scoring_settings(): - impls = await resolve_impls_for_test( - Api.scoring, deps=[Api.datasetio, Api.inference] - ) - return { - "scoring_impl": impls[Api.scoring], - "scoring_functions_impl": impls[Api.scoring_functions], - "datasets_impl": impls[Api.datasets], - } - - -@pytest_asyncio.fixture(scope="session") -async def provider_scoring_functions(): - return { - "meta-reference": { - "meta-reference::equality", - "meta-reference::subset_of", - "meta-reference::llm_as_judge_8b_correctness", - }, - "braintrust": { - "braintrust::factuality", - "braintrust::answer-correctness", - }, - } - - -@pytest.mark.asyncio -async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - # get current provider_type we're testing - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - - for x in provider_scoring_functions[provider_type]: - assert x in function_ids - - -@pytest.mark.asyncio -async def test_scoring_functions_register(scoring_settings): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - datasets_impl = scoring_settings["datasets_impl"] - - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - if provider_type not in ("meta-reference"): - pytest.skip( - "Other scoring providers don't support registering scoring functions." - ) - - test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ - # register the scoring function - await scoring_functions_impl.register_scoring_function( - ScoringFnDefWithProvider( - identifier="meta-reference::llm_as_judge_8b_random", - description="Llm As Judge Scoring Function", - parameters=[], - return_type=NumberType(), - context=LLMAsJudgeContext( - prompt_template=test_prompt, - judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Number: (\d+)"], - ), - provider_id="test-meta", - ) - ) - - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - assert "meta-reference::llm_as_judge_8b_random" in function_ids - - # test score using newly registered scoring function - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert len(response) == 1 - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=[ - "meta-reference::llm_as_judge_8b_random", - ], - ) - assert "meta-reference::llm_as_judge_8b_random" in response.results - - -@pytest.mark.asyncio -async def test_scoring_score(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - datasets_impl = scoring_settings["datasets_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - await register_dataset(datasets_impl) - - response = await datasets_impl.list_datasets() - assert len(response) == 1 - - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=list(provider_scoring_functions[provider_type]), - ) - - assert len(response.results) == len(provider_scoring_functions[provider_type]) - for x in provider_scoring_functions[provider_type]: - assert x in response.results