diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 0922cb512..147c5b2ae 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov |-----|-------------| | agents | `inline::meta-reference` | | datasetio | `inline::localfs` | -| eval | `inline::meta-reference` | +| eval | `remote::nvidia` | | inference | `remote::nvidia` | | post_training | `remote::nvidia` | | safety | `remote::nvidia` | @@ -29,6 +29,7 @@ The following environment variables can be configured: - `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`) - `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`) - `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`) +- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`) - `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`) - `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`) diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index f3e42c531..9604d5da4 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec def available_providers() -> List[ProviderSpec]: @@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]: Api.agents, ], ), + remote_provider_spec( + api=Api.eval, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[ + "requests", + ], + module="llama_stack.providers.remote.eval.nvidia", + config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", + ), + api_dependencies=[ + Api.datasetio, + Api.datasets, + Api.scoring, + Api.inference, + Api.agents, + ], + ), ] diff --git a/llama_stack/providers/remote/eval/__init__.py b/llama_stack/providers/remote/eval/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/remote/eval/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/remote/eval/nvidia/README.md b/llama_stack/providers/remote/eval/nvidia/README.md new file mode 100644 index 000000000..cebc77920 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/README.md @@ -0,0 +1,134 @@ +# NVIDIA NeMo Evaluator Eval Provider + + +## Overview + +For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used. + +Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation. + +### Example for register an academic benchmark + +``` +POST /eval/benchmarks +``` +```json +{ + "benchmark_id": "mmlu", + "dataset_id": "", + "scoring_functions": [], + "metadata": { + "type": "mmlu" + } +} +``` + +### Example for register a custom evaluation + +``` +POST /eval/benchmarks +``` +```json +{ + "benchmark_id": "my-custom-benchmark", + "dataset_id": "", + "scoring_functions": [], + "metadata": { + "type": "custom", + "params": { + "parallelism": 8 + }, + "tasks": { + "qa": { + "type": "completion", + "params": { + "template": { + "prompt": "{{prompt}}", + "max_tokens": 200 + } + }, + "dataset": { + "files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl" + }, + "metrics": { + "bleu": { + "type": "bleu", + "params": { + "references": [ + "{{ideal_response}}" + ] + } + } + } + } + } + } +} +``` + +### Example for triggering a benchmark/custom evaluation + +``` +POST /eval/benchmarks/{benchmark_id}/jobs +``` +```json +{ + "benchmark_id": "my-custom-benchmark", + "benchmark_config": { + "eval_candidate": { + "type": "model", + "model": "meta-llama/Llama3.1-8B-Instruct", + "sampling_params": { + "max_tokens": 100, + "temperature": 0.7 + } + }, + "scoring_params": {} + } +} +``` + +Response example: +```json +{ + "job_id": "eval-1234", + "status": "in_progress" +} +``` + +### Example for getting the status of a job +``` +GET /eval/benchmarks/{benchmark_id}/jobs/{job_id} +``` + +Response example: +```json +{ + "job_id": "eval-1234", + "status": "in_progress" +} +``` + +### Example for cancelling a job +``` +POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel +``` + +### Example for getting the results +``` +GET /eval/benchmarks/{benchmark_id}/results +``` +```json +{ + "generations": [], + "scores": { + "{benchmark_id}": { + "score_rows": [], + "aggregated_results": { + "tasks": {}, + "groups": {} + } + } + } +} +``` diff --git a/llama_stack/providers/remote/eval/nvidia/__init__.py b/llama_stack/providers/remote/eval/nvidia/__init__.py new file mode 100644 index 000000000..8abbec9b2 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict + +from llama_stack.distribution.datatypes import Api + +from .config import NVIDIAEvalConfig + + +async def get_adapter_impl( + config: NVIDIAEvalConfig, + deps: Dict[Api, Any], +): + from .eval import NVIDIAEvalImpl + + impl = NVIDIAEvalImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + deps[Api.scoring], + deps[Api.inference], + deps[Api.agents], + ) + await impl.initialize() + return impl + + +__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"] diff --git a/llama_stack/providers/remote/eval/nvidia/config.py b/llama_stack/providers/remote/eval/nvidia/config.py new file mode 100644 index 000000000..b660fcd68 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import os +from typing import Any, Dict + +from pydantic import BaseModel, Field + + +class NVIDIAEvalConfig(BaseModel): + """ + Configuration for the NVIDIA NeMo Evaluator microservice endpoint. + + Attributes: + evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000. + """ + + evaluator_url: str = Field( + default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"), + description="The url for accessing the evaluator service", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}", + } diff --git a/llama_stack/providers/remote/eval/nvidia/eval.py b/llama_stack/providers/remote/eval/nvidia/eval.py new file mode 100644 index 000000000..e1a3b5355 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/eval.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.agents import Agents +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.inference import Inference +from llama_stack.apis.scoring import Scoring, ScoringResult +from llama_stack.providers.datatypes import BenchmarksProtocolPrivate +from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper + +from .....apis.common.job_types import Job, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse +from .config import NVIDIAEvalConfig + +DEFAULT_NAMESPACE = "nvidia" + + +class NVIDIAEvalImpl( + Eval, + BenchmarksProtocolPrivate, + ModelRegistryHelper, +): + def __init__( + self, + config: NVIDIAEvalConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + scoring_api: Scoring, + inference_api: Inference, + agents_api: Agents, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.scoring_api = scoring_api + self.inference_api = inference_api + self.agents_api = agents_api + + ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def _evaluator_get(self, path): + """Helper for making GET requests to the evaluator service.""" + response = requests.get(url=f"{self.config.evaluator_url}{path}") + response.raise_for_status() + return response.json() + + async def _evaluator_post(self, path, data): + """Helper for making POST requests to the evaluator service.""" + response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data) + response.raise_for_status() + return response.json() + + async def register_benchmark(self, task_def: Benchmark) -> None: + """Register a benchmark as an evaluation configuration.""" + await self._evaluator_post( + "/v1/evaluation/configs", + { + "namespace": DEFAULT_NAMESPACE, + "name": task_def.benchmark_id, + # metadata is copied to request body as-is + **task_def.metadata, + }, + ) + + async def run_eval( + self, + benchmark_id: str, + benchmark_config: BenchmarkConfig, + ) -> Job: + """Run an evaluation job for a benchmark.""" + model = ( + benchmark_config.eval_candidate.model + if benchmark_config.eval_candidate.type == "model" + else benchmark_config.eval_candidate.config.model + ) + nvidia_model = self.get_provider_model_id(model) or model + + result = await self._evaluator_post( + "/v1/evaluation/jobs", + { + "config": f"{DEFAULT_NAMESPACE}/{benchmark_id}", + "target": {"type": "model", "model": nvidia_model}, + }, + ) + + return Job(job_id=result["id"], status=JobStatus.in_progress) + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + raise NotImplementedError() + + async def job_status(self, benchmark_id: str, job_id: str) -> Job: + """Get the status of an evaluation job. + + EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed". + JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed" + """ + result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}") + result_status = result["status"] + + job_status = JobStatus.failed + if result_status in ["created", "pending"]: + job_status = JobStatus.scheduled + elif result_status in ["running"]: + job_status = JobStatus.in_progress + elif result_status in ["completed"]: + job_status = JobStatus.completed + elif result_status in ["cancelled"]: + job_status = JobStatus.cancelled + + return Job(job_id=job_id, status=job_status) + + async def job_cancel(self, benchmark_id: str, job_id: str) -> None: + """Cancel the evaluation job.""" + await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {}) + + async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: + """Returns the results of the evaluation job.""" + + job = await self.job_status(benchmark_id, job_id) + status = job.status + if not status or status != JobStatus.completed: + raise ValueError(f"Job {job_id} not completed. Status: {status.value}") + + result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results") + + return EvaluateResponse( + # TODO: these are stored in detailed results on NeMo Evaluator side; can be added + generations=[], + scores={ + benchmark_id: ScoringResult( + score_rows=[], + aggregated_results=result, + ) + }, + ) diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index b96191752..63c4ecfa5 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -394,12 +394,10 @@ "aiosqlite", "blobfile", "chardet", - "emoji", "faiss-cpu", "fastapi", "fire", "httpx", - "langdetect", "matplotlib", "nltk", "numpy", @@ -411,7 +409,6 @@ "psycopg2-binary", "pymongo", "pypdf", - "pythainlp", "redis", "requests", "scikit-learn", @@ -419,7 +416,6 @@ "sentencepiece", "tqdm", "transformers", - "tree_sitter", "uvicorn" ], "ollama": [ diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index f99ff6c81..a33fa3737 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use NVIDIA NIM for running LLM inference and safety + description: Use NVIDIA NIM for running LLM inference, evaluation and safety providers: inference: - remote::nvidia @@ -13,7 +13,7 @@ distribution_spec: telemetry: - inline::meta-reference eval: - - inline::meta-reference + - remote::nvidia post_training: - remote::nvidia datasetio: diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index a0cefba52..32ddf78e3 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -7,6 +7,7 @@ from pathlib import Path from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig @@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate: "safety": ["remote::nvidia"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], + "eval": ["remote::nvidia"], "post_training": ["remote::nvidia"], "datasetio": ["inline::localfs"], "scoring": ["inline::basic"], @@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="remote::nvidia", config=NVIDIASafetyConfig.sample_run_config(), ) + eval_provider = Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAEvalConfig.sample_run_config(), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="nvidia", @@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name="nvidia", distro_type="self_hosted", - description="Use NVIDIA NIM for running LLM inference and safety", + description="Use NVIDIA NIM for running LLM inference, evaluation and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider], + "eval": [eval_provider], }, default_models=default_models, default_tool_groups=default_tool_groups, @@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate: "inference": [ inference_provider, safety_provider, - ] + ], + "eval": [eval_provider], }, default_models=[inference_model, safety_model], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")], @@ -119,6 +127,10 @@ def get_distribution_template() -> DistributionTemplate: "http://0.0.0.0:7331", "URL for the NeMo Guardrails Service", ), + "NVIDIA_EVALUATOR_URL": ( + "http://0.0.0.0:7331", + "URL for the NeMo Evaluator Service", + ), "INFERENCE_MODEL": ( "Llama3.1-8B-Instruct", "Inference model", diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 658d9377e..8483fb9bf 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -53,13 +53,10 @@ providers: sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db + evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331} post_training: - provider_id: nvidia provider_type: remote::nvidia diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index ff548d82e..d7e2753ba 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -48,13 +48,10 @@ providers: sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db + evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331} post_training: - provider_id: nvidia provider_type: remote::nvidia diff --git a/tests/unit/providers/nvidia/test_eval.py b/tests/unit/providers/nvidia/test_eval.py new file mode 100644 index 000000000..584ca2101 --- /dev/null +++ b/tests/unit/providers/nvidia/test_eval.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.common.job_types import Job, JobStatus +from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams +from llama_stack.models.llama.sku_types import CoreModelId +from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig +from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl + +MOCK_DATASET_ID = "default/test-dataset" +MOCK_BENCHMARK_ID = "test-benchmark" + + +class TestNVIDIAEvalImpl(unittest.TestCase): + def setUp(self): + os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" + + # Create mock APIs + self.datasetio_api = MagicMock() + self.datasets_api = MagicMock() + self.scoring_api = MagicMock() + self.inference_api = MagicMock() + self.agents_api = MagicMock() + + self.config = NVIDIAEvalConfig( + evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], + ) + + self.eval_impl = NVIDIAEvalImpl( + config=self.config, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + scoring_api=self.scoring_api, + inference_api=self.inference_api, + agents_api=self.agents_api, + ) + + # Mock the HTTP request methods + self.evaluator_get_patcher = patch( + "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get" + ) + self.evaluator_post_patcher = patch( + "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post" + ) + + self.mock_evaluator_get = self.evaluator_get_patcher.start() + self.mock_evaluator_post = self.evaluator_post_patcher.start() + + def tearDown(self): + """Clean up after each test.""" + self.evaluator_get_patcher.stop() + self.evaluator_post_patcher.stop() + + def _assert_request_body(self, expected_json): + """Helper method to verify request body in Evaluator POST request is correct""" + call_args = self.mock_evaluator_post.call_args + actual_json = call_args[0][1] + + # Check that all expected keys contain the expected values in the actual JSON + for key, value in expected_json.items(): + assert key in actual_json, f"Key '{key}' missing in actual JSON" + + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']" + assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'" + else: + assert actual_json[key] == value, f"Value mismatch for '{key}'" + + @pytest.fixture(autouse=True) + def inject_fixtures(self, run_async): + self.run_async = run_async + + def test_register_benchmark(self): + eval_config = { + "type": "custom", + "params": {"parallelism": 8}, + "tasks": { + "qa": { + "type": "completion", + "params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}}, + "dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"}, + "metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}}, + } + }, + } + + benchmark = Benchmark( + provider_id="nvidia", + type="benchmark", + identifier=MOCK_BENCHMARK_ID, + dataset_id=MOCK_DATASET_ID, + scoring_functions=["basic::equality"], + metadata=eval_config, + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Register the benchmark + self.run_async(self.eval_impl.register_benchmark(benchmark)) + + # Verify the Evaluator API was called correctly + self.mock_evaluator_post.assert_called_once() + self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}) + + def test_run_eval(self): + benchmark_config = BenchmarkConfig( + eval_candidate=ModelCandidate( + type="model", + model=CoreModelId.llama3_1_8b_instruct.value, + sampling_params=SamplingParams(max_tokens=100, temperature=0.7), + ) + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "created"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Run the Evaluation job + result = self.run_async( + self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config) + ) + + # Verify the Evaluator API was called correctly + self.mock_evaluator_post.assert_called_once() + self._assert_request_body( + { + "config": f"nvidia/{MOCK_BENCHMARK_ID}", + "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, + } + ) + + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.in_progress + + def test_job_status(self): + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "completed"} + self.mock_evaluator_get.return_value = mock_evaluator_response + + # Get the Evaluation job + result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.completed + + # Verify the API was called correctly + self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}") + + def test_job_cancel(self): + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "cancelled"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Cancel the Evaluation job + self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the API was called correctly + self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) + + def test_job_result(self): + # Mock Evaluator API responses + mock_job_status_response = {"id": "job-123", "status": "completed"} + mock_job_results_response = { + "id": "job-123", + "status": "completed", + "results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}}, + } + self.mock_evaluator_get.side_effect = [ + mock_job_status_response, # First call to retrieve job + mock_job_results_response, # Second call to retrieve job results + ] + + # Get the Evaluation job results + result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the result + assert isinstance(result, EvaluateResponse) + assert MOCK_BENCHMARK_ID in result.scores + assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85 + + # Verify the API was called correctly + assert self.mock_evaluator_get.call_count == 2 + self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123") + self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")