mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
generator + scorer Api for MMLU
This commit is contained in:
parent
fb565dfb06
commit
a25aff290e
14 changed files with 618 additions and 131 deletions
|
@ -14,6 +14,25 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GenerationInput(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GenerationOutput(BaseModel):
|
||||||
|
completion_message: str
|
||||||
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostprocessedGeneration(BaseModel):
|
||||||
|
completion_message: str
|
||||||
|
# structured transformed output from raw_completion_message to compute scorer metrics
|
||||||
|
transformed_generation: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
# A sample (row) from dataset
|
# A sample (row) from dataset
|
||||||
TDatasetSample = TypeVar("TDatasetSample")
|
TDatasetSample = TypeVar("TDatasetSample")
|
||||||
|
|
||||||
|
@ -27,20 +46,32 @@ class DictSample(DatasetSample):
|
||||||
data: Dict[str, Any]
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
# A sample (row) from evals intermediate dataset
|
# A sample (row) from evals intermediate dataset after preprocessing
|
||||||
TProcessedSample = TypeVar("TProcessedSample")
|
TPreprocessedSample = TypeVar("TPreprocessedSample")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PredictionSample(BaseModel):
|
class PreprocessedSample(DatasetSample):
|
||||||
completion_message: str
|
generation_input: GenerationInput
|
||||||
|
|
||||||
|
|
||||||
|
# A sample (row) from evals intermediate dataset after inference
|
||||||
|
TGenerationResponseSample = TypeVar("TGenerationResponseSample")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProcessedDictSample(DictSample):
|
class GenerationResponseSample(DatasetSample):
|
||||||
preprocessed: Optional[Dict[str, Any]] = None
|
generation_output: GenerationOutput
|
||||||
prediction: Optional[PredictionSample] = None
|
|
||||||
postprocessed: Optional[Dict[str, Any]] = None
|
|
||||||
|
# A sample (row) for prepared evals dataset ready for scoring
|
||||||
|
TScorerInputSample = TypeVar("TScorerInputSample")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ScorerInputSample(DatasetSample):
|
||||||
|
generation_output: PostprocessedGeneration
|
||||||
|
expected_output: Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -33,7 +33,7 @@ class EvaluationClient(Evals):
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/evals/run",
|
f"{self.base_url}/evals/run_eval_task",
|
||||||
json={
|
json={
|
||||||
"model": model,
|
"model": model,
|
||||||
"task": task,
|
"task": task,
|
||||||
|
@ -55,28 +55,25 @@ async def run_main(host: str, port: int):
|
||||||
client = EvaluationClient(f"http://{host}:{port}")
|
client = EvaluationClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# Custom Eval Task
|
# Custom Eval Task
|
||||||
|
response = await client.run_evals(
|
||||||
|
model="Llama3.1-8B-Instruct",
|
||||||
|
dataset="mmlu-simple-eval-en",
|
||||||
|
task="mmlu",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Eleuther Eval Task
|
||||||
# response = await client.run_evals(
|
# response = await client.run_evals(
|
||||||
# model="Llama3.1-8B-Instruct",
|
# model="Llama3.1-8B-Instruct",
|
||||||
# dataset="mmlu-simple-eval-en",
|
# # task="meta_mmlu_pro_instruct",
|
||||||
# task="mmlu",
|
# task="meta_ifeval",
|
||||||
# eval_task_config=EvaluateTaskConfig(
|
# eval_task_config=EvaluateTaskConfig(
|
||||||
# n_samples=2,
|
# n_samples=2,
|
||||||
# ),
|
# ),
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# Eleuther Eval Task
|
|
||||||
response = await client.run_evals(
|
|
||||||
model="Llama3.1-8B-Instruct",
|
|
||||||
# task="meta_mmlu_pro_instruct",
|
|
||||||
task="meta_ifeval",
|
|
||||||
eval_task_config=EvaluateTaskConfig(
|
|
||||||
n_samples=2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if response.formatted_report:
|
if response.formatted_report:
|
||||||
cprint(response.formatted_report, "green")
|
cprint(response.formatted_report, "green")
|
||||||
else:
|
else:
|
||||||
cprint(f"evaluate response={response}", "green")
|
cprint(f"Response: {response}", "green")
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Generic, List, Protocol
|
from typing import Dict, Generic, List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -24,14 +24,14 @@ class EvaluationJobLogStream(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EvalResult(BaseModel):
|
class EvalResult(BaseModel):
|
||||||
"""Evaluation result."""
|
"""Aggregated final evaluation result."""
|
||||||
|
|
||||||
metrics: Dict[str, str]
|
metrics: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SingleEvalResult(BaseModel):
|
class SingleEvalResult(BaseModel):
|
||||||
"""Single evaluation result."""
|
"""Single evaluation result. Contains a scorer name, and corresponding metrics from scorer."""
|
||||||
|
|
||||||
score_data: Dict[str, float]
|
score_data: Dict[str, float]
|
||||||
|
|
||||||
|
@ -64,57 +64,222 @@ class EvaluationJobCreateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EvaluateTaskConfig(BaseModel):
|
class EvaluateDatasetConfig(BaseModel):
|
||||||
# num examples to evaluate, evaluate all if None
|
# identifier to previously registered dataset via DatasetDef
|
||||||
n_samples: Optional[int] = None
|
dataset_name: str
|
||||||
# model evaluation params
|
# limit number of rows to evaluate
|
||||||
|
row_limit: Optional[int] = None
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluatePreprocessConfig(BaseModel):
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateModelGenerationConfig(BaseModel):
|
||||||
|
model: str
|
||||||
sampling_params: SamplingParams = SamplingParams()
|
sampling_params: SamplingParams = SamplingParams()
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(ABC, Generic[TDatasetSample, TProcessedSample]):
|
@json_schema_type
|
||||||
|
class EvaluatePostprocessConfig(BaseModel):
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateJudgeScoringConfig(BaseModel): ...
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LLMJudgeConfig(BaseModel):
|
||||||
|
judge_preprocess_config: EvaluatePreprocessConfig
|
||||||
|
judge_model_generation_config: EvaluateModelGenerationConfig
|
||||||
|
judge_postprocess_config: EvaluatePostprocessConfig
|
||||||
|
judge_scoring_config: EvaluateJudgeScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateSingleScorerConfig(BaseModel):
|
||||||
|
scorer_name: str
|
||||||
|
llm_judge_config: Optional[LLMJudgeConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateScoringConfig(BaseModel):
|
||||||
|
# list of scorer (metrics) names to use
|
||||||
|
scorer_config_list: List[EvaluateSingleScorerConfig]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EvaluateTaskConfig(BaseModel):
|
||||||
|
dataset_config: EvaluateDatasetConfig
|
||||||
|
preprocess_config: Optional[EvaluatePreprocessConfig] = None
|
||||||
|
generation_config: EvaluateModelGenerationConfig
|
||||||
|
postprocess_config: Optional[EvaluatePostprocessConfig] = None
|
||||||
|
scoring_config: EvaluateScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGeneratorProcessor(
|
||||||
|
ABC,
|
||||||
|
Generic[
|
||||||
|
TDatasetSample,
|
||||||
|
TPreprocessedSample,
|
||||||
|
TGenerationResponseSample,
|
||||||
|
TScorerInputSample,
|
||||||
|
],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
|
Base class for all generator processors. Each processor needs to implement the following methods:
|
||||||
Base class for all evaluation tasks. Each task needs to implement the following methods:
|
- F1: preprocess_sample(self, dataset)
|
||||||
- F1: preprocess_sample(self)
|
|
||||||
- F2: postprocess_sample(self)
|
- F2: postprocess_sample(self)
|
||||||
- F3: score_sample(self)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._name = self.__class__.__name__
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self, dataset: BaseDataset[TDatasetSample]
|
||||||
|
) -> List[TPreprocessedSample]:
|
||||||
|
return [self.preprocess_sample(sample) for sample in dataset]
|
||||||
|
|
||||||
|
def postprocess(
|
||||||
|
self,
|
||||||
|
generation: List[TGenerationResponseSample],
|
||||||
|
dataset: BaseDataset[TDatasetSample],
|
||||||
|
) -> List[TScorerInputSample]:
|
||||||
|
return [
|
||||||
|
self.postprocess_sample(generation_sample, dataset_sample)
|
||||||
|
for generation_sample, dataset_sample in zip(generation, dataset)
|
||||||
|
]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def preprocess_sample(self, sample: TDatasetSample) -> TProcessedSample:
|
def preprocess_sample(self, sample: TDatasetSample) -> TPreprocessedSample:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def postprocess_sample(self, sample: TProcessedSample) -> TProcessedSample:
|
def postprocess_sample(
|
||||||
|
self,
|
||||||
|
generation_sample: TGenerationResponseSample,
|
||||||
|
dataset_sample: TDatasetSample,
|
||||||
|
) -> TScorerInputSample:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGenerator(ABC, Generic[TGenerationResponseSample]):
|
||||||
|
"""
|
||||||
|
Base class for all generators. Each generator needs to implement the following methods:
|
||||||
|
- generate(self, preprocessed_dataset)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_sample(self, sample: TProcessedSample) -> SingleEvalResult:
|
def generate(
|
||||||
|
self, preprocessed_dataset: List[TPreprocessedSample]
|
||||||
|
) -> List[TGenerationResponseSample]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScorer(ABC, Generic[TScorerInputSample]):
|
||||||
|
"""
|
||||||
|
Base class for all scorers. Each scorer needs to implement the following methods:
|
||||||
|
- score_sample(self, scorer_input_sample)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def score_sample(self, scorer_input_sample: TScorerInputSample) -> SingleEvalResult:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def preprocess(
|
def score(
|
||||||
self, dataset: BaseDataset[TProcessedSample]
|
self, prepared_eval_dataset: List[TScorerInputSample]
|
||||||
) -> List[TProcessedSample]:
|
) -> List[SingleEvalResult]:
|
||||||
return [self.preprocess_sample(sample) for sample in dataset]
|
return [self.score_sample(sample) for sample in prepared_eval_dataset]
|
||||||
|
|
||||||
def postprocess(self, generation: List[TProcessedSample]) -> List[TProcessedSample]:
|
|
||||||
return [self.postprocess_sample(sample) for sample in generation]
|
|
||||||
|
|
||||||
def score(self, postprocessed: List[TProcessedSample]) -> List[SingleEvalResult]:
|
class BaseTask(ABC):
|
||||||
return [self.score_sample(sample) for sample in postprocessed]
|
def __init__(
|
||||||
|
self,
|
||||||
|
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
||||||
|
generator: Optional[BaseGenerator] = None,
|
||||||
|
scorer: Optional[BaseScorer] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.generator_processor = generator_processor
|
||||||
|
self.generator = generator
|
||||||
|
self.scorer = scorer
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, *args, **kwargs) -> EvalResult:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
# class BaseTask(ABC, Generic[TDatasetSample, TProcessedSample]):
|
||||||
|
# """
|
||||||
|
# A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
|
||||||
|
# Base class for all evaluation tasks. Each task needs to implement the following methods:
|
||||||
|
# - F1: preprocess_sample(self)
|
||||||
|
# - F2: postprocess_sample(self)
|
||||||
|
# - F3: score_sample(self)
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(self, *args, **kwargs) -> None:
|
||||||
|
# super().__init__(*args, **kwargs)
|
||||||
|
# self._name = self.__class__.__name__
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def preprocess_sample(self, sample: TDatasetSample) -> TProcessedSample:
|
||||||
|
# raise NotImplementedError()
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def postprocess_sample(self, sample: TProcessedSample) -> TProcessedSample:
|
||||||
|
# raise NotImplementedError()
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def score_sample(self, sample: TProcessedSample) -> SingleEvalResult:
|
||||||
|
# raise NotImplementedError()
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
|
# raise NotImplementedError()
|
||||||
|
|
||||||
|
# def preprocess(
|
||||||
|
# self, dataset: BaseDataset[TProcessedSample]
|
||||||
|
# ) -> List[TProcessedSample]:
|
||||||
|
# return [self.preprocess_sample(sample) for sample in dataset]
|
||||||
|
|
||||||
|
# def postprocess(self, generation: List[TProcessedSample]) -> List[TProcessedSample]:
|
||||||
|
# return [self.postprocess_sample(sample) for sample in generation]
|
||||||
|
|
||||||
|
# def score(self, postprocessed: List[TProcessedSample]) -> List[SingleEvalResult]:
|
||||||
|
# return [self.score_sample(sample) for sample in postprocessed]
|
||||||
|
|
||||||
|
|
||||||
class Evals(Protocol):
|
class Evals(Protocol):
|
||||||
@webmethod(route="/evals/run")
|
|
||||||
async def run_evals(
|
@webmethod(route="/evals/run_eval_task")
|
||||||
|
async def run_eval_task(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
task: str,
|
task: str,
|
||||||
|
@ -122,6 +287,13 @@ class Evals(Protocol):
|
||||||
eval_task_config: Optional[EvaluateTaskConfig] = None,
|
eval_task_config: Optional[EvaluateTaskConfig] = None,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/evals/run_scorer")
|
||||||
|
async def run_scorer(
|
||||||
|
self,
|
||||||
|
dataset_config: EvaluateDatasetConfig,
|
||||||
|
eval_scoring_config: EvaluateScoringConfig,
|
||||||
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
# @webmethod(route="/evals/jobs")
|
# @webmethod(route="/evals/jobs")
|
||||||
# def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
# def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||||
|
|
||||||
|
|
|
@ -25,23 +25,27 @@ class CustomDataset(BaseDataset[DictSample]):
|
||||||
self.load()
|
self.load()
|
||||||
return (DictSample(data=x) for x in self.dataset)
|
return (DictSample(data=x) for x in self.dataset)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return f"CustomDataset({self.config})"
|
return f"CustomDataset({self.config})"
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
if not self.dataset:
|
if not self.dataset:
|
||||||
self.load()
|
self.load()
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
|
||||||
def load(self):
|
def load(self, n_samples: Optional[int] = None) -> None:
|
||||||
if self.dataset:
|
if self.dataset:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: better support w/ data url
|
# TODO: better support w/ data url
|
||||||
if self.config.url.endswith(".csv"):
|
if self.config.url.endswith(".csv"):
|
||||||
df = pandas.read_csv(self.config.url)
|
df = pandas.read_csv(self.config.url)
|
||||||
elif self.config.url.endswith(".xlsx"):
|
elif self.config.url.endswith(".xlsx"):
|
||||||
df = pandas.read_excel(self.config.url)
|
df = pandas.read_excel(self.config.url)
|
||||||
|
|
||||||
|
if n_samples is not None:
|
||||||
|
df = df.sample(n=n_samples)
|
||||||
|
|
||||||
self.dataset = Dataset.from_pandas(df)
|
self.dataset = Dataset.from_pandas(df)
|
||||||
|
|
||||||
|
|
||||||
|
|
6
llama_stack/distribution/registry/scorers/__init__.py
Normal file
6
llama_stack/distribution/registry/scorers/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
# TODO: make these import config based
|
32
llama_stack/distribution/registry/scorers/scorer_registry.py
Normal file
32
llama_stack/distribution/registry/scorers/scorer_registry.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import AbstractSet, Dict
|
||||||
|
|
||||||
|
from llama_stack.apis.evals import BaseScorer
|
||||||
|
|
||||||
|
|
||||||
|
class ScorerRegistry:
|
||||||
|
_REGISTRY: Dict[str, BaseScorer] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def names() -> AbstractSet[str]:
|
||||||
|
return ScorerRegistry._REGISTRY.keys()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(name: str, scorer: BaseScorer) -> None:
|
||||||
|
if name in ScorerRegistry._REGISTRY:
|
||||||
|
raise ValueError(f"Task {name} already exists.")
|
||||||
|
ScorerRegistry._REGISTRY[name] = task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scorer(name: str) -> BaseScorer:
|
||||||
|
if name not in ScorerRegistry._REGISTRY:
|
||||||
|
raise ValueError(f"Task {name} not found.")
|
||||||
|
return ScorerRegistry._REGISTRY[name]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset() -> None:
|
||||||
|
ScorerRegistry._REGISTRY = {}
|
|
@ -3,11 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
# TODO: make these import config based
|
|
||||||
from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask
|
|
||||||
from .task_registry import TaskRegistry
|
|
||||||
|
|
||||||
TaskRegistry.register(
|
|
||||||
"mmlu",
|
|
||||||
MMLUTask,
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,16 +3,27 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import (
|
||||||
|
AggregateScorer,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
from llama_stack.apis.dataset import * # noqa: F403
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
|
||||||
|
MMLUProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||||
|
# from .tasks.run_eval_task import RunEvalTask
|
||||||
|
from .scorer.basic_scorers import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
|
||||||
|
|
||||||
from .config import MetaReferenceEvalsImplConfig
|
from .config import MetaReferenceEvalsImplConfig
|
||||||
|
|
||||||
|
@ -27,7 +38,7 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_evals(
|
async def run_eval_task(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
task: str,
|
task: str,
|
||||||
|
@ -38,43 +49,142 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
f"model={model}, dataset={dataset}, task={task}, eval_task_config={eval_task_config}",
|
f"model={model}, dataset={dataset}, task={task}, eval_task_config={eval_task_config}",
|
||||||
"red",
|
"red",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("dataset must be specified for mete-reference evals")
|
raise ValueError("dataset must be specified for mete-reference evals")
|
||||||
|
|
||||||
dataset = DatasetRegistry.get_dataset(dataset)
|
if not eval_task_config:
|
||||||
dataset.load()
|
# construct eval task config from inputs
|
||||||
|
eval_task_config = EvaluateTaskConfig(
|
||||||
task_impl = TaskRegistry.get_task(task)()
|
dataset_config=EvaluateDatasetConfig(
|
||||||
preprocessed = task_impl.preprocess(dataset)
|
dataset_name=dataset,
|
||||||
|
row_limit=2,
|
||||||
# TODO: replace w/ batch inference & async return eval job
|
),
|
||||||
generation_outputs = []
|
generation_config=EvaluateModelGenerationConfig(
|
||||||
if eval_task_config is None:
|
model=model,
|
||||||
eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
|
),
|
||||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
|
scoring_config=EvaluateScoringConfig(
|
||||||
preprocessed
|
scorer_config_list=[
|
||||||
):
|
EvaluateSingleScorerConfig(scorer_name="accuracy"),
|
||||||
eval_task_config.n_samples = len(preprocessed)
|
]
|
||||||
|
),
|
||||||
print(
|
|
||||||
f"Eval generation start, generate on {eval_task_config.n_samples} samples"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for sample in preprocessed[: eval_task_config.n_samples]:
|
# TODO: wrap inside task
|
||||||
|
# run_task = RunEvalTask(
|
||||||
|
# eval_task_config=eval_task_config,
|
||||||
|
# )
|
||||||
|
# eval_result = run_task.run()
|
||||||
|
|
||||||
|
dataset = DatasetRegistry.get_dataset(
|
||||||
|
eval_task_config.dataset_config.dataset_name
|
||||||
|
)
|
||||||
|
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
|
||||||
|
print(f"Running on {len(dataset)} samples")
|
||||||
|
|
||||||
|
# F1
|
||||||
|
processor = MMLUProcessor()
|
||||||
|
preprocessed = processor.preprocess(dataset)
|
||||||
|
|
||||||
|
# Generation
|
||||||
|
# TODO: wrap inside BaseGenerator
|
||||||
|
generation_outputs = []
|
||||||
|
for sample in preprocessed:
|
||||||
print("generation: ", sample)
|
print("generation: ", sample)
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=sample.preprocessed["messages"],
|
messages=sample.generation_input.messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
sample.prediction = PredictionSample(
|
cprint(f"response: {response}", "cyan")
|
||||||
|
|
||||||
|
generation_outputs.append(
|
||||||
|
GenerationResponseSample(
|
||||||
|
generation_output=GenerationOutput(
|
||||||
completion_message=response.completion_message.content
|
completion_message=response.completion_message.content
|
||||||
)
|
)
|
||||||
generation_outputs.append(sample)
|
|
||||||
|
|
||||||
postprocessed = task_impl.postprocess(generation_outputs)
|
|
||||||
eval_results = task_impl.score(postprocessed)
|
|
||||||
aggr_result = task_impl.aggregate_results(eval_results)
|
|
||||||
return EvaluateResponse(
|
|
||||||
eval_result=aggr_result,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
cprint(generation_outputs, "green")
|
||||||
|
|
||||||
|
# F2
|
||||||
|
postprocessed = processor.postprocess(generation_outputs, dataset)
|
||||||
|
cprint(postprocessed, "blue")
|
||||||
|
|
||||||
|
# F3 - scorer
|
||||||
|
scorer = AggregateScorer(
|
||||||
|
scorers=[
|
||||||
|
AccuracyScorer(),
|
||||||
|
RandomScorer(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
scorer_results = scorer.score(postprocessed)
|
||||||
|
cprint(scorer_results, "magenta")
|
||||||
|
eval_result = scorer.aggregate_results(scorer_results)
|
||||||
|
|
||||||
|
return EvaluateResponse(
|
||||||
|
eval_result=eval_result,
|
||||||
|
formatted_report=json.dumps(eval_result.json(), indent=4),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_scorer(
|
||||||
|
self,
|
||||||
|
dataset_config: EvaluateDatasetConfig,
|
||||||
|
eval_scoring_config: EvaluateScoringConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
return EvaluateResponse(
|
||||||
|
eval_result={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# async def run_evals(
|
||||||
|
# self,
|
||||||
|
# model: str,
|
||||||
|
# task: str,
|
||||||
|
# dataset: Optional[str] = None,
|
||||||
|
# eval_task_config: Optional[EvaluateTaskConfig] = None,
|
||||||
|
# ) -> EvaluateResponse:
|
||||||
|
# cprint(
|
||||||
|
# f"model={model}, dataset={dataset}, task={task}, eval_task_config={eval_task_config}",
|
||||||
|
# "red",
|
||||||
|
# )
|
||||||
|
# if not dataset:
|
||||||
|
# raise ValueError("dataset must be specified for mete-reference evals")
|
||||||
|
|
||||||
|
# dataset = DatasetRegistry.get_dataset(dataset)
|
||||||
|
# dataset.load()
|
||||||
|
|
||||||
|
# task_impl = TaskRegistry.get_task(task)()
|
||||||
|
# preprocessed = task_impl.preprocess(dataset)
|
||||||
|
|
||||||
|
# # TODO: replace w/ batch inference & async return eval job
|
||||||
|
# generation_outputs = []
|
||||||
|
# if eval_task_config is None:
|
||||||
|
# eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
|
||||||
|
# if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
|
||||||
|
# preprocessed
|
||||||
|
# ):
|
||||||
|
# eval_task_config.n_samples = len(preprocessed)
|
||||||
|
|
||||||
|
# print(
|
||||||
|
# f"Eval generation start, generate on {eval_task_config.n_samples} samples"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# for sample in preprocessed[: eval_task_config.n_samples]:
|
||||||
|
# print("generation: ", sample)
|
||||||
|
# response = await self.inference_api.chat_completion(
|
||||||
|
# model=model,
|
||||||
|
# messages=sample.preprocessed["messages"],
|
||||||
|
# stream=False,
|
||||||
|
# )
|
||||||
|
# sample.prediction = PredictionSample(
|
||||||
|
# completion_message=response.completion_message.content
|
||||||
|
# )
|
||||||
|
# generation_outputs.append(sample)
|
||||||
|
|
||||||
|
# postprocessed = task_impl.postprocess(generation_outputs)
|
||||||
|
# eval_results = task_impl.score(postprocessed)
|
||||||
|
# aggr_result = task_impl.aggregate_results(eval_results)
|
||||||
|
# return EvaluateResponse(
|
||||||
|
# eval_result=aggr_result,
|
||||||
|
# )
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -7,8 +7,6 @@ import re
|
||||||
|
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
|
||||||
# from llama_stack.distribution.registry.tasks.task import BaseTask
|
|
||||||
|
|
||||||
QUERY_TEMPLATE_MULTICHOICE = """
|
QUERY_TEMPLATE_MULTICHOICE = """
|
||||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||||
|
|
||||||
|
@ -112,60 +110,78 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
|
class MMLUProcessor(
|
||||||
|
BaseGeneratorProcessor[
|
||||||
|
DictSample, PreprocessedSample, GenerationResponseSample, ScorerInputSample
|
||||||
|
]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
MMLU Task.
|
Generator processor for MMLU
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
def preprocess_sample(self, sample: DictSample) -> PreprocessedSample:
|
||||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
|
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
|
||||||
preprocessed = {
|
preprocessed_msgs = [
|
||||||
"messages": [
|
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
}
|
processed_sample = PreprocessedSample(
|
||||||
processed_sample = ProcessedDictSample(
|
generation_input=GenerationInput(
|
||||||
data=sample.data,
|
messages=preprocessed_msgs,
|
||||||
preprocessed=preprocessed,
|
)
|
||||||
)
|
)
|
||||||
return processed_sample
|
return processed_sample
|
||||||
|
|
||||||
def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
def postprocess_sample(
|
||||||
if not sample.postprocessed:
|
self, generation_sample: GenerationResponseSample, dataset_sample: DictSample
|
||||||
sample.postprocessed = {}
|
) -> ScorerInputSample:
|
||||||
sample.postprocessed["postprocessed"] = normalize_response(
|
response_text = generation_sample.generation_output.completion_message
|
||||||
sample.prediction.completion_message
|
normalized_response = normalize_response(response_text)
|
||||||
)
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
# extract answer
|
||||||
postprocessed_output = sample.postprocessed["postprocessed"]
|
extracted_answer = ""
|
||||||
expected_answer = sample.data["Answer"]
|
|
||||||
|
|
||||||
extracted_answer = None
|
|
||||||
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||||
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||||
match = re.search(regex, postprocessed_output)
|
match = re.search(regex, normalized_response)
|
||||||
if match:
|
if match:
|
||||||
extracted_answer = normalize_extracted_answer(match.group(1))
|
extracted_answer = normalize_extracted_answer(match.group(1))
|
||||||
break
|
break
|
||||||
|
|
||||||
score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
return ScorerInputSample(
|
||||||
|
generation_output=PostprocessedGeneration(
|
||||||
return SingleEvalResult(
|
completion_message=response_text,
|
||||||
score_data={
|
transformed_generation=extracted_answer,
|
||||||
"score": score,
|
),
|
||||||
},
|
expected_output=dataset_sample.data["Answer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
# def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
||||||
print("aggregate_results", eval_results)
|
# postprocessed_output = sample.postprocessed["postprocessed"]
|
||||||
sum_score = sum([result.score_data["score"] for result in eval_results])
|
# expected_answer = sample.data["Answer"]
|
||||||
|
|
||||||
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
# extracted_answer = None
|
||||||
|
# for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||||
|
# regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||||
|
# match = re.search(regex, postprocessed_output)
|
||||||
|
# if match:
|
||||||
|
# extracted_answer = normalize_extracted_answer(match.group(1))
|
||||||
|
# break
|
||||||
|
|
||||||
|
# score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||||
|
|
||||||
|
# return SingleEvalResult(
|
||||||
|
# score_data={
|
||||||
|
# "score": score,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
|
||||||
|
# def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
|
# print("aggregate_results", eval_results)
|
||||||
|
# sum_score = sum([result.score_data["score"] for result in eval_results])
|
||||||
|
|
||||||
|
# return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import random
|
||||||
|
|
||||||
|
from llama_stack.apis.evals.evals import BaseScorer, EvalResult, SingleEvalResult
|
||||||
|
from llama_stack.apis.dataset.dataset import * # noqa: F401 F403
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateScorer(BaseScorer[ScorerInputSample]):
|
||||||
|
def __init__(self, scorers: List[BaseScorer[ScorerInputSample]]):
|
||||||
|
self.scorers = scorers
|
||||||
|
|
||||||
|
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
||||||
|
all_score_data = {}
|
||||||
|
for scorer in self.scorers:
|
||||||
|
score_data = scorer.score_sample(scorer_input_sample).score_data
|
||||||
|
for k, v in score_data.items():
|
||||||
|
all_score_data[k] = v
|
||||||
|
|
||||||
|
return SingleEvalResult(
|
||||||
|
score_data=all_score_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
|
all_metrics = {}
|
||||||
|
|
||||||
|
for scorer in self.scorers:
|
||||||
|
metrics = scorer.aggregate_results(eval_results).metrics
|
||||||
|
for k, v in metrics.items():
|
||||||
|
all_metrics[f"{scorer.__class__.__name__}:{k}"] = v
|
||||||
|
|
||||||
|
return EvalResult(
|
||||||
|
metrics=all_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomScorer(BaseScorer[ScorerInputSample]):
|
||||||
|
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
||||||
|
return SingleEvalResult(score_data={"random": random.random()})
|
||||||
|
|
||||||
|
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
|
avg_random = sum(
|
||||||
|
[result.score_data["random"] for result in eval_results]
|
||||||
|
) / len(eval_results)
|
||||||
|
max_random = max([result.score_data["random"] for result in eval_results])
|
||||||
|
return EvalResult(
|
||||||
|
metrics={
|
||||||
|
"avg_random": avg_random,
|
||||||
|
"max_random": max_random,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AccuracyScorer(BaseScorer[ScorerInputSample]):
|
||||||
|
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
||||||
|
extracted_answer = scorer_input_sample.generation_output.transformed_generation
|
||||||
|
expected_answer = scorer_input_sample.expected_output
|
||||||
|
|
||||||
|
accuracy = (
|
||||||
|
1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return SingleEvalResult(score_data={"accuracy": accuracy})
|
||||||
|
|
||||||
|
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
|
num_correct = sum([result.score_data["accuracy"] for result in eval_results])
|
||||||
|
num_total = len(eval_results)
|
||||||
|
|
||||||
|
return EvalResult(
|
||||||
|
metrics={
|
||||||
|
"avg_accuracy": num_correct / num_total,
|
||||||
|
"num_correct": num_correct,
|
||||||
|
"num_total": num_total,
|
||||||
|
}
|
||||||
|
)
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||||
|
|
||||||
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class RunEvalTask(BaseTask):
|
||||||
|
"""
|
||||||
|
RunEvalTask for LlamaStack
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
eval_task_config,
|
||||||
|
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
||||||
|
generator: Optional[BaseGenerator] = None,
|
||||||
|
scorer: Optional[BaseScorer] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
generator_processor=generator_processor,
|
||||||
|
generator=generator,
|
||||||
|
scorer=scorer,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.eval_task_config = eval_task_config
|
||||||
|
self.dataset = DatasetRegistry.get_dataset(
|
||||||
|
eval_task_config.dataset_config.dataset_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs) -> EvalResult:
|
||||||
|
print(f"Running eval task on {self.dataset}")
|
||||||
|
return EvalResult()
|
|
@ -14,8 +14,8 @@ apis:
|
||||||
- evals
|
- evals
|
||||||
providers:
|
providers:
|
||||||
evals:
|
evals:
|
||||||
- provider_id: eleuther
|
- provider_id: meta-reference
|
||||||
provider_type: eleuther
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
inference:
|
inference:
|
||||||
- provider_id: remote::tgi
|
- provider_id: remote::tgi
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue