forked from phoenix-oss/llama-stack-mirror
[Evals API][4/n] evals with generation meta-reference impl (#303)
* wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * evals with generation * add all rows scores to ScoringResult * minor typing * bugfix * scoring function def rename * rebase name * refactor * address comments * Update iOS inference instructions for new quantization * Small updates to quantization config * Fix score threshold in faiss * Bump version to 0.0.45 * Handle both ipv6 and ipv4 interfaces together * update manifest for build templates * Update getting_started.md * chatcompletion & completion input type validation * inclusion->subsetof * error checking * scoring_function -> scoring_fn rename, scorer -> scoring_fn rename * address comments * [Evals API][5/n] fixes to generate openapi spec (#323) * generate openapi * typing comment, dataset -> dataset_id * remove custom type * sample eval run.yaml --------- Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
426d821e7f
commit
abdf7cddf3
31 changed files with 3371 additions and 1296 deletions
|
@ -33,14 +33,16 @@ schema_utils.json_schema_type = json_schema_type
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
@ -54,14 +56,16 @@ class LlamaStack(
|
|||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
RewardScoring,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Evaluations,
|
||||
Eval,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -3,6 +3,8 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -10,3 +12,9 @@ from pydantic import BaseModel
|
|||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
@ -24,12 +24,10 @@ class BooleanType(BaseModel):
|
|||
|
||||
class ArrayType(BaseModel):
|
||||
type: Literal["array"] = "array"
|
||||
items: "ParamType"
|
||||
|
||||
|
||||
class ObjectType(BaseModel):
|
||||
type: Literal["object"] = "object"
|
||||
properties: Dict[str, "ParamType"] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class JsonType(BaseModel):
|
||||
|
@ -38,12 +36,21 @@ class JsonType(BaseModel):
|
|||
|
||||
class UnionType(BaseModel):
|
||||
type: Literal["union"] = "union"
|
||||
options: List["ParamType"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CustomType(BaseModel):
|
||||
type: Literal["custom"] = "custom"
|
||||
validator_class: str
|
||||
class ChatCompletionInputType(BaseModel):
|
||||
# expects List[Message] for messages
|
||||
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||
|
||||
|
||||
class CompletionInputType(BaseModel):
|
||||
# expects InterleavedTextMedia for content
|
||||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
class AgentTurnInputType(BaseModel):
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
ParamType = Annotated[
|
||||
|
@ -55,11 +62,22 @@ ParamType = Annotated[
|
|||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
CustomType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
ArrayType.model_rebuild()
|
||||
ObjectType.model_rebuild()
|
||||
UnionType.model_rebuild()
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
# will cause infinite recursion in OpenAPI generation script
|
||||
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||
# ArrayType.model_rebuild()
|
||||
# ObjectType.model_rebuild()
|
||||
# UnionType.model_rebuild()
|
||||
|
||||
|
||||
# class CustomType(BaseModel):
|
||||
# type: Literal["custom"] = "custom"
|
||||
# validator_class: str
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ class EvaluateResponse(BaseModel):
|
|||
generations: List[Dict[str, Any]]
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
scores: List[Dict[str, ScoringResult]]
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
|
@ -61,10 +61,10 @@ class Eval(Protocol):
|
|||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/job/status", method="GET")
|
||||
async def job_status(self, job_id: str) -> None: ...
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/job/cancel", method="POST")
|
||||
async def job_cancel(self, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/job/result", method="GET")
|
||||
async def job_result(self, job_id: str) -> None: ...
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse: ...
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
|
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
job_uuid: str
|
||||
|
||||
model: str
|
||||
dataset: TrainEvalDataset
|
||||
validation_dataset: TrainEvalDataset
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Union[
|
||||
|
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
|||
|
||||
finetuned_model: URL
|
||||
|
||||
dataset: TrainEvalDataset
|
||||
validation_dataset: TrainEvalDataset
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: RLHFAlgorithm
|
||||
algorithm_config: Union[DPOAlignmentConfig]
|
||||
|
@ -181,8 +181,8 @@ class PostTraining(Protocol):
|
|||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset: TrainEvalDataset,
|
||||
validation_dataset: TrainEvalDataset,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: Union[
|
||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||
|
@ -198,8 +198,8 @@ class PostTraining(Protocol):
|
|||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
dataset: TrainEvalDataset,
|
||||
validation_dataset: TrainEvalDataset,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: RLHFAlgorithm,
|
||||
algorithm_config: Union[DPOAlignmentConfig],
|
||||
optimizer_config: OptimizerConfig,
|
||||
|
|
|
@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
|
|||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ...
|
||||
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -29,7 +29,7 @@ class LLMAsJudgeContext(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFunctionDef(BaseModel):
|
||||
class ScoringFnDef(BaseModel):
|
||||
identifier: str
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
|
@ -48,7 +48,7 @@ class ScoringFunctionDef(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
|
||||
class ScoringFnDefWithProvider(ScoringFnDef):
|
||||
provider_id: str = Field(
|
||||
description="ID of the provider which serves this dataset",
|
||||
)
|
||||
|
@ -57,14 +57,14 @@ class ScoringFunctionDefWithProvider(ScoringFunctionDef):
|
|||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring_functions/list", method="GET")
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/scoring_functions/get", method="GET")
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFunctionDefWithProvider]: ...
|
||||
) -> Optional[ScoringFnDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/scoring_functions/register", method="POST")
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDefWithProvider
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
) -> None: ...
|
||||
|
|
|
@ -13,7 +13,6 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
|
@ -40,7 +39,7 @@ class SyntheticDataGenerationRequest(BaseModel):
|
|||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
synthetic_data: List[ScoredDialogGenerations]
|
||||
synthetic_data: List[Dict[str, Any]]
|
||||
statistics: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ RoutableObject = Union[
|
|||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
DatasetDef,
|
||||
ScoringFunctionDef,
|
||||
ScoringFnDef,
|
||||
]
|
||||
|
||||
RoutableObjectWithProvider = Union[
|
||||
|
@ -42,7 +42,7 @@ RoutableObjectWithProvider = Union[
|
|||
ShieldDefWithProvider,
|
||||
MemoryBankDefWithProvider,
|
||||
DatasetDefWithProvider,
|
||||
ScoringFunctionDefWithProvider,
|
||||
ScoringFnDefWithProvider,
|
||||
]
|
||||
|
||||
RoutedProtocol = Union[
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
|
@ -46,6 +47,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.datasetio: DatasetIO,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.scoring: Scoring,
|
||||
Api.eval: Eval,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
scoring_functions = await p.list_scoring_functions()
|
||||
add_objects(
|
||||
[
|
||||
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
|
||||
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in scoring_functions
|
||||
]
|
||||
)
|
||||
|
@ -239,7 +239,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
|
@ -247,10 +247,10 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
|||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFunctionDefWithProvider]:
|
||||
) -> Optional[ScoringFnDefWithProvider]:
|
||||
return self.get_object_by_identifier(name)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDefWithProvider
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.datasets import DatasetDef
|
||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||
from llama_stack.apis.models import ModelDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctionDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.shields import ShieldDef
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ class Api(Enum):
|
|||
memory = "memory"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
@ -63,11 +64,9 @@ class DatasetsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDef
|
||||
) -> None: ...
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -143,11 +143,12 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
||||
if rows_in_page == -1:
|
||||
rows = dataset_info.dataset_impl[next_page_token:]
|
||||
|
||||
start = next_page_token
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
if rows_in_page == -1:
|
||||
end = len(dataset_info.dataset_impl)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
|
|
27
llama_stack/providers/impls/meta_reference/eval/__init__.py
Normal file
27
llama_stack/providers/impls/meta_reference/eval/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
impl = MetaReferenceEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.eval import * # noqa: F401, F403
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel): ...
|
167
llama_stack/providers/impls/meta_reference/eval/eval.py
Normal file
167
llama_stack/providers/impls/meta_reference/eval/eval.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
expected_answer = "expected_answer"
|
||||
chat_completion_input = "chat_completion_input"
|
||||
completion_input = "completion_input"
|
||||
generated_answer = "generated_answer"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
expected_schemas = [
|
||||
{
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
]
|
||||
|
||||
if dataset_def.dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def evaluate_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
) -> Job:
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.evaluate(
|
||||
input_rows=all_rows.rows,
|
||||
candidate=candidate,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
) -> EvaluateResponse:
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
)
|
||||
assert (
|
||||
candidate.sampling_params.max_tokens is not None
|
||||
), "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in input_rows:
|
||||
if ColumnName.completion_input.value in x:
|
||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||
response = await self.inference_api.completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
# scoring with generated_answer
|
||||
score_input_rows = [
|
||||
input_r | generated_r
|
||||
for input_r, generated_r in zip(input_rows, generations)
|
||||
]
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
|
||||
return None
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id)
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
return self.jobs[job_id]
|
|
@ -13,17 +13,22 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
|
||||
EqualityScorer,
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
EqualityScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
SubsetOfScoringFn,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
SUPPORTED_SCORERS = [
|
||||
EqualityScorer,
|
||||
SUPPORTED_SCORING_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
]
|
||||
|
||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
|
||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
|
||||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
|
@ -41,10 +46,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
|
||||
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
raise NotImplementedError(
|
||||
"Dynamically registering scoring functions is not supported"
|
||||
)
|
||||
|
@ -96,9 +101,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in SCORER_REGISTRY:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scorer = SCORER_REGISTRY[scoring_fn_id]()
|
||||
score_results = scorer.score(input_rows)
|
||||
agg_results = scorer.aggregate(score_results)
|
||||
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
||||
score_results = scoring_fn.score(input_rows)
|
||||
agg_results = scoring_fn.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -9,15 +9,15 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
|||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
class BaseScoringFn(ABC):
|
||||
"""
|
||||
Base interface class for all meta-reference scorers.
|
||||
Each scorer needs to implement the following methods:
|
||||
Base interface class for all meta-reference scoring_fns.
|
||||
Each scoring_fn needs to implement the following methods:
|
||||
- score_row(self, row)
|
||||
- aggregate(self, scorer_results)
|
||||
- aggregate(self, scoring_fn_results)
|
||||
"""
|
||||
|
||||
scoring_function_def: ScoringFunctionDef
|
||||
scoring_function_def: ScoringFnDef
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
return {
|
||||
"accuracy": avg_score,
|
||||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
|
@ -4,20 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
|
||||
BaseScorer,
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
|
||||
class EqualityScorer(BaseScorer):
|
||||
class EqualityScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
scoring_function_def = ScoringFunctionDef(
|
||||
scoring_function_def = ScoringFnDef(
|
||||
identifier="equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
|
@ -38,12 +41,4 @@ class EqualityScorer(BaseScorer):
|
|||
}
|
||||
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
assert len(scoring_results) > 0, "Empty scoring results provided."
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
return {
|
||||
"accuracy": avg_score,
|
||||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
||||
return aggregate_accuracy(scoring_results)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
|
||||
class SubsetOfScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
scoring_function_def = ScoringFnDef(
|
||||
identifier="subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
27
llama_stack/providers/registry/eval.py
Normal file
27
llama_stack/providers/registry/eval.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.impls.meta_reference.eval",
|
||||
config_class="llama_stack.providers.impls.meta_reference.eval.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
Api.scoring,
|
||||
Api.inference,
|
||||
],
|
||||
),
|
||||
]
|
|
@ -1,6 +1,6 @@
|
|||
input_query,generated_answer,expected_answer
|
||||
What is the capital of France?,London,Paris
|
||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
|
||||
What is the largest planet in our solar system?,Jupiter,Jupiter
|
||||
What is the smallest country in the world?,China,Vatican City
|
||||
What is the currency of Japan?,Yen,Yen
|
||||
input_query,generated_answer,expected_answer,chat_completion_input
|
||||
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
|
||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
|
||||
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
|
||||
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
|
||||
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"
|
||||
|
|
|
|
@ -61,20 +61,31 @@ def data_url_from_file(file_path: str) -> str:
|
|||
return data_url
|
||||
|
||||
|
||||
async def register_dataset(datasets_impl: Datasets):
|
||||
async def register_dataset(
|
||||
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||
):
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
if for_generation:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"generated_answer": StringType(),
|
||||
}
|
||||
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier="test_dataset",
|
||||
identifier=dataset_id,
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
dataset_schema=dataset_schema,
|
||||
)
|
||||
await datasets_impl.register_dataset(dataset)
|
||||
|
||||
|
|
5
llama_stack/providers/tests/eval/__init__.py
Normal file
5
llama_stack/providers/tests/eval/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,18 @@
|
|||
providers:
|
||||
datasetio:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
eval:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
inference:
|
||||
- provider_id: test-tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:5009
|
79
llama_stack/providers/tests/eval/test_eval.py
Normal file
79
llama_stack/providers/tests/eval/test_eval.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.eval.eval import ModelCandidate
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_models.llama3.api import SamplingParams
|
||||
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/eval/test_eval.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def eval_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference]
|
||||
)
|
||||
return {
|
||||
"eval_impl": impls[Api.eval],
|
||||
"scoring_impl": impls[Api.scoring],
|
||||
"datasets_impl": impls[Api.datasets],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval(eval_settings):
|
||||
datasets_impl = eval_settings["datasets_impl"]
|
||||
await register_dataset(
|
||||
datasets_impl,
|
||||
for_generation=True,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
eval_impl = eval_settings["eval_impl"]
|
||||
response = await eval_impl.evaluate_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
candidate=ModelCandidate(
|
||||
model="Llama3.2-1B-Instruct",
|
||||
sampling_params=SamplingParams(),
|
||||
),
|
||||
scoring_functions=["subset_of"],
|
||||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(response.job_id)
|
||||
|
||||
assert job_status and job_status.value == "completed"
|
||||
|
||||
eval_response = await eval_impl.job_result(response.job_id)
|
||||
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 5
|
||||
assert "subset_of" in eval_response.scores
|
|
@ -14,7 +14,12 @@ apis:
|
|||
- datasets
|
||||
- datasetio
|
||||
- scoring
|
||||
- eval
|
||||
providers:
|
||||
eval:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue