[Evals API][4/n] evals with generation meta-reference impl (#303)

* wip

* dataset validation

* test_scoring

* cleanup

* clean up test

* comments

* error checking

* dataset client

* test client:

* datasetio client

* clean up

* basic scoring function works

* scorer wip

* equality scorer

* score batch impl

* score batch

* update scoring test

* refactor

* validate scorer input

* address comments

* evals with generation

* add all rows scores to ScoringResult

* minor typing

* bugfix

* scoring function def rename

* rebase name

* refactor

* address comments

* Update iOS inference instructions for new quantization

* Small updates to quantization config

* Fix score threshold in faiss

* Bump version to 0.0.45

* Handle both ipv6 and ipv4 interfaces together

* update manifest for build templates

* Update getting_started.md

* chatcompletion & completion input type validation

* inclusion->subsetof

* error checking

* scoring_function -> scoring_fn rename, scorer -> scoring_fn rename

* address comments

* [Evals API][5/n] fixes to generate openapi spec (#323)

* generate openapi

* typing comment, dataset -> dataset_id

* remove custom type

* sample eval run.yaml

---------

Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Xi Yan 2024-10-25 13:12:39 -07:00 committed by GitHub
parent 426d821e7f
commit abdf7cddf3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 3371 additions and 1296 deletions

View file

@ -33,14 +33,16 @@ schema_utils.json_schema_type = json_schema_type
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -54,14 +56,16 @@ class LlamaStack(
Inference,
BatchInference,
Agents,
RewardScoring,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Evaluations,
Eval,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,8 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@ -10,3 +12,9 @@ from pydantic import BaseModel
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Literal, Union
from typing import Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -24,12 +24,10 @@ class BooleanType(BaseModel):
class ArrayType(BaseModel):
type: Literal["array"] = "array"
items: "ParamType"
class ObjectType(BaseModel):
type: Literal["object"] = "object"
properties: Dict[str, "ParamType"] = Field(default_factory=dict)
class JsonType(BaseModel):
@ -38,12 +36,21 @@ class JsonType(BaseModel):
class UnionType(BaseModel):
type: Literal["union"] = "union"
options: List["ParamType"] = Field(default_factory=list)
class CustomType(BaseModel):
type: Literal["custom"] = "custom"
validator_class: str
class ChatCompletionInputType(BaseModel):
# expects List[Message] for messages
type: Literal["chat_completion_input"] = "chat_completion_input"
class CompletionInputType(BaseModel):
# expects InterleavedTextMedia for content
type: Literal["completion_input"] = "completion_input"
class AgentTurnInputType(BaseModel):
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
ParamType = Annotated[
@ -55,11 +62,22 @@ ParamType = Annotated[
ObjectType,
JsonType,
UnionType,
CustomType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
Field(discriminator="type"),
]
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
# ArrayType.model_rebuild()
# ObjectType.model_rebuild()
# UnionType.model_rebuild()
# class CustomType(BaseModel):
# type: Literal["custom"] = "custom"
# validator_class: str

View file

@ -12,7 +12,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403
@ -40,7 +40,7 @@ class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: List[Dict[str, ScoringResult]]
scores: Dict[str, ScoringResult]
class Eval(Protocol):
@ -61,10 +61,10 @@ class Eval(Protocol):
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> None: ...
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> None: ...
async def job_result(self, job_id: str) -> EvaluateResponse: ...

View file

@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
job_uuid: str
model: str
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: Union[
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
finetuned_model: URL
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
@ -181,8 +181,8 @@ class PostTraining(Protocol):
self,
job_uuid: str,
model: str,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
@ -198,8 +198,8 @@ class PostTraining(Protocol):
self,
job_uuid: str,
finetuned_model: URL,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,

View file

@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ...
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
@runtime_checkable

View file

@ -29,7 +29,7 @@ class LLMAsJudgeContext(BaseModel):
@json_schema_type
class ScoringFunctionDef(BaseModel):
class ScoringFnDef(BaseModel):
identifier: str
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
@ -48,7 +48,7 @@ class ScoringFunctionDef(BaseModel):
@json_schema_type
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
class ScoringFnDefWithProvider(ScoringFnDef):
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
@ -57,14 +57,14 @@ class ScoringFunctionDefWithProvider(ScoringFunctionDef):
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ...
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
@webmethod(route="/scoring_functions/get", method="GET")
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFunctionDefWithProvider]: ...
) -> Optional[ScoringFnDefWithProvider]: ...
@webmethod(route="/scoring_functions/register", method="POST")
async def register_scoring_function(
self, function_def: ScoringFunctionDefWithProvider
self, function_def: ScoringFnDefWithProvider
) -> None: ...

View file

@ -13,7 +13,6 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum):
@ -40,7 +39,7 @@ class SyntheticDataGenerationRequest(BaseModel):
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[ScoredDialogGenerations]
synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None

View file

@ -34,7 +34,7 @@ RoutableObject = Union[
ShieldDef,
MemoryBankDef,
DatasetDef,
ScoringFunctionDef,
ScoringFnDef,
]
RoutableObjectWithProvider = Union[
@ -42,7 +42,7 @@ RoutableObjectWithProvider = Union[
ShieldDefWithProvider,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
ScoringFunctionDefWithProvider,
ScoringFnDefWithProvider,
]
RoutedProtocol = Union[

View file

@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
@ -46,6 +47,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.datasetio: DatasetIO,
Api.scoring_functions: ScoringFunctions,
Api.scoring: Scoring,
Api.eval: Eval,
}

View file

@ -100,7 +100,7 @@ class CommonRoutingTableImpl(RoutingTable):
scoring_functions = await p.list_scoring_functions()
add_objects(
[
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
for s in scoring_functions
]
)
@ -239,7 +239,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
@ -247,10 +247,10 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFunctionDefWithProvider]:
) -> Optional[ScoringFnDefWithProvider]:
return self.get_object_by_identifier(name)
async def register_scoring_function(
self, function_def: ScoringFunctionDefWithProvider
self, function_def: ScoringFnDefWithProvider
) -> None:
await self.register_object(function_def)

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.scoring_functions import ScoringFunctionDef
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.shields import ShieldDef
@ -25,6 +25,7 @@ class Api(Enum):
memory = "memory"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
telemetry = "telemetry"
@ -63,11 +64,9 @@ class DatasetsProtocolPrivate(Protocol):
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
async def register_scoring_function(
self, function_def: ScoringFunctionDef
) -> None: ...
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
@json_schema_type

View file

@ -143,11 +143,12 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
else:
next_page_token = int(page_token)
if rows_in_page == -1:
rows = dataset_info.dataset_impl[next_page_token:]
start = next_page_token
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_info.dataset_impl[start:end]
return PaginatedRowsResult(

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
):
from .eval import MetaReferenceEvalImpl
impl = MetaReferenceEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.eval import * # noqa: F401, F403
class MetaReferenceEvalConfig(BaseModel): ...

View file

@ -0,0 +1,167 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from .config import MetaReferenceEvalConfig
class ColumnName(Enum):
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval):
def __init__(
self,
config: MetaReferenceEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
{
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
) -> Job:
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.evaluate(
input_rows=all_rows.rows,
candidate=candidate,
scoring_functions=scoring_functions,
)
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id)
async def evaluate(
self,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
) -> EvaluateResponse:
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
)
assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = []
for x in input_rows:
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x:
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else:
raise ValueError("Invalid input row")
# scoring with generated_answer
score_input_rows = [
input_r | generated_r
for input_r, generated_r in zip(input_rows, generations)
]
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id]

View file

@ -13,17 +13,22 @@ from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
EqualityScorer,
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
EqualityScoringFn,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
SubsetOfScoringFn,
)
from .config import MetaReferenceScoringConfig
SUPPORTED_SCORERS = [
EqualityScorer,
SUPPORTED_SCORING_FNS = [
EqualityScoringFn,
SubsetOfScoringFn,
]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
@ -41,10 +46,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
async def list_scoring_functions(self) -> List[ScoringFnDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
raise NotImplementedError(
"Dynamically registering scoring functions is not supported"
)
@ -96,9 +101,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in SCORER_REGISTRY:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scorer = SCORER_REGISTRY[scoring_fn_id]()
score_results = scorer.score(input_rows)
agg_results = scorer.aggregate(score_results)
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
score_results = scoring_fn.score(input_rows)
agg_results = scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -9,15 +9,15 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
class BaseScorer(ABC):
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scorers.
Each scorer needs to implement the following methods:
Base interface class for all meta-reference scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scorer_results)
- aggregate(self, scoring_fn_results)
"""
scoring_function_def: ScoringFunctionDef
scoring_function_def: ScoringFnDef
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -4,20 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
BaseScorer,
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
class EqualityScorer(BaseScorer):
class EqualityScoringFn(BaseScoringFn):
"""
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFunctionDef(
scoring_function_def = ScoringFnDef(
identifier="equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
@ -38,12 +41,4 @@ class EqualityScorer(BaseScorer):
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided."
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}
return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
class SubsetOfScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFnDef(
identifier="subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer in generated_answer else 0.0
return {
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.eval,
provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.eval",
config_class="llama_stack.providers.impls.meta_reference.eval.MetaReferenceEvalConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
],
),
]

View file

@ -1,6 +1,6 @@
input_query,generated_answer,expected_answer
What is the capital of France?,London,Paris
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,Jupiter,Jupiter
What is the smallest country in the world?,China,Vatican City
What is the currency of Japan?,Yen,Yen
input_query,generated_answer,expected_answer,chat_completion_input
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"

1 input_query generated_answer expected_answer chat_completion_input
2 What is the capital of France? London Paris [{'role': 'user', 'content': 'What is the capital of France?'}]
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg [{'role': 'user', 'content': 'Who is the CEO of Meta?'}]
4 What is the largest planet in our solar system? Jupiter Jupiter [{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]
5 What is the smallest country in the world? China Vatican City [{'role': 'user', 'content': 'What is the smallest country in the world?'}]
6 What is the currency of Japan? Yen Yen [{'role': 'user', 'content': 'What is the currency of Japan?'}]

View file

@ -61,20 +61,31 @@ def data_url_from_file(file_path: str) -> str:
return data_url
async def register_dataset(datasets_impl: Datasets):
async def register_dataset(
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
if for_generation:
dataset_schema = {
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
}
else:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
}
dataset = DatasetDefWithProvider(
identifier="test_dataset",
identifier=dataset_id,
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
dataset_schema=dataset_schema,
)
await datasets_impl.register_dataset(dataset)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,18 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}
eval:
- provider_id: test-meta
provider_type: meta-reference
config: {}
inference:
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.eval.eval import ModelCandidate
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_models.llama3.api import SamplingParams
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/eval/test_eval.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def eval_settings():
impls = await resolve_impls_for_test(
Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference]
)
return {
"eval_impl": impls[Api.eval],
"scoring_impl": impls[Api.scoring],
"datasets_impl": impls[Api.datasets],
}
@pytest.mark.asyncio
async def test_eval(eval_settings):
datasets_impl = eval_settings["datasets_impl"]
await register_dataset(
datasets_impl,
for_generation=True,
dataset_id="test_dataset_for_eval",
)
response = await datasets_impl.list_datasets()
assert len(response) == 1
eval_impl = eval_settings["eval_impl"]
response = await eval_impl.evaluate_batch(
dataset_id=response[0].identifier,
candidate=ModelCandidate(
model="Llama3.2-1B-Instruct",
sampling_params=SamplingParams(),
),
scoring_functions=["subset_of"],
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(response.job_id)
assert eval_response is not None
assert len(eval_response.generations) == 5
assert "subset_of" in eval_response.scores

View file

@ -14,7 +14,12 @@ apis:
- datasets
- datasetio
- scoring
- eval
providers:
eval:
- provider_id: meta0
provider_type: meta-reference
config: {}
scoring:
- provider_id: meta0
provider_type: meta-reference