feat: Add NVIDIA Eval integration (#1890)

# What does this PR do?
This PR adds support for NVIDIA's NeMo Evaluator API to the Llama Stack
eval module. The integration enables users to evaluate models via the
Llama Stack interface.

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]
1. Added unit tests and successfully ran from root of project:
`./scripts/unit-tests.sh tests/unit/providers/nvidia/test_eval.py`
```
tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_cancel PASSED
tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_result PASSED
tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_status PASSED
tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_register_benchmark PASSED
tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_run_eval PASSED
```
2. Verified I could build the Llama Stack image: `LLAMA_STACK_DIR=$(pwd)
llama stack build --template nvidia --image-type venv`

Documentation added to
`llama_stack/providers/remote/eval/nvidia/README.md`

---------

Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
Jash Gulabrai 2025-04-24 20:12:42 -04:00 committed by GitHub
parent 0b6cd45950
commit cc77f79f55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 598 additions and 23 deletions

View file

@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| eval | `remote::nvidia` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
@ -29,6 +29,7 @@ The following environment variables can be configured:
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)

View file

@ -6,7 +6,7 @@
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
def available_providers() -> List[ProviderSpec]:
@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]:
Api.agents,
],
),
remote_provider_spec(
api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"requests",
],
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
),
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
],
),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,134 @@
# NVIDIA NeMo Evaluator Eval Provider
## Overview
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
### Example for register an academic benchmark
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "mmlu",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "mmlu"
}
}
```
### Example for register a custom evaluation
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "my-custom-benchmark",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "custom",
"params": {
"parallelism": 8
},
"tasks": {
"qa": {
"type": "completion",
"params": {
"template": {
"prompt": "{{prompt}}",
"max_tokens": 200
}
},
"dataset": {
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
},
"metrics": {
"bleu": {
"type": "bleu",
"params": {
"references": [
"{{ideal_response}}"
]
}
}
}
}
}
}
}
```
### Example for triggering a benchmark/custom evaluation
```
POST /eval/benchmarks/{benchmark_id}/jobs
```
```json
{
"benchmark_id": "my-custom-benchmark",
"benchmark_config": {
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama3.1-8B-Instruct",
"sampling_params": {
"max_tokens": 100,
"temperature": 0.7
}
},
"scoring_params": {}
}
}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for getting the status of a job
```
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for cancelling a job
```
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
```
### Example for getting the results
```
GET /eval/benchmarks/{benchmark_id}/results
```
```json
{
"generations": [],
"scores": {
"{benchmark_id}": {
"score_rows": [],
"aggregated_results": {
"tasks": {},
"groups": {}
}
}
}
}
```

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import NVIDIAEvalConfig
async def get_adapter_impl(
config: NVIDIAEvalConfig,
deps: Dict[Api, Any],
):
from .eval import NVIDIAEvalImpl
impl = NVIDIAEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict
from pydantic import BaseModel, Field
class NVIDIAEvalConfig(BaseModel):
"""
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
Attributes:
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
"""
evaluator_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
description="The url for accessing the evaluator service",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring, ScoringResult
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import NVIDIAEvalConfig
DEFAULT_NAMESPACE = "nvidia"
class NVIDIAEvalImpl(
Eval,
BenchmarksProtocolPrivate,
ModelRegistryHelper,
):
def __init__(
self,
config: NVIDIAEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
"/v1/evaluation/configs",
{
"namespace": DEFAULT_NAMESPACE,
"name": task_def.benchmark_id,
# metadata is copied to request body as-is
**task_def.metadata,
},
)
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation job for a benchmark."""
model = (
benchmark_config.eval_candidate.model
if benchmark_config.eval_candidate.type == "model"
else benchmark_config.eval_candidate.config.model
)
nvidia_model = self.get_provider_model_id(model) or model
result = await self._evaluator_post(
"/v1/evaluation/jobs",
{
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
"target": {"type": "model", "model": nvidia_model},
},
)
return Job(job_id=result["id"], status=JobStatus.in_progress)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
raise NotImplementedError()
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of an evaluation job.
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
"""
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
result_status = result["status"]
job_status = JobStatus.failed
if result_status in ["created", "pending"]:
job_status = JobStatus.scheduled
elif result_status in ["running"]:
job_status = JobStatus.in_progress
elif result_status in ["completed"]:
job_status = JobStatus.completed
elif result_status in ["cancelled"]:
job_status = JobStatus.cancelled
return Job(job_id=job_id, status=job_status)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel the evaluation job."""
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Returns the results of the evaluation job."""
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
return EvaluateResponse(
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
generations=[],
scores={
benchmark_id: ScoringResult(
score_rows=[],
aggregated_results=result,
)
},
)

View file

@ -394,12 +394,10 @@
"aiosqlite",
"blobfile",
"chardet",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
@ -411,7 +409,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -419,7 +416,6 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"ollama": [

View file

@ -1,6 +1,6 @@
version: '2'
distribution_spec:
description: Use NVIDIA NIM for running LLM inference and safety
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
providers:
inference:
- remote::nvidia
@ -13,7 +13,7 @@ distribution_spec:
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
- remote::nvidia
post_training:
- remote::nvidia
datasetio:

View file

@ -7,6 +7,7 @@
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
"safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"eval": ["remote::nvidia"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs"],
"scoring": ["inline::basic"],
@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
eval_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name="nvidia",
distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference and safety",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"eval": [eval_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [
inference_provider,
safety_provider,
]
],
"eval": [eval_provider],
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
@ -119,6 +127,10 @@ def get_distribution_template() -> DistributionTemplate:
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"NVIDIA_EVALUATOR_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Evaluator Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",

View file

@ -53,13 +53,10 @@ providers:
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_id: nvidia
provider_type: remote::nvidia
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia

View file

@ -48,13 +48,10 @@ providers:
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_id: nvidia
provider_type: remote::nvidia
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark"
class TestNVIDIAEvalImpl(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
# Create mock APIs
self.datasetio_api = MagicMock()
self.datasets_api = MagicMock()
self.scoring_api = MagicMock()
self.inference_api = MagicMock()
self.agents_api = MagicMock()
self.config = NVIDIAEvalConfig(
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
)
self.eval_impl = NVIDIAEvalImpl(
config=self.config,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
scoring_api=self.scoring_api,
inference_api=self.inference_api,
agents_api=self.agents_api,
)
# Mock the HTTP request methods
self.evaluator_get_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
)
self.evaluator_post_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
def tearDown(self):
"""Clean up after each test."""
self.evaluator_get_patcher.stop()
self.evaluator_post_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
call_args = self.mock_evaluator_post.call_args
actual_json = call_args[0][1]
# Check that all expected keys contain the expected values in the actual JSON
for key, value in expected_json.items():
assert key in actual_json, f"Key '{key}' missing in actual JSON"
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
else:
assert actual_json[key] == value, f"Value mismatch for '{key}'"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_register_benchmark(self):
eval_config = {
"type": "custom",
"params": {"parallelism": 8},
"tasks": {
"qa": {
"type": "completion",
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
}
},
}
benchmark = Benchmark(
provider_id="nvidia",
type="benchmark",
identifier=MOCK_BENCHMARK_ID,
dataset_id=MOCK_DATASET_ID,
scoring_functions=["basic::equality"],
metadata=eval_config,
)
# Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark
self.run_async(self.eval_impl.register_benchmark(benchmark))
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
type="model",
model=CoreModelId.llama3_1_8b_instruct.value,
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
)
)
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Run the Evaluation job
result = self.run_async(
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
)
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body(
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
}
)
# Verify the result
assert isinstance(result, Job)
assert result.job_id == "job-123"
assert result.status == JobStatus.in_progress
def test_job_status(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "completed"}
self.mock_evaluator_get.return_value = mock_evaluator_response
# Get the Evaluation job
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the result
assert isinstance(result, Job)
assert result.job_id == "job-123"
assert result.status == JobStatus.completed
# Verify the API was called correctly
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
def test_job_cancel(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the API was called correctly
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
def test_job_result(self):
# Mock Evaluator API responses
mock_job_status_response = {"id": "job-123", "status": "completed"}
mock_job_results_response = {
"id": "job-123",
"status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
}
self.mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results
]
# Get the Evaluation job results
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the result
assert isinstance(result, EvaluateResponse)
assert MOCK_BENCHMARK_ID in result.scores
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly
assert self.mock_evaluator_get.call_count == 2
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")