mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
merge
This commit is contained in:
commit
33b6d9b7b7
8 changed files with 67 additions and 304 deletions
|
@ -74,36 +74,27 @@ class EvaluateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
@webmethod(route="/eval/run_benchmark", method="POST")
|
|
||||||
async def run_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkEvalTaskConfig,
|
|
||||||
) -> Job: ...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/run_eval", method="POST")
|
@webmethod(route="/eval/run_eval", method="POST")
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task: EvalTaskDef,
|
task_id: str,
|
||||||
task_config: AppEvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
|
task_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
eval_task_id: Optional[str] = None,
|
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/status", method="GET")
|
@webmethod(route="/eval/job/status", method="GET")
|
||||||
async def job_status(
|
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||||
self, job_id: str, eval_task_id: str
|
|
||||||
) -> Optional[JobStatus]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/job/cancel", method="POST")
|
@webmethod(route="/eval/job/cancel", method="POST")
|
||||||
async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ...
|
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/result", method="GET")
|
@webmethod(route="/eval/job/result", method="GET")
|
||||||
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ...
|
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||||
|
|
|
@ -48,7 +48,7 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
|
@ -56,5 +56,5 @@ class Scoring(Protocol):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse: ...
|
) -> ScoreResponse: ...
|
||||||
|
|
|
@ -16,10 +16,6 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F403
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
from llama_stack.apis.eval import * # noqa: F403
|
from llama_stack.apis.eval import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.eval.eval import (
|
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRouter(Memory):
|
class MemoryRouter(Memory):
|
||||||
"""Routes to an provider based on the memory bank identifier"""
|
"""Routes to an provider based on the memory bank identifier"""
|
||||||
|
@ -216,7 +212,7 @@ class ScoringRouter(Scoring):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -239,7 +235,7 @@ class ScoringRouter(Scoring):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
|
@ -268,39 +264,26 @@ class EvalRouter(Eval):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkEvalTaskConfig,
|
|
||||||
) -> Job:
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_benchmark(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task: EvalTaskDef,
|
task_id: str,
|
||||||
task_config: AppEvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
return await self.routing_table.get_provider_impl(task.identifier).run_eval(
|
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
||||||
task=task,
|
task_id=task_id,
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
|
task_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
eval_task_id: Optional[str] = None,
|
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
# NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task
|
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
|
||||||
# We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier
|
task_id=task_id,
|
||||||
if eval_task_id is None:
|
|
||||||
eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER
|
|
||||||
return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows(
|
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
|
@ -308,27 +291,29 @@ class EvalRouter(Eval):
|
||||||
|
|
||||||
async def job_status(
|
async def job_status(
|
||||||
self,
|
self,
|
||||||
|
task_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
eval_task_id: str,
|
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
return await self.routing_table.get_provider_impl(eval_task_id).job_status(
|
return await self.routing_table.get_provider_impl(task_id).job_status(
|
||||||
job_id, eval_task_id
|
task_id, job_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
|
task_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
eval_task_id: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.get_provider_impl(eval_task_id).job_cancel(
|
await self.routing_table.get_provider_impl(task_id).job_cancel(
|
||||||
job_id, eval_task_id
|
task_id,
|
||||||
|
job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def job_result(
|
async def job_result(
|
||||||
self,
|
self,
|
||||||
|
task_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
eval_task_id: str,
|
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
return await self.routing_table.get_provider_impl(eval_task_id).job_result(
|
return await self.routing_table.get_provider_impl(task_id).job_result(
|
||||||
job_id, eval_task_id
|
task_id,
|
||||||
|
job_id,
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,14 +7,7 @@ from enum import Enum
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from .....apis.common.job_types import Job
|
from .....apis.common.job_types import Job
|
||||||
from .....apis.eval.eval import (
|
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||||
AppEvalTaskConfig,
|
|
||||||
BenchmarkEvalTaskConfig,
|
|
||||||
Eval,
|
|
||||||
EvalTaskConfig,
|
|
||||||
EvaluateResponse,
|
|
||||||
JobStatus,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -28,12 +21,6 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this is the default eval task identifier for app eval
|
|
||||||
# it is used to make the router work for all app evals
|
|
||||||
# For app eval using other eval providers, the eval task identifier will be different
|
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnName(Enum):
|
class ColumnName(Enum):
|
||||||
input_query = "input_query"
|
input_query = "input_query"
|
||||||
expected_answer = "expected_answer"
|
expected_answer = "expected_answer"
|
||||||
|
@ -60,30 +47,15 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
# TODO: assume sync job, will need jobs API for async scheduling
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
self.jobs = {}
|
self.jobs = {}
|
||||||
|
|
||||||
# Keep track of benchmark eval tasks that are supported by this provider
|
|
||||||
self.eval_tasks = {}
|
self.eval_tasks = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None: ...
|
||||||
self.eval_tasks = {
|
|
||||||
# NOTE: In order to be routed to this provider, the eval task def must have
|
|
||||||
# a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER
|
|
||||||
# for app eval where eval task benchmark_id is not pre-registered
|
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER: EvalTaskDef(
|
|
||||||
identifier=DEFAULT_EVAL_TASK_IDENTIFIER,
|
|
||||||
dataset_id="",
|
|
||||||
scoring_functions=[],
|
|
||||||
),
|
|
||||||
"meta-reference-mmlu": EvalTaskDef(
|
|
||||||
identifier="meta-reference-mmlu",
|
|
||||||
dataset_id="llamastack_mmlu",
|
|
||||||
scoring_functions=[
|
|
||||||
"meta-reference::regex_parser_multiple_choice_answer"
|
|
||||||
],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
|
||||||
|
self.eval_tasks[task_def.identifier] = task_def
|
||||||
|
|
||||||
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
||||||
return list(self.eval_tasks.values())
|
return list(self.eval_tasks.values())
|
||||||
|
|
||||||
|
@ -110,39 +82,15 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkEvalTaskConfig,
|
|
||||||
) -> Job:
|
|
||||||
eval_task_def = self.eval_tasks[benchmark_id]
|
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
|
||||||
dataset_id=eval_task_def.dataset_id,
|
|
||||||
rows_in_page=(
|
|
||||||
-1
|
|
||||||
if benchmark_config.num_examples is None
|
|
||||||
else benchmark_config.num_examples
|
|
||||||
),
|
|
||||||
)
|
|
||||||
res = await self.evaluate_rows(
|
|
||||||
input_rows=all_rows.rows,
|
|
||||||
scoring_functions=eval_task_def.scoring_functions,
|
|
||||||
task_config=benchmark_config,
|
|
||||||
)
|
|
||||||
# TODO: currently needs to wait for generation before returning
|
|
||||||
# need job scheduler queue (celery) w/ jobs api
|
|
||||||
job_id = str(len(self.jobs))
|
|
||||||
self.jobs[job_id] = res
|
|
||||||
return Job(job_id=job_id)
|
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task: EvalTaskDef,
|
task_id: str,
|
||||||
task_config: AppEvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
dataset_id = task.dataset_id
|
task_def = self.eval_tasks[task_id]
|
||||||
|
dataset_id = task_def.dataset_id
|
||||||
candidate = task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
scoring_functions = task.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
|
|
||||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
|
@ -152,6 +100,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
|
task_id=task_id,
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
|
@ -165,10 +114,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
|
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
|
task_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
eval_task_id: Optional[str] = None,
|
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
candidate = task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
if candidate.type == "agent":
|
||||||
|
@ -238,17 +187,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
|
|
||||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||||
|
|
||||||
async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]:
|
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
|
||||||
if job_id in self.jobs:
|
if job_id in self.jobs:
|
||||||
return JobStatus.completed
|
return JobStatus.completed
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def job_cancel(self, job_id: str, eval_task_id: str) -> None:
|
async def job_cancel(self, task_id: str, job_id: str) -> None:
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse:
|
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
|
||||||
status = await self.job_status(job_id, eval_task_id)
|
status = await self.job_status(task_id, job_id)
|
||||||
if not status or status != JobStatus.completed:
|
if not status or status != JobStatus.completed:
|
||||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
|
@ -113,7 +113,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions.keys():
|
for scoring_fn_id in scoring_functions.keys():
|
||||||
|
|
|
@ -11,8 +11,7 @@ from llama_models.llama3.api import SamplingParams
|
||||||
|
|
||||||
from llama_stack.apis.eval.eval import (
|
from llama_stack.apis.eval.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
BenchmarkEvalTaskConfig,
|
EvalTaskDefWithProvider,
|
||||||
EvalTaskDef,
|
|
||||||
ModelCandidate,
|
ModelCandidate,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||||
|
@ -33,7 +32,7 @@ class Testeval:
|
||||||
_, eval_tasks_impl, _, _, _, _ = eval_stack
|
_, eval_tasks_impl, _, _, _, _ = eval_stack
|
||||||
response = await eval_tasks_impl.list_eval_tasks()
|
response = await eval_tasks_impl.list_eval_tasks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) >= 1
|
assert len(response) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_eval_evaluate_rows(self, eval_stack):
|
async def test_eval_evaluate_rows(self, eval_stack):
|
||||||
|
@ -59,8 +58,17 @@ class Testeval:
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
"meta-reference::equality",
|
"meta-reference::equality",
|
||||||
]
|
]
|
||||||
|
task_id = "meta-reference::app_eval"
|
||||||
|
task_def = EvalTaskDefWithProvider(
|
||||||
|
identifier=task_id,
|
||||||
|
dataset_id="test_dataset_for_eval",
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
provider_id="meta-reference",
|
||||||
|
)
|
||||||
|
await eval_tasks_impl.register_eval_task(task_def)
|
||||||
|
|
||||||
response = await eval_impl.evaluate_rows(
|
response = await eval_impl.evaluate_rows(
|
||||||
|
task_id=task_id,
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
|
@ -91,13 +99,16 @@ class Testeval:
|
||||||
"meta-reference::subset_of",
|
"meta-reference::subset_of",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
task_id = "meta-reference::app_eval-2"
|
||||||
|
task_def = EvalTaskDefWithProvider(
|
||||||
|
identifier=task_id,
|
||||||
|
dataset_id="test_dataset_for_eval",
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
provider_id="meta-reference",
|
||||||
|
)
|
||||||
|
await eval_tasks_impl.register_eval_task(task_def)
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
task=EvalTaskDef(
|
task_id=task_id,
|
||||||
# NOTE: this is needed to make the router work for all app evals
|
|
||||||
identifier="meta-reference::app_eval",
|
|
||||||
dataset_id="test_dataset_for_eval",
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
),
|
|
||||||
task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
|
@ -106,13 +117,9 @@ class Testeval:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert response.job_id == "0"
|
assert response.job_id == "0"
|
||||||
job_status = await eval_impl.job_status(
|
job_status = await eval_impl.job_status(task_id, response.job_id)
|
||||||
response.job_id, "meta-reference::app_eval"
|
|
||||||
)
|
|
||||||
assert job_status and job_status.value == "completed"
|
assert job_status and job_status.value == "completed"
|
||||||
eval_response = await eval_impl.job_result(
|
eval_response = await eval_impl.job_result(task_id, response.job_id)
|
||||||
response.job_id, "meta-reference::app_eval"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert eval_response is not None
|
assert eval_response is not None
|
||||||
assert len(eval_response.generations) == 5
|
assert len(eval_response.generations) == 5
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
providers:
|
|
||||||
datasetio:
|
|
||||||
- provider_id: test-meta
|
|
||||||
provider_type: meta-reference
|
|
||||||
config: {}
|
|
||||||
scoring:
|
|
||||||
- provider_id: test-meta
|
|
||||||
provider_type: meta-reference
|
|
||||||
config: {}
|
|
||||||
- provider_id: test-braintrust
|
|
||||||
provider_type: braintrust
|
|
||||||
config: {}
|
|
||||||
inference:
|
|
||||||
- provider_id: tgi0
|
|
||||||
provider_type: remote::tgi
|
|
||||||
config:
|
|
||||||
url: http://127.0.0.1:5009
|
|
|
@ -1,152 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|
||||||
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
|
||||||
# since it depends on the provider you are testing. On top of that you need
|
|
||||||
# `pytest` and `pytest-asyncio` installed.
|
|
||||||
#
|
|
||||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
|
||||||
#
|
|
||||||
# 3. Run:
|
|
||||||
#
|
|
||||||
# ```bash
|
|
||||||
# PROVIDER_ID=<your_provider> \
|
|
||||||
# PROVIDER_CONFIG=provider_config.yaml \
|
|
||||||
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
|
|
||||||
# --tb=short --disable-warnings
|
|
||||||
# ```
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def scoring_settings():
|
|
||||||
impls = await resolve_impls_for_test(
|
|
||||||
Api.scoring, deps=[Api.datasetio, Api.inference]
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"scoring_impl": impls[Api.scoring],
|
|
||||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
|
||||||
"datasets_impl": impls[Api.datasets],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def provider_scoring_functions():
|
|
||||||
return {
|
|
||||||
"meta-reference": {
|
|
||||||
"meta-reference::equality",
|
|
||||||
"meta-reference::subset_of",
|
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
|
||||||
},
|
|
||||||
"braintrust": {
|
|
||||||
"braintrust::factuality",
|
|
||||||
"braintrust::answer-correctness",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
|
|
||||||
scoring_impl = scoring_settings["scoring_impl"]
|
|
||||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
|
||||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
|
||||||
assert isinstance(scoring_functions, list)
|
|
||||||
assert len(scoring_functions) > 0
|
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
|
||||||
# get current provider_type we're testing
|
|
||||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
|
||||||
provider_type = provider.__provider_spec__.provider_type
|
|
||||||
|
|
||||||
for x in provider_scoring_functions[provider_type]:
|
|
||||||
assert x in function_ids
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scoring_functions_register(scoring_settings):
|
|
||||||
scoring_impl = scoring_settings["scoring_impl"]
|
|
||||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
|
||||||
datasets_impl = scoring_settings["datasets_impl"]
|
|
||||||
|
|
||||||
# get current provider_type we're testing
|
|
||||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
|
||||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
|
||||||
provider_type = provider.__provider_spec__.provider_type
|
|
||||||
if provider_type not in ("meta-reference"):
|
|
||||||
pytest.skip(
|
|
||||||
"Other scoring providers don't support registering scoring functions."
|
|
||||||
)
|
|
||||||
|
|
||||||
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
|
||||||
# register the scoring function
|
|
||||||
await scoring_functions_impl.register_scoring_function(
|
|
||||||
ScoringFnDefWithProvider(
|
|
||||||
identifier="meta-reference::llm_as_judge_8b_random",
|
|
||||||
description="Llm As Judge Scoring Function",
|
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
|
||||||
context=LLMAsJudgeContext(
|
|
||||||
prompt_template=test_prompt,
|
|
||||||
judge_model="Llama3.1-8B-Instruct",
|
|
||||||
judge_score_regex=[r"Number: (\d+)"],
|
|
||||||
),
|
|
||||||
provider_id="test-meta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
|
||||||
assert isinstance(scoring_functions, list)
|
|
||||||
assert len(scoring_functions) > 0
|
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
|
||||||
assert "meta-reference::llm_as_judge_8b_random" in function_ids
|
|
||||||
|
|
||||||
# test score using newly registered scoring function
|
|
||||||
await register_dataset(datasets_impl)
|
|
||||||
response = await datasets_impl.list_datasets()
|
|
||||||
assert len(response) == 1
|
|
||||||
response = await scoring_impl.score_batch(
|
|
||||||
dataset_id=response[0].identifier,
|
|
||||||
scoring_functions=[
|
|
||||||
"meta-reference::llm_as_judge_8b_random",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert "meta-reference::llm_as_judge_8b_random" in response.results
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
|
||||||
scoring_impl = scoring_settings["scoring_impl"]
|
|
||||||
datasets_impl = scoring_settings["datasets_impl"]
|
|
||||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
|
||||||
await register_dataset(datasets_impl)
|
|
||||||
|
|
||||||
response = await datasets_impl.list_datasets()
|
|
||||||
assert len(response) == 1
|
|
||||||
|
|
||||||
# get current provider_type we're testing
|
|
||||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
|
||||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
|
||||||
provider_type = provider.__provider_spec__.provider_type
|
|
||||||
|
|
||||||
response = await scoring_impl.score_batch(
|
|
||||||
dataset_id=response[0].identifier,
|
|
||||||
scoring_functions=list(provider_scoring_functions[provider_type]),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
|
||||||
for x in provider_scoring_functions[provider_type]:
|
|
||||||
assert x in response.results
|
|
Loading…
Add table
Add a link
Reference in a new issue