diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index ab203ebb8..3161e3e87 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -3,6 +3,8 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel @@ -10,3 +12,10 @@ from pydantic import BaseModel @json_schema_type class Job(BaseModel): job_id: str + + +@json_schema_type +class JobStatus(Enum): + completed = "completed" + in_progress = "in_progress" + not_found = "not_found" diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index a97af1fc0..cfcada766 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -12,7 +12,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.schema_utils import json_schema_type, webmethod from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig -from llama_stack.apis.common.job_types import Job +from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.scoring import * # noqa: F403 @@ -40,7 +40,7 @@ class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] # each key in the dict is a scoring function name - scores: List[Dict[str, ScoringResult]] + scores: Dict[str, ScoringResult] class Eval(Protocol): @@ -61,10 +61,10 @@ class Eval(Protocol): ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status(self, job_id: str) -> None: ... + async def job_status(self, job_id: str) -> JobStatus: ... @webmethod(route="/eval/job/cancel", method="POST") async def job_cancel(self, job_id: str) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result(self, job_id: str) -> None: ... + async def job_result(self, job_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b9b9fb229..cfe31a21d 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory @@ -46,6 +47,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.datasetio: DatasetIO, Api.scoring_functions: ScoringFunctions, Api.scoring: Scoring, + Api.eval: Eval, } diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 903ff5438..8d476a509 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -25,6 +25,7 @@ class Api(Enum): memory = "memory" datasetio = "datasetio" scoring = "scoring" + eval = "eval" telemetry = "telemetry" diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index 57ce8e10f..686c100f9 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -143,11 +143,12 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): else: next_page_token = int(page_token) - if rows_in_page == -1: - rows = dataset_info.dataset_impl[next_page_token:] - start = next_page_token - end = min(start + rows_in_page, len(dataset_info.dataset_impl)) + if rows_in_page == -1: + end = len(dataset_info.dataset_impl) + else: + end = min(start + rows_in_page, len(dataset_info.dataset_impl)) + rows = dataset_info.dataset_impl[start:end] return PaginatedRowsResult( diff --git a/llama_stack/providers/impls/meta_reference/eval/__init__.py b/llama_stack/providers/impls/meta_reference/eval/__init__.py new file mode 100644 index 000000000..fb285c668 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/eval/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import MetaReferenceEvalConfig + + +async def get_provider_impl( + config: MetaReferenceEvalConfig, + deps: Dict[Api, ProviderSpec], +): + from .eval import MetaReferenceEvalImpl + + impl = MetaReferenceEvalImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + deps[Api.scoring], + deps[Api.inference], + ) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/eval/config.py b/llama_stack/providers/impls/meta_reference/eval/config.py new file mode 100644 index 000000000..1892da2a2 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/eval/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.eval import * # noqa: F401, F403 + + +class MetaReferenceEvalConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/eval/eval.py b/llama_stack/providers/impls/meta_reference/eval/eval.py new file mode 100644 index 000000000..abf187bab --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/eval/eval.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.common.job_types import Job +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.inference import Inference +from llama_stack.apis.scoring import Scoring + +from .config import MetaReferenceEvalConfig + + +class MetaReferenceEvalImpl(Eval): + def __init__( + self, + config: MetaReferenceEvalConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + scoring_api: Scoring, + inference_api: Inference, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.scoring_api = scoring_api + self.inference_api = inference_api + + # TODO: assume sync job, will need jobs API for async scheduling + self.jobs = {} + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + # TODO: we will require user defined message types for ToolResponseMessage or include message.context + # for now uses basic schema where messages={type: "user", content: "input_query"} + for required_column in ["expected_answer", "input_query"]: + if required_column not in dataset_def.dataset_schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column." + ) + if dataset_def.dataset_schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." + ) + + async def evaluate_batch( + self, + dataset_id: str, + candidate: EvalCandidate, + scoring_functions: List[str], + ) -> Job: + await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.evaluate( + input_rows=all_rows.rows, + candidate=candidate, + scoring_functions=scoring_functions, + ) + + job_id = str(len(self.jobs)) + self.jobs[job_id] = res + return Job(job_id=job_id) + + async def evaluate( + self, + input_rows: List[Dict[str, Any]], + candidate: EvalCandidate, + scoring_functions: List[str], + ) -> EvaluateResponse: + if candidate.type == "agent": + raise NotImplementedError( + "Evaluation with generation has not been implemented for agents" + ) + generations = [] + for x in input_rows: + input_query = x["input_query"] + messages = [] + if candidate.system_message: + messages.append(candidate.system_message) + messages.append( + UserMessage(content=input_query), + ) + response = await self.inference_api.chat_completion( + model=candidate.model, + messages=messages, + ) + generations.append( + {"generated_answer": response.completion_message.content} + ) + + # scoring with generated_answer + score_input_rows = [ + input_r | generated_r + for input_r, generated_r in zip(input_rows, generations) + ] + + score_response = await self.scoring_api.score( + input_rows=score_input_rows, scoring_functions=scoring_functions + ) + + return EvaluateResponse(generations=generations, scores=score_response.results) + + async def job_status(self, job_id: str) -> JobStatus: + if job_id in self.jobs: + return JobStatus.completed + else: + return JobStatus.not_found + + async def job_cancel(self, job_id: str) -> None: + raise NotImplementedError("Job cancel is not implemented yet") + + async def job_result(self, job_id: str) -> None: + status = await self.job_status(job_id) + if status != JobStatus.completed: + raise ValueError(f"Job is not completed, Status: {status.value}") + + return self.jobs[job_id] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py new file mode 100644 index 000000000..506bc60a7 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( + BaseScorer, +) +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 + + +class InclusionScorer(BaseScorer): + """ + A scorer that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. + """ + + scoring_function_def = DeterministicFunctionDef( + identifier="inclusion", + description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), + ) + + def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + assert "expected_answer" in input_row, "Expected answer not found in input row." + assert ( + "generated_answer" in input_row + ), "Generated answer not found in input row." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + score = 1.0 if expected_answer in generated_answer else 0.0 + return { + "score": score, + } + + def aggregate(self, scoring_results: List[ScoringResult]) -> ScoringResult: + assert len(scoring_results) > 0, "Empty scoring results provided." + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 73f9fcc5a..554d637b8 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -16,11 +16,15 @@ from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( EqualityScorer, ) +from llama_stack.providers.impls.meta_reference.scoring.scorer.inclusion_scorer import ( + InclusionScorer, +) from .config import MetaReferenceScoringConfig SUPPORTED_SCORERS = [ EqualityScorer, + InclusionScorer, ] SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS} diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py new file mode 100644 index 000000000..fc7c923d9 --- /dev/null +++ b/llama_stack/providers/registry/eval.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.eval, + provider_type="meta-reference", + pip_packages=[], + module="llama_stack.providers.impls.meta_reference.eval", + config_class="llama_stack.providers.impls.meta_reference.eval.MetaReferenceEvalConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + Api.scoring, + Api.inference, + ], + ), + ] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 9a351ba30..755ed9735 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -61,20 +61,26 @@ def data_url_from_file(file_path: str) -> str: return data_url -async def register_dataset(datasets_impl: Datasets): +async def register_dataset( + datasets_impl: Datasets, include_generated_answer=True, dataset_id="test_dataset" +): test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" test_url = data_url_from_file(str(test_file)) + + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + } + if include_generated_answer: + dataset_schema["generated_answer"] = StringType() + dataset = DatasetDefWithProvider( - identifier="test_dataset", + identifier=dataset_id, provider_id=os.environ["PROVIDER_ID"], url=URL( uri=test_url, ), - dataset_schema={ - "generated_answer": StringType(), - "expected_answer": StringType(), - "input_query": StringType(), - }, + dataset_schema=dataset_schema, ) await datasets_impl.register_dataset(dataset) diff --git a/llama_stack/providers/tests/eval/__init__.py b/llama_stack/providers/tests/eval/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/eval/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/eval/provider_config_example.yaml b/llama_stack/providers/tests/eval/provider_config_example.yaml new file mode 100644 index 000000000..1576d2ef0 --- /dev/null +++ b/llama_stack/providers/tests/eval/provider_config_example.yaml @@ -0,0 +1,18 @@ +providers: + datasetio: + - provider_id: test-meta + provider_type: meta-reference + config: {} + scoring: + - provider_id: test-meta + provider_type: meta-reference + config: {} + eval: + - provider_id: test-meta + provider_type: meta-reference + config: {} + inference: + - provider_id: test-tgi + provider_type: remote::tgi + config: + url: http://127.0.0.1:5009 diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py new file mode 100644 index 000000000..099153f03 --- /dev/null +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +import pytest_asyncio + +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.eval.eval import ModelCandidate +from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_models.llama3.api import SamplingParams + +from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/eval/test_eval.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def eval_settings(): + impls = await resolve_impls_for_test( + Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference] + ) + return { + "eval_impl": impls[Api.eval], + "scoring_impl": impls[Api.scoring], + "datasets_impl": impls[Api.datasets], + } + + +@pytest.mark.asyncio +async def test_eval(eval_settings): + datasets_impl = eval_settings["datasets_impl"] + await register_dataset( + datasets_impl, + include_generated_answer=False, + dataset_id="test_dataset_for_eval", + ) + + response = await datasets_impl.list_datasets() + assert len(response) == 1 + + eval_impl = eval_settings["eval_impl"] + response = await eval_impl.evaluate_batch( + dataset_id=response[0].identifier, + candidate=ModelCandidate( + model="Llama3.1-8B-Instruct", + sampling_params=SamplingParams(), + ), + scoring_functions=["inclusion"], + ) + assert response.job_id == "0" + job_status = await eval_impl.job_status(response.job_id) + + assert job_status.value == "completed" + + eval_response = await eval_impl.job_result(response.job_id) + + assert eval_response is not None + assert len(eval_response.generations) == 5 + assert "inclusion" in eval_response.scores