mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add NVIDIA Eval integration (#1890)
# What does this PR do? This PR adds support for NVIDIA's NeMo Evaluator API to the Llama Stack eval module. The integration enables users to evaluate models via the Llama Stack interface. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] 1. Added unit tests and successfully ran from root of project: `./scripts/unit-tests.sh tests/unit/providers/nvidia/test_eval.py` ``` tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_cancel PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_result PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_status PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_register_benchmark PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_run_eval PASSED ``` 2. Verified I could build the Llama Stack image: `LLAMA_STACK_DIR=$(pwd) llama stack build --template nvidia --image-type venv` Documentation added to `llama_stack/providers/remote/eval/nvidia/README.md` --------- Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
parent
0b6cd45950
commit
cc77f79f55
13 changed files with 598 additions and 23 deletions
|
@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| eval | `remote::nvidia` |
|
||||
| inference | `remote::nvidia` |
|
||||
| post_training | `remote::nvidia` |
|
||||
| safety | `remote::nvidia` |
|
||||
|
@ -29,6 +29,7 @@ The following environment variables can be configured:
|
|||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.agents,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.eval,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
),
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
Api.scoring,
|
||||
Api.inference,
|
||||
Api.agents,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
5
llama_stack/providers/remote/eval/__init__.py
Normal file
5
llama_stack/providers/remote/eval/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
|
@ -0,0 +1,134 @@
|
|||
# NVIDIA NeMo Evaluator Eval Provider
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
|
||||
|
||||
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
|
||||
|
||||
### Example for register an academic benchmark
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "mmlu",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "mmlu"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for register a custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "custom",
|
||||
"params": {
|
||||
"parallelism": 8
|
||||
},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {
|
||||
"template": {
|
||||
"prompt": "{{prompt}}",
|
||||
"max_tokens": 200
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
|
||||
},
|
||||
"metrics": {
|
||||
"bleu": {
|
||||
"type": "bleu",
|
||||
"params": {
|
||||
"references": [
|
||||
"{{ideal_response}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for triggering a benchmark/custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"benchmark_config": {
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||
"sampling_params": {
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.7
|
||||
}
|
||||
},
|
||||
"scoring_params": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for getting the status of a job
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for cancelling a job
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||
```
|
||||
|
||||
### Example for getting the results
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/results
|
||||
```
|
||||
```json
|
||||
{
|
||||
"generations": [],
|
||||
"scores": {
|
||||
"{benchmark_id}": {
|
||||
"score_rows": [],
|
||||
"aggregated_results": {
|
||||
"tasks": {},
|
||||
"groups": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NVIDIAEvalConfig,
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import NVIDIAEvalImpl
|
||||
|
||||
impl = NVIDIAEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
deps[Api.agents],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]
|
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NVIDIAEvalConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||
"""
|
||||
|
||||
evaluator_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the evaluator service",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||
}
|
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
DEFAULT_NAMESPACE = "nvidia"
|
||||
|
||||
|
||||
class NVIDIAEvalImpl(
|
||||
Eval,
|
||||
BenchmarksProtocolPrivate,
|
||||
ModelRegistryHelper,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIAEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
agents_api: Agents,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path, data):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
"/v1/evaluation/configs",
|
||||
{
|
||||
"namespace": DEFAULT_NAMESPACE,
|
||||
"name": task_def.benchmark_id,
|
||||
# metadata is copied to request body as-is
|
||||
**task_def.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation job for a benchmark."""
|
||||
model = (
|
||||
benchmark_config.eval_candidate.model
|
||||
if benchmark_config.eval_candidate.type == "model"
|
||||
else benchmark_config.eval_candidate.config.model
|
||||
)
|
||||
nvidia_model = self.get_provider_model_id(model) or model
|
||||
|
||||
result = await self._evaluator_post(
|
||||
"/v1/evaluation/jobs",
|
||||
{
|
||||
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
||||
"target": {"type": "model", "model": nvidia_model},
|
||||
},
|
||||
)
|
||||
|
||||
return Job(job_id=result["id"], status=JobStatus.in_progress)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of an evaluation job.
|
||||
|
||||
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
||||
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
||||
"""
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
||||
result_status = result["status"]
|
||||
|
||||
job_status = JobStatus.failed
|
||||
if result_status in ["created", "pending"]:
|
||||
job_status = JobStatus.scheduled
|
||||
elif result_status in ["running"]:
|
||||
job_status = JobStatus.in_progress
|
||||
elif result_status in ["completed"]:
|
||||
job_status = JobStatus.completed
|
||||
elif result_status in ["cancelled"]:
|
||||
job_status = JobStatus.cancelled
|
||||
|
||||
return Job(job_id=job_id, status=job_status)
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel the evaluation job."""
|
||||
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Returns the results of the evaluation job."""
|
||||
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
||||
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
||||
|
||||
return EvaluateResponse(
|
||||
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
||||
generations=[],
|
||||
scores={
|
||||
benchmark_id: ScoringResult(
|
||||
score_rows=[],
|
||||
aggregated_results=result,
|
||||
)
|
||||
},
|
||||
)
|
|
@ -394,12 +394,10 @@
|
|||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
@ -411,7 +409,6 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
@ -419,7 +416,6 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"ollama": [
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Use NVIDIA NIM for running LLM inference and safety
|
||||
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||
providers:
|
||||
inference:
|
||||
- remote::nvidia
|
||||
|
@ -13,7 +13,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
- remote::nvidia
|
||||
post_training:
|
||||
- remote::nvidia
|
||||
datasetio:
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
|
@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"safety": ["remote::nvidia"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"eval": ["remote::nvidia"],
|
||||
"post_training": ["remote::nvidia"],
|
||||
"datasetio": ["inline::localfs"],
|
||||
"scoring": ["inline::basic"],
|
||||
|
@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="remote::nvidia",
|
||||
config=NVIDIASafetyConfig.sample_run_config(),
|
||||
)
|
||||
eval_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIAEvalConfig.sample_run_config(),
|
||||
)
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="nvidia",
|
||||
|
@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
distro_type="self_hosted",
|
||||
description="Use NVIDIA NIM for running LLM inference and safety",
|
||||
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
|
@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"eval": [eval_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inference": [
|
||||
inference_provider,
|
||||
safety_provider,
|
||||
]
|
||||
],
|
||||
"eval": [eval_provider],
|
||||
},
|
||||
default_models=[inference_model, safety_model],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||
|
@ -119,6 +127,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Guardrails Service",
|
||||
),
|
||||
"NVIDIA_EVALUATOR_URL": (
|
||||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Evaluator Service",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"Llama3.1-8B-Instruct",
|
||||
"Inference model",
|
||||
|
|
|
@ -53,13 +53,10 @@ providers:
|
|||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
|
|
|
@ -48,13 +48,10 @@ providers:
|
|||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
|
|
201
tests/unit/providers/nvidia/test_eval.py
Normal file
201
tests/unit/providers/nvidia/test_eval.py
Normal file
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||
|
||||
MOCK_DATASET_ID = "default/test-dataset"
|
||||
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||
|
||||
|
||||
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
|
||||
# Create mock APIs
|
||||
self.datasetio_api = MagicMock()
|
||||
self.datasets_api = MagicMock()
|
||||
self.scoring_api = MagicMock()
|
||||
self.inference_api = MagicMock()
|
||||
self.agents_api = MagicMock()
|
||||
|
||||
self.config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
|
||||
self.eval_impl = NVIDIAEvalImpl(
|
||||
config=self.config,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
scoring_api=self.scoring_api,
|
||||
inference_api=self.inference_api,
|
||||
agents_api=self.agents_api,
|
||||
)
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.evaluator_get_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||
)
|
||||
self.evaluator_post_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||
)
|
||||
|
||||
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.evaluator_get_patcher.stop()
|
||||
self.evaluator_post_patcher.stop()
|
||||
|
||||
def _assert_request_body(self, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
call_args = self.mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def test_register_benchmark(self):
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type="benchmark",
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Register the benchmark
|
||||
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||
|
||||
def test_run_eval(self):
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Run the Evaluation job
|
||||
result = self.run_async(
|
||||
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body(
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
|
||||
def test_job_status(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||
|
||||
# Get the Evaluation job
|
||||
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
def test_job_cancel(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Cancel the Evaluation job
|
||||
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
|
||||
def test_job_result(self):
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
self.mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert self.mock_evaluator_get.call_count == 2
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
Loading…
Add table
Add a link
Reference in a new issue