[Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296)

* wip

* dataset validation

* test_scoring

* cleanup

* clean up test

* comments

* error checking

* dataset client

* test client:

* datasetio client

* clean up

* basic scoring function works

* scorer wip

* equality scorer

* score batch impl

* score batch

* update scoring test

* refactor

* validate scorer input

* address comments

* add all rows scores to ScoringResult

* bugfix

* scoring function def rename
This commit is contained in:
Xi Yan 2024-10-24 14:52:30 -07:00 committed by GitHub
parent e70420a06e
commit cb84034567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 904 additions and 51 deletions

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasets.client import DatasetsClient
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetIOClient(DatasetIO):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasetio/get_rows_paginated",
params={
"dataset_id": dataset_id,
"rows_in_page": rows_in_page,
"page_token": page_token,
"filter_condition": filter_condition,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return PaginatedRowsResult(**response.json())
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
# datsetio client to get the rows
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
response = await datasetio_client.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=4,
page_token=None,
filter_condition=None,
)
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -29,7 +29,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/dataio/get_rows_paginated")
@webmethod(route="/datasetio/get_rows_paginated", method="GET")
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from .datasets import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetsClient(Datasets):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_dataset(
self,
dataset_def: DatasetDefWithProvider,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/datasets/register",
json={
"dataset_def": json.loads(dataset_def.json()),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
return
async def get_dataset(
self,
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/get",
params={
"dataset_identifier": dataset_identifier,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return DatasetDefWithProvider(**response.json())
async def list_datasets(self) -> List[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/list",
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return [DatasetDefWithProvider(**x) for x in response.json()]
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -20,7 +20,7 @@ class DatasetDef(BaseModel):
identifier: str = Field(
description="A unique name for the dataset",
)
columns_schema: Dict[str, ParamType] = Field(
dataset_schema: Dict[str, ParamType] = Field(
description="The schema definition for this dataset",
)
url: URL

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from pathlib import Path
import fire
import httpx
from termcolor import cprint
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio.client import DatasetIOClient
from llama_stack.apis.datasets.client import DatasetsClient
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class ScoringClient(Scoring):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score_batch",
json={
"dataset_id": dataset_id,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return ScoreBatchResponse(**response.json())
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score",
json={
"input_rows": input_rows,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return ScoreResponse(**response.json())
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
# datsetio client to get the rows
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
response = await datasetio_client.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=4,
page_token=None,
filter_condition=None,
)
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
# scoring client to score the rows
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score(
input_rows=response.rows,
scoring_functions=["equality"],
)
cprint(f"score response={response}", "blue")
# test scoring batch using datasetio api
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score_batch(
dataset_id="test-dataset",
scoring_functions=["equality"],
)
cprint(f"score_batch response={response}", "cyan")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -13,18 +13,27 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any]
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: str
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name
results: List[Dict[str, ScoringResult]]
results: Dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
@ -37,7 +46,10 @@ class Scoring(Protocol):
@webmethod(route="/scoring/score_batch")
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")

View file

@ -4,20 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
@ -33,45 +23,37 @@ class Parameter(BaseModel):
# with standard metrics so they can be rolled up?
class LLMAsJudgeContext(BaseModel):
judge_model: str
prompt_template: Optional[str] = None
@json_schema_type
class CommonDef(BaseModel):
name: str
class ScoringFunctionDef(BaseModel):
identifier: str
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
# Hack: same with memory_banks for union defs
provider_id: str = ""
@json_schema_type
class DeterministicFunctionDef(CommonDef):
type: Literal["deterministic"] = "deterministic"
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
context: Optional[LLMAsJudgeContext] = None
# We can optionally add information here to support packaging of code, etc.
@json_schema_type
class LLMJudgeFunctionDef(CommonDef):
type: Literal["judge"] = "judge"
model: str = Field(
description="The LLM model to use for the judge function",
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
ScoringFunctionDef = Annotated[
Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type")
]
ScoringFunctionDefWithProvider = ScoringFunctionDef
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
@ -84,5 +66,5 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/register", method="POST")
async def register_scoring_function(
self, function: ScoringFunctionDefWithProvider
self, function_def: ScoringFunctionDefWithProvider
) -> None: ...

View file

@ -15,10 +15,12 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -32,6 +34,7 @@ RoutableObject = Union[
ShieldDef,
MemoryBankDef,
DatasetDef,
ScoringFunctionDef,
]
RoutableObjectWithProvider = Union[
@ -39,6 +42,7 @@ RoutableObjectWithProvider = Union[
ShieldDefWithProvider,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
ScoringFunctionDefWithProvider,
]
RoutedProtocol = Union[
@ -46,6 +50,7 @@ RoutedProtocol = Union[
Safety,
Memory,
DatasetIO,
Scoring,
]

View file

@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
]

View file

@ -20,6 +20,8 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
@ -42,6 +44,8 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO,
Api.scoring_functions: ScoringFunctions,
Api.scoring: Scoring,
}

View file

@ -11,6 +11,7 @@ from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
)
@ -25,7 +26,9 @@ async def get_routing_table_impl(
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
@ -35,13 +38,20 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter
from .routers import (
DatasetIORouter,
InferenceRouter,
MemoryRouter,
SafetyRouter,
ScoringRouter,
)
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -13,6 +13,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
class MemoryRouter(Memory):
@ -192,3 +193,56 @@ class DatasetIORouter(DatasetIO):
page_token=page_token,
filter_condition=filter_condition,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
)
res.update(score_response.results)
return ScoreResponse(results=res)

View file

@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -93,7 +95,15 @@ class CommonRoutingTableImpl(RoutingTable):
for d in datasets:
d.provider_id = pid
add_objects(datasets)
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
add_objects(
[
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
for s in scoring_functions
]
)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -109,6 +119,10 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Safety", "shield")
elif isinstance(self, MemoryBanksRoutingTable):
return ("Memory", "memory_bank")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
else:
raise ValueError("Unknown routing table type")
@ -218,7 +232,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def get_dataset(
self, dataset_identifier: str
) -> Optional[DatasetDefWithProvider]:
return self.get_object_by_identifier(identifier)
return self.get_object_by_identifier(dataset_identifier)
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
await self.register_object(dataset_def)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFunctionDefWithProvider]:
return self.get_object_by_identifier(name)
async def register_scoring_function(
self, function_def: ScoringFunctionDefWithProvider
) -> None:
await self.register_object(function_def)

View file

@ -11,10 +11,9 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.scoring_functions import ScoringFunctionDef
from llama_stack.apis.shields import ShieldDef
@ -25,6 +24,7 @@ class Api(Enum):
agents = "agents"
memory = "memory"
datasetio = "datasetio"
scoring = "scoring"
telemetry = "telemetry"
@ -32,6 +32,7 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
datasets = "datasets"
scoring_functions = "scoring_functions"
# built-in API
inspect = "inspect"
@ -61,6 +62,14 @@ class DatasetsProtocolPrivate(Protocol):
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ...
async def register_scoring_function(
self, function_def: ScoringFunctionDef
) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api

View file

@ -3,17 +3,20 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
from typing import List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import unquote
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import MetaReferenceDatasetIOConfig
@ -52,11 +55,20 @@ class PandasDataframeDataset(BaseDataset):
return len(self.df)
def __getitem__(self, idx):
assert self.df is not None, "Dataset not loaded. Please call .load() first"
if isinstance(idx, slice):
return self.df.iloc[idx].to_dict(orient="records")
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
@ -87,7 +99,7 @@ class PandasDataframeDataset(BaseDataset):
else:
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
self.df = df
self.df = self._validate_dataset_schema(df)
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -123,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
if page_token is None:
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
class MetaReferenceScoringConfig(BaseModel): ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
class BaseScorer(ABC):
"""
Base interface class for all meta-reference scorers.
Each scorer needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scorer_results)
"""
scoring_function_def: ScoringFunctionDef
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return self.__class__.__name__
@abstractmethod
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
raise NotImplementedError()
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
return [self.score_row(input_row) for input_row in input_rows]

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
BaseScorer,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
class EqualityScorer(BaseScorer):
"""
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFunctionDef(
identifier="equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided."
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
EqualityScorer,
)
from .config import MetaReferenceScoringConfig
SUPPORTED_SCORERS = [
EqualityScorer,
]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
raise NotImplementedError(
"Dynamically registering scoring functions is not supported"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in SCORER_REGISTRY:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scorer = SCORER_REGISTRY[scoring_fn_id]()
score_results = scorer.score(input_rows)
agg_results = scorer.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.scoring,
provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.scoring",
config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]

View file

@ -0,0 +1,6 @@
input_query,generated_answer,expected_answer
What is the capital of France?,London,Paris
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,Jupiter,Jupiter
What is the smallest country in the world?,China,Vatican City
What is the currency of Japan?,Yen,Yen
1 input_query generated_answer expected_answer
2 What is the capital of France? London Paris
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg
4 What is the largest planet in our solar system? Jupiter Jupiter
5 What is the smallest country in the world? China Vatican City
6 What is the currency of Japan? Yen Yen

View file

@ -8,8 +8,13 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import base64
import mimetypes
from pathlib import Path
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
@ -41,14 +46,35 @@ async def datasetio_settings():
}
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
async def register_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
uri=test_url,
),
columns_schema={},
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
await datasets_impl.register_dataset(dataset)
@ -100,10 +126,10 @@ async def test_get_rows_paginated(datasetio_settings):
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=10,
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 10
assert response.next_page_token == "13"
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,9 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def scoring_settings():
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
return {
"scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions],
"datasets_impl": impls[Api.datasets],
}
@pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings):
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "equality" in function_ids
@pytest.mark.asyncio
async def test_scoring_score(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
datasets_impl = scoring_settings["datasets_impl"]
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=["equality"],
)
assert len(response.results) == 1
assert "equality" in response.results

View file

@ -13,7 +13,12 @@ apis:
- inference
- datasets
- datasetio
- scoring
providers:
scoring:
- provider_id: meta0
provider_type: meta-reference
config: {}
datasetio:
- provider_id: meta0
provider_type: meta-reference