diff --git a/llama_stack/apis/evals/client.py b/llama_stack/apis/evals/client.py new file mode 100644 index 000000000..b1cb53607 --- /dev/null +++ b/llama_stack/apis/evals/client.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +import fire +import httpx +from termcolor import cprint + +from .evals import * # noqa: F403 + + +class EvaluationClient(Evals): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/evals/run", + json={ + "model": model, + "dataset": dataset, + "task": task, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return EvaluateResponse(**response.json()) + + +async def run_main(host: str, port: int): + client = EvaluationClient(f"http://{host}:{port}") + + response = await client.run_evals( + "Llama3.1-8B-Instruct", + "mmlu.csv", + "mmlu", + ) + cprint(f"evaluate response={response}", "green") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index b9fc339a2..5a4fafd4e 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import List, Protocol from llama_models.schema_utils import webmethod @@ -16,22 +15,6 @@ from llama_stack.apis.dataset import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403 -class TextGenerationMetric(Enum): - perplexity = "perplexity" - rouge = "rouge" - bleu = "bleu" - - -class QuestionAnsweringMetric(Enum): - em = "em" - f1 = "f1" - - -class SummarizationMetric(Enum): - rouge = "rouge" - bleu = "bleu" - - class EvaluationJob(BaseModel): job_uuid: str @@ -54,28 +37,7 @@ class EvaluateTaskRequestCommon(BaseModel): class EvaluateResponse(BaseModel): """Scores for evaluation.""" - scores = Dict[str, str] - - -@json_schema_type -class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon): - """Request to evaluate text generation.""" - - metrics: List[TextGenerationMetric] - - -@json_schema_type -class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon): - """Request to evaluate question answering.""" - - metrics: List[QuestionAnsweringMetric] - - -@json_schema_type -class EvaluateSummarizationRequest(EvaluateTaskRequestCommon): - """Request to evaluate summarization.""" - - metrics: List[SummarizationMetric] + metrics: Dict[str, float] @json_schema_type @@ -97,33 +59,36 @@ class EvaluationJobCreateResponse(BaseModel): job_uuid: str -class Evaluations(Protocol): - @webmethod(route="/evaluate") - async def evaluate( - self, model: str, dataset: str, task: str +class Evals(Protocol): + @webmethod(route="/evals/run") + async def run_evals( + self, + model: str, + dataset: str, + task: str, ) -> EvaluateResponse: ... - @webmethod(route="/evaluate/jobs") + @webmethod(route="/evals/jobs") def get_evaluation_jobs(self) -> List[EvaluationJob]: ... - @webmethod(route="/evaluate/job/create") + @webmethod(route="/evals/job/create") async def create_evaluation_job( self, model: str, dataset: str, task: str ) -> EvaluationJob: ... - @webmethod(route="/evaluate/job/status") + @webmethod(route="/evals/job/status") def get_evaluation_job_status( self, job_uuid: str ) -> EvaluationJobStatusResponse: ... # sends SSE stream of logs - @webmethod(route="/evaluate/job/logs") + @webmethod(route="/evals/job/logs") def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... - @webmethod(route="/evaluate/job/cancel") + @webmethod(route="/evals/job/cancel") def cancel_evaluation_job(self, job_uuid: str) -> None: ... - @webmethod(route="/evaluate/job/artifacts") + @webmethod(route="/evals/job/artifacts") def get_evaluation_job_artifacts( self, job_uuid: str ) -> EvaluationJobArtifactsResponse: ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index ae7d9ab40..e50dc810a 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,6 +8,7 @@ import importlib from typing import Any, Dict, List, Set from llama_stack.distribution.datatypes import * # noqa: F403 + from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 601e80e5d..c3fedd882 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -10,6 +10,7 @@ from typing import Dict, List from pydantic import BaseModel from llama_stack.apis.agents import Agents +from llama_stack.apis.evals import Evals from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory @@ -41,6 +42,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.shields: Shields, Api.memory_banks: MemoryBanks, Api.inspect: Inspect, + Api.evals: Evals, } for api, protocol in protocols.items(): diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index a2e8851a2..604b457f7 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -24,6 +24,8 @@ class Api(Enum): shields = "shields" memory_banks = "memory_banks" + evals = "evals" + # built-in API inspect = "inspect" diff --git a/llama_stack/providers/impls/meta_reference/evals/__init__.py b/llama_stack/providers/impls/meta_reference/evals/__init__.py new file mode 100644 index 000000000..f4dd4b79d --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import MetaReferenceEvalsImplConfig # noqa +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.datatypes import Api, ProviderSpec + + +async def get_provider_impl( + config: MetaReferenceEvalsImplConfig, deps: Dict[Api, ProviderSpec] +): + from .evals import MetaReferenceEvalsImpl + + impl = MetaReferenceEvalsImpl(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/evals/config.py b/llama_stack/providers/impls/meta_reference/evals/config.py new file mode 100644 index 000000000..05dee366e --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class MetaReferenceEvalsImplConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py new file mode 100644 index 000000000..c68414c43 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.evals import * # noqa: F403 + +from .config import MetaReferenceEvalsImplConfig + + +class MetaReferenceEvalsImpl(Evals): + def __init__(self, config: MetaReferenceEvalsImplConfig, inference_api: Inference): + self.inference_api = inference_api + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_evals( + self, + model: str, + dataset: str, + task: str, + ) -> EvaluateResponse: + print("hi") + return EvaluateResponse( + metrics={ + "accuracy": 0.5, + } + ) diff --git a/llama_stack/providers/registry/evals.py b/llama_stack/providers/registry/evals.py new file mode 100644 index 000000000..8f9bacdd6 --- /dev/null +++ b/llama_stack/providers/registry/evals.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.evals, + provider_type="meta-reference", + pip_packages=[ + "matplotlib", + "pillow", + "pandas", + "scikit-learn", + ], + module="llama_stack.providers.impls.meta_reference.evals", + config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig", + api_dependencies=[ + Api.inference, + ], + ), + ] diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index e4319750a..129b71e34 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -10,7 +10,11 @@ apis_to_serve: - memory_banks - inference - safety +- evals api_providers: + evals: + provider_type: meta-reference + config: {} inference: providers: - meta-reference @@ -34,12 +38,12 @@ routing_table: inference: - provider_type: meta-reference config: - model: Llama3.1-8B-Instruct + model: Llama3.2-1B quantization: null torch_seed: null max_seq_len: 4096 max_batch_size: 1 - routing_key: Llama3.1-8B-Instruct + routing_key: Llama3.2-1B safety: - provider_type: meta-reference config: