mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
[Evals API][4/n] evals with generation meta-reference impl (#303)
* wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * evals with generation * add all rows scores to ScoringResult * minor typing * bugfix * scoring function def rename * rebase name * refactor * address comments * Update iOS inference instructions for new quantization * Small updates to quantization config * Fix score threshold in faiss * Bump version to 0.0.45 * Handle both ipv6 and ipv4 interfaces together * update manifest for build templates * Update getting_started.md * chatcompletion & completion input type validation * inclusion->subsetof * error checking * scoring_function -> scoring_fn rename, scorer -> scoring_fn rename * address comments * [Evals API][5/n] fixes to generate openapi spec (#323) * generate openapi * typing comment, dataset -> dataset_id * remove custom type * sample eval run.yaml --------- Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
426d821e7f
commit
abdf7cddf3
31 changed files with 3371 additions and 1296 deletions
|
@ -33,14 +33,16 @@ schema_utils.json_schema_type = json_schema_type
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.apis.dataset import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
|
from llama_stack.apis.eval import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
from llama_stack.apis.post_training import * # noqa: F403
|
from llama_stack.apis.post_training import * # noqa: F403
|
||||||
from llama_stack.apis.reward_scoring import * # noqa: F403
|
|
||||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
@ -54,14 +56,16 @@ class LlamaStack(
|
||||||
Inference,
|
Inference,
|
||||||
BatchInference,
|
BatchInference,
|
||||||
Agents,
|
Agents,
|
||||||
RewardScoring,
|
|
||||||
Safety,
|
Safety,
|
||||||
SyntheticDataGeneration,
|
SyntheticDataGeneration,
|
||||||
Datasets,
|
Datasets,
|
||||||
Telemetry,
|
Telemetry,
|
||||||
PostTraining,
|
PostTraining,
|
||||||
Memory,
|
Memory,
|
||||||
Evaluations,
|
Eval,
|
||||||
|
Scoring,
|
||||||
|
ScoringFunctions,
|
||||||
|
DatasetIO,
|
||||||
Models,
|
Models,
|
||||||
Shields,
|
Shields,
|
||||||
Inspect,
|
Inspect,
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -3,6 +3,8 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -10,3 +12,9 @@ from pydantic import BaseModel
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Job(BaseModel):
|
class Job(BaseModel):
|
||||||
job_id: str
|
job_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class JobStatus(Enum):
|
||||||
|
completed = "completed"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict, List, Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -24,12 +24,10 @@ class BooleanType(BaseModel):
|
||||||
|
|
||||||
class ArrayType(BaseModel):
|
class ArrayType(BaseModel):
|
||||||
type: Literal["array"] = "array"
|
type: Literal["array"] = "array"
|
||||||
items: "ParamType"
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectType(BaseModel):
|
class ObjectType(BaseModel):
|
||||||
type: Literal["object"] = "object"
|
type: Literal["object"] = "object"
|
||||||
properties: Dict[str, "ParamType"] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class JsonType(BaseModel):
|
class JsonType(BaseModel):
|
||||||
|
@ -38,12 +36,21 @@ class JsonType(BaseModel):
|
||||||
|
|
||||||
class UnionType(BaseModel):
|
class UnionType(BaseModel):
|
||||||
type: Literal["union"] = "union"
|
type: Literal["union"] = "union"
|
||||||
options: List["ParamType"] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomType(BaseModel):
|
class ChatCompletionInputType(BaseModel):
|
||||||
type: Literal["custom"] = "custom"
|
# expects List[Message] for messages
|
||||||
validator_class: str
|
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionInputType(BaseModel):
|
||||||
|
# expects InterleavedTextMedia for content
|
||||||
|
type: Literal["completion_input"] = "completion_input"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTurnInputType(BaseModel):
|
||||||
|
# expects List[Message] for messages (may also include attachments?)
|
||||||
|
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||||
|
|
||||||
|
|
||||||
ParamType = Annotated[
|
ParamType = Annotated[
|
||||||
|
@ -55,11 +62,22 @@ ParamType = Annotated[
|
||||||
ObjectType,
|
ObjectType,
|
||||||
JsonType,
|
JsonType,
|
||||||
UnionType,
|
UnionType,
|
||||||
CustomType,
|
ChatCompletionInputType,
|
||||||
|
CompletionInputType,
|
||||||
|
AgentTurnInputType,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
ArrayType.model_rebuild()
|
# TODO: recursive definition of ParamType in these containers
|
||||||
ObjectType.model_rebuild()
|
# will cause infinite recursion in OpenAPI generation script
|
||||||
UnionType.model_rebuild()
|
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||||
|
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||||
|
# ArrayType.model_rebuild()
|
||||||
|
# ObjectType.model_rebuild()
|
||||||
|
# UnionType.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
|
# class CustomType(BaseModel):
|
||||||
|
# type: Literal["custom"] = "custom"
|
||||||
|
# validator_class: str
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
from llama_stack.apis.scoring import * # noqa: F403
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class EvaluateResponse(BaseModel):
|
||||||
generations: List[Dict[str, Any]]
|
generations: List[Dict[str, Any]]
|
||||||
|
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
scores: List[Dict[str, ScoringResult]]
|
scores: Dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
|
@ -61,10 +61,10 @@ class Eval(Protocol):
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/status", method="GET")
|
@webmethod(route="/eval/job/status", method="GET")
|
||||||
async def job_status(self, job_id: str) -> None: ...
|
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/cancel", method="POST")
|
@webmethod(route="/eval/job/cancel", method="POST")
|
||||||
async def job_cancel(self, job_id: str) -> None: ...
|
async def job_cancel(self, job_id: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/result", method="GET")
|
@webmethod(route="/eval/job/result", method="GET")
|
||||||
async def job_result(self, job_id: str) -> None: ...
|
async def job_result(self, job_id: str) -> EvaluateResponse: ...
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.dataset import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
dataset: TrainEvalDataset
|
dataset_id: str
|
||||||
validation_dataset: TrainEvalDataset
|
validation_dataset_id: str
|
||||||
|
|
||||||
algorithm: FinetuningAlgorithm
|
algorithm: FinetuningAlgorithm
|
||||||
algorithm_config: Union[
|
algorithm_config: Union[
|
||||||
|
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
|
|
||||||
finetuned_model: URL
|
finetuned_model: URL
|
||||||
|
|
||||||
dataset: TrainEvalDataset
|
dataset_id: str
|
||||||
validation_dataset: TrainEvalDataset
|
validation_dataset_id: str
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
algorithm: RLHFAlgorithm
|
||||||
algorithm_config: Union[DPOAlignmentConfig]
|
algorithm_config: Union[DPOAlignmentConfig]
|
||||||
|
@ -181,8 +181,8 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
model: str,
|
||||||
dataset: TrainEvalDataset,
|
dataset_id: str,
|
||||||
validation_dataset: TrainEvalDataset,
|
validation_dataset_id: str,
|
||||||
algorithm: FinetuningAlgorithm,
|
algorithm: FinetuningAlgorithm,
|
||||||
algorithm_config: Union[
|
algorithm_config: Union[
|
||||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||||
|
@ -198,8 +198,8 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
finetuned_model: URL,
|
finetuned_model: URL,
|
||||||
dataset: TrainEvalDataset,
|
dataset_id: str,
|
||||||
validation_dataset: TrainEvalDataset,
|
validation_dataset_id: str,
|
||||||
algorithm: RLHFAlgorithm,
|
algorithm: RLHFAlgorithm,
|
||||||
algorithm_config: Union[DPOAlignmentConfig],
|
algorithm_config: Union[DPOAlignmentConfig],
|
||||||
optimizer_config: OptimizerConfig,
|
optimizer_config: OptimizerConfig,
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
class ScoringFunctionStore(Protocol):
|
||||||
def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ...
|
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -29,7 +29,7 @@ class LLMAsJudgeContext(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFunctionDef(BaseModel):
|
class ScoringFnDef(BaseModel):
|
||||||
identifier: str
|
identifier: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
|
@ -48,7 +48,7 @@ class ScoringFunctionDef(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
|
class ScoringFnDefWithProvider(ScoringFnDef):
|
||||||
provider_id: str = Field(
|
provider_id: str = Field(
|
||||||
description="ID of the provider which serves this dataset",
|
description="ID of the provider which serves this dataset",
|
||||||
)
|
)
|
||||||
|
@ -57,14 +57,14 @@ class ScoringFunctionDefWithProvider(ScoringFunctionDef):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring_functions/list", method="GET")
|
@webmethod(route="/scoring_functions/list", method="GET")
|
||||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ...
|
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring_functions/get", method="GET")
|
@webmethod(route="/scoring_functions/get", method="GET")
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(
|
||||||
self, name: str
|
self, name: str
|
||||||
) -> Optional[ScoringFunctionDefWithProvider]: ...
|
) -> Optional[ScoringFnDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring_functions/register", method="POST")
|
@webmethod(route="/scoring_functions/register", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
self, function_def: ScoringFunctionDefWithProvider
|
self, function_def: ScoringFnDefWithProvider
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -13,7 +13,6 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.reward_scoring import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
class FilteringFunction(Enum):
|
||||||
|
@ -40,7 +39,7 @@ class SyntheticDataGenerationRequest(BaseModel):
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
class SyntheticDataGenerationResponse(BaseModel):
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
synthetic_data: List[ScoredDialogGenerations]
|
synthetic_data: List[Dict[str, Any]]
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
statistics: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ RoutableObject = Union[
|
||||||
ShieldDef,
|
ShieldDef,
|
||||||
MemoryBankDef,
|
MemoryBankDef,
|
||||||
DatasetDef,
|
DatasetDef,
|
||||||
ScoringFunctionDef,
|
ScoringFnDef,
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutableObjectWithProvider = Union[
|
RoutableObjectWithProvider = Union[
|
||||||
|
@ -42,7 +42,7 @@ RoutableObjectWithProvider = Union[
|
||||||
ShieldDefWithProvider,
|
ShieldDefWithProvider,
|
||||||
MemoryBankDefWithProvider,
|
MemoryBankDefWithProvider,
|
||||||
DatasetDefWithProvider,
|
DatasetDefWithProvider,
|
||||||
ScoringFunctionDefWithProvider,
|
ScoringFnDefWithProvider,
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Union[
|
||||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
|
@ -46,6 +47,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.datasetio: DatasetIO,
|
Api.datasetio: DatasetIO,
|
||||||
Api.scoring_functions: ScoringFunctions,
|
Api.scoring_functions: ScoringFunctions,
|
||||||
Api.scoring: Scoring,
|
Api.scoring: Scoring,
|
||||||
|
Api.eval: Eval,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
add_objects(
|
add_objects(
|
||||||
[
|
[
|
||||||
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
|
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
|
||||||
for s in scoring_functions
|
for s in scoring_functions
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -239,7 +239,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
|
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||||
objects = []
|
objects = []
|
||||||
for objs in self.registry.values():
|
for objs in self.registry.values():
|
||||||
objects.extend(objs)
|
objects.extend(objs)
|
||||||
|
@ -247,10 +247,10 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||||
|
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(
|
||||||
self, name: str
|
self, name: str
|
||||||
) -> Optional[ScoringFunctionDefWithProvider]:
|
) -> Optional[ScoringFnDefWithProvider]:
|
||||||
return self.get_object_by_identifier(name)
|
return self.get_object_by_identifier(name)
|
||||||
|
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
self, function_def: ScoringFunctionDefWithProvider
|
self, function_def: ScoringFnDefWithProvider
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.register_object(function_def)
|
await self.register_object(function_def)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
from llama_stack.apis.datasets import DatasetDef
|
from llama_stack.apis.datasets import DatasetDef
|
||||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||||
from llama_stack.apis.models import ModelDef
|
from llama_stack.apis.models import ModelDef
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctionDef
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
from llama_stack.apis.shields import ShieldDef
|
from llama_stack.apis.shields import ShieldDef
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ class Api(Enum):
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
datasetio = "datasetio"
|
datasetio = "datasetio"
|
||||||
scoring = "scoring"
|
scoring = "scoring"
|
||||||
|
eval = "eval"
|
||||||
|
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
|
@ -63,11 +64,9 @@ class DatasetsProtocolPrivate(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ...
|
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
|
||||||
|
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
|
||||||
self, function_def: ScoringFunctionDef
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -143,11 +143,12 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
next_page_token = int(page_token)
|
next_page_token = int(page_token)
|
||||||
|
|
||||||
if rows_in_page == -1:
|
|
||||||
rows = dataset_info.dataset_impl[next_page_token:]
|
|
||||||
|
|
||||||
start = next_page_token
|
start = next_page_token
|
||||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
if rows_in_page == -1:
|
||||||
|
end = len(dataset_info.dataset_impl)
|
||||||
|
else:
|
||||||
|
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||||
|
|
||||||
rows = dataset_info.dataset_impl[start:end]
|
rows = dataset_info.dataset_impl[start:end]
|
||||||
|
|
||||||
return PaginatedRowsResult(
|
return PaginatedRowsResult(
|
||||||
|
|
27
llama_stack/providers/impls/meta_reference/eval/__init__.py
Normal file
27
llama_stack/providers/impls/meta_reference/eval/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceEvalConfig,
|
||||||
|
deps: Dict[Api, ProviderSpec],
|
||||||
|
):
|
||||||
|
from .eval import MetaReferenceEvalImpl
|
||||||
|
|
||||||
|
impl = MetaReferenceEvalImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.datasetio],
|
||||||
|
deps[Api.datasets],
|
||||||
|
deps[Api.scoring],
|
||||||
|
deps[Api.inference],
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.eval import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceEvalConfig(BaseModel): ...
|
167
llama_stack/providers/impls/meta_reference/eval/eval.py
Normal file
167
llama_stack/providers/impls/meta_reference/eval/eval.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from enum import Enum
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.scoring import Scoring
|
||||||
|
|
||||||
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnName(Enum):
|
||||||
|
expected_answer = "expected_answer"
|
||||||
|
chat_completion_input = "chat_completion_input"
|
||||||
|
completion_input = "completion_input"
|
||||||
|
generated_answer = "generated_answer"
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceEvalImpl(Eval):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MetaReferenceEvalConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
|
scoring_api: Scoring,
|
||||||
|
inference_api: Inference,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
self.scoring_api = scoring_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
|
self.jobs = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
|
expected_schemas = [
|
||||||
|
{
|
||||||
|
ColumnName.expected_answer.value: StringType(),
|
||||||
|
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColumnName.expected_answer.value: StringType(),
|
||||||
|
ColumnName.completion_input.value: CompletionInputType(),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if dataset_def.dataset_schema not in expected_schemas:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def evaluate_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
candidate: EvalCandidate,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
) -> Job:
|
||||||
|
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||||
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows_in_page=-1,
|
||||||
|
)
|
||||||
|
res = await self.evaluate(
|
||||||
|
input_rows=all_rows.rows,
|
||||||
|
candidate=candidate,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: currently needs to wait for generation before returning
|
||||||
|
# need job scheduler queue (ray/celery) w/ jobs api
|
||||||
|
job_id = str(len(self.jobs))
|
||||||
|
self.jobs[job_id] = res
|
||||||
|
return Job(job_id=job_id)
|
||||||
|
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
candidate: EvalCandidate,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
if candidate.type == "agent":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Evaluation with generation has not been implemented for agents"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
candidate.sampling_params.max_tokens is not None
|
||||||
|
), "SamplingParams.max_tokens must be provided"
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
for x in input_rows:
|
||||||
|
if ColumnName.completion_input.value in x:
|
||||||
|
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||||
|
response = await self.inference_api.completion(
|
||||||
|
model=candidate.model,
|
||||||
|
content=input_content,
|
||||||
|
sampling_params=candidate.sampling_params,
|
||||||
|
)
|
||||||
|
generations.append(
|
||||||
|
{
|
||||||
|
ColumnName.generated_answer.value: response.completion_message.content
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
|
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||||
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
|
messages = []
|
||||||
|
if candidate.system_message:
|
||||||
|
messages.append(candidate.system_message)
|
||||||
|
messages += input_messages
|
||||||
|
response = await self.inference_api.chat_completion(
|
||||||
|
model=candidate.model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=candidate.sampling_params,
|
||||||
|
)
|
||||||
|
generations.append(
|
||||||
|
{
|
||||||
|
ColumnName.generated_answer.value: response.completion_message.content
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input row")
|
||||||
|
|
||||||
|
# scoring with generated_answer
|
||||||
|
score_input_rows = [
|
||||||
|
input_r | generated_r
|
||||||
|
for input_r, generated_r in zip(input_rows, generations)
|
||||||
|
]
|
||||||
|
|
||||||
|
score_response = await self.scoring_api.score(
|
||||||
|
input_rows=score_input_rows, scoring_functions=scoring_functions
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||||
|
|
||||||
|
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||||
|
if job_id in self.jobs:
|
||||||
|
return JobStatus.completed
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def job_cancel(self, job_id: str) -> None:
|
||||||
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
|
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||||
|
status = await self.job_status(job_id)
|
||||||
|
if not status or status != JobStatus.completed:
|
||||||
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||||
|
|
||||||
|
return self.jobs[job_id]
|
|
@ -13,17 +13,22 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||||
EqualityScorer,
|
EqualityScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||||
|
SubsetOfScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceScoringConfig
|
from .config import MetaReferenceScoringConfig
|
||||||
|
|
||||||
SUPPORTED_SCORERS = [
|
SUPPORTED_SCORING_FNS = [
|
||||||
EqualityScorer,
|
EqualityScoringFn,
|
||||||
|
SubsetOfScoringFn,
|
||||||
]
|
]
|
||||||
|
|
||||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
|
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
@ -41,10 +46,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
|
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Dynamically registering scoring functions is not supported"
|
"Dynamically registering scoring functions is not supported"
|
||||||
)
|
)
|
||||||
|
@ -96,9 +101,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in SCORER_REGISTRY:
|
if scoring_fn_id not in SCORER_REGISTRY:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scorer = SCORER_REGISTRY[scoring_fn_id]()
|
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
||||||
score_results = scorer.score(input_rows)
|
score_results = scoring_fn.score(input_rows)
|
||||||
agg_results = scorer.aggregate(score_results)
|
agg_results = scoring_fn.aggregate(score_results)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -9,15 +9,15 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
class BaseScorer(ABC):
|
class BaseScoringFn(ABC):
|
||||||
"""
|
"""
|
||||||
Base interface class for all meta-reference scorers.
|
Base interface class for all meta-reference scoring_fns.
|
||||||
Each scorer needs to implement the following methods:
|
Each scoring_fn needs to implement the following methods:
|
||||||
- score_row(self, row)
|
- score_row(self, row)
|
||||||
- aggregate(self, scorer_results)
|
- aggregate(self, scoring_fn_results)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def: ScoringFunctionDef
|
scoring_function_def: ScoringFnDef
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
|
num_correct = sum(result["score"] for result in scoring_results)
|
||||||
|
avg_score = num_correct / len(scoring_results)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": avg_score,
|
||||||
|
"num_correct": num_correct,
|
||||||
|
"num_total": len(scoring_results),
|
||||||
|
}
|
|
@ -4,20 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||||
BaseScorer,
|
BaseScoringFn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_accuracy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EqualityScorer(BaseScorer):
|
class EqualityScoringFn(BaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def = ScoringFunctionDef(
|
scoring_function_def = ScoringFnDef(
|
||||||
identifier="equality",
|
identifier="equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
|
@ -38,12 +41,4 @@ class EqualityScorer(BaseScorer):
|
||||||
}
|
}
|
||||||
|
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
assert len(scoring_results) > 0, "Empty scoring results provided."
|
return aggregate_accuracy(scoring_results)
|
||||||
num_correct = sum(result["score"] for result in scoring_results)
|
|
||||||
avg_score = num_correct / len(scoring_results)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"accuracy": avg_score,
|
|
||||||
"num_correct": num_correct,
|
|
||||||
"num_total": len(scoring_results),
|
|
||||||
}
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||||
|
BaseScoringFn,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_accuracy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SubsetOfScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
scoring_function_def = ScoringFnDef(
|
||||||
|
identifier="subset_of",
|
||||||
|
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||||
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
|
assert (
|
||||||
|
"generated_answer" in input_row
|
||||||
|
), "Generated answer not found in input row."
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
|
||||||
|
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
|
return aggregate_accuracy(scoring_results)
|
27
llama_stack/providers/registry/eval.py
Normal file
27
llama_stack/providers/registry/eval.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.eval,
|
||||||
|
provider_type="meta-reference",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.impls.meta_reference.eval",
|
||||||
|
config_class="llama_stack.providers.impls.meta_reference.eval.MetaReferenceEvalConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
Api.scoring,
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
|
@ -1,6 +1,6 @@
|
||||||
input_query,generated_answer,expected_answer
|
input_query,generated_answer,expected_answer,chat_completion_input
|
||||||
What is the capital of France?,London,Paris
|
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
|
||||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
|
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
|
||||||
What is the largest planet in our solar system?,Jupiter,Jupiter
|
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
|
||||||
What is the smallest country in the world?,China,Vatican City
|
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
|
||||||
What is the currency of Japan?,Yen,Yen
|
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"
|
||||||
|
|
|
|
@ -61,20 +61,31 @@ def data_url_from_file(file_path: str) -> str:
|
||||||
return data_url
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
async def register_dataset(datasets_impl: Datasets):
|
async def register_dataset(
|
||||||
|
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||||
|
):
|
||||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||||
test_url = data_url_from_file(str(test_file))
|
test_url = data_url_from_file(str(test_file))
|
||||||
|
|
||||||
|
if for_generation:
|
||||||
|
dataset_schema = {
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dataset_schema = {
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
|
"generated_answer": StringType(),
|
||||||
|
}
|
||||||
|
|
||||||
dataset = DatasetDefWithProvider(
|
dataset = DatasetDefWithProvider(
|
||||||
identifier="test_dataset",
|
identifier=dataset_id,
|
||||||
provider_id=os.environ["PROVIDER_ID"],
|
provider_id=os.environ["PROVIDER_ID"],
|
||||||
url=URL(
|
url=URL(
|
||||||
uri=test_url,
|
uri=test_url,
|
||||||
),
|
),
|
||||||
dataset_schema={
|
dataset_schema=dataset_schema,
|
||||||
"generated_answer": StringType(),
|
|
||||||
"expected_answer": StringType(),
|
|
||||||
"input_query": StringType(),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
await datasets_impl.register_dataset(dataset)
|
await datasets_impl.register_dataset(dataset)
|
||||||
|
|
||||||
|
|
5
llama_stack/providers/tests/eval/__init__.py
Normal file
5
llama_stack/providers/tests/eval/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,18 @@
|
||||||
|
providers:
|
||||||
|
datasetio:
|
||||||
|
- provider_id: test-meta
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
scoring:
|
||||||
|
- provider_id: test-meta
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
eval:
|
||||||
|
- provider_id: test-meta
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
inference:
|
||||||
|
- provider_id: test-tgi
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:5009
|
79
llama_stack/providers/tests/eval/test_eval.py
Normal file
79
llama_stack/providers/tests/eval/test_eval.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
from llama_stack.apis.eval.eval import ModelCandidate
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_models.llama3.api import SamplingParams
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||||
|
# since it depends on the provider you are testing. On top of that you need
|
||||||
|
# `pytest` and `pytest-asyncio` installed.
|
||||||
|
#
|
||||||
|
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||||
|
#
|
||||||
|
# 3. Run:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# PROVIDER_ID=<your_provider> \
|
||||||
|
# PROVIDER_CONFIG=provider_config.yaml \
|
||||||
|
# pytest -s llama_stack/providers/tests/eval/test_eval.py \
|
||||||
|
# --tb=short --disable-warnings
|
||||||
|
# ```
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def eval_settings():
|
||||||
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"eval_impl": impls[Api.eval],
|
||||||
|
"scoring_impl": impls[Api.scoring],
|
||||||
|
"datasets_impl": impls[Api.datasets],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_eval(eval_settings):
|
||||||
|
datasets_impl = eval_settings["datasets_impl"]
|
||||||
|
await register_dataset(
|
||||||
|
datasets_impl,
|
||||||
|
for_generation=True,
|
||||||
|
dataset_id="test_dataset_for_eval",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await datasets_impl.list_datasets()
|
||||||
|
assert len(response) == 1
|
||||||
|
|
||||||
|
eval_impl = eval_settings["eval_impl"]
|
||||||
|
response = await eval_impl.evaluate_batch(
|
||||||
|
dataset_id=response[0].identifier,
|
||||||
|
candidate=ModelCandidate(
|
||||||
|
model="Llama3.2-1B-Instruct",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
),
|
||||||
|
scoring_functions=["subset_of"],
|
||||||
|
)
|
||||||
|
assert response.job_id == "0"
|
||||||
|
job_status = await eval_impl.job_status(response.job_id)
|
||||||
|
|
||||||
|
assert job_status and job_status.value == "completed"
|
||||||
|
|
||||||
|
eval_response = await eval_impl.job_result(response.job_id)
|
||||||
|
|
||||||
|
assert eval_response is not None
|
||||||
|
assert len(eval_response.generations) == 5
|
||||||
|
assert "subset_of" in eval_response.scores
|
|
@ -14,7 +14,12 @@ apis:
|
||||||
- datasets
|
- datasets
|
||||||
- datasetio
|
- datasetio
|
||||||
- scoring
|
- scoring
|
||||||
|
- eval
|
||||||
providers:
|
providers:
|
||||||
|
eval:
|
||||||
|
- provider_id: meta0
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
scoring:
|
scoring:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue