This commit is contained in:
Xi Yan 2025-03-18 21:49:11 -07:00
parent 011fd59a29
commit 8a576d7d72
24 changed files with 297 additions and 2525 deletions

View file

@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
)

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.eval,
provider_type="inline::meta-reference",
pip_packages=["tree_sitter"],
module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
],
),
]

View file

@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::basic",
pip_packages=[],
module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::llm-as-judge",
pip_packages=[],
module="llama_stack.providers.inline.scoring.llm_as_judge",
config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.inference,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::braintrust",
pip_packages=["autoevals", "openai"],
module="llama_stack.providers.inline.scoring.braintrust",
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]

View file

@ -75,29 +75,31 @@ VALID_SCHEMAS_FOR_EVAL = [
]
def get_valid_schemas(api_str: str):
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING
elif api_str == Api.eval.value:
return VALID_SCHEMAS_FOR_EVAL
else:
raise ValueError(f"Invalid API string: {api_str}")
# TODO(xiyan): add this back
# def get_valid_schemas(api_str: str):
# if api_str == Api.scoring.value:
# return VALID_SCHEMAS_FOR_SCORING
# elif api_str == Api.eval.value:
# return VALID_SCHEMAS_FOR_EVAL
# else:
# raise ValueError(f"Invalid API string: {api_str}")
def validate_dataset_schema(
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
# def validate_dataset_schema(
# dataset_schema: Dict[str, Any],
# expected_schemas: List[Dict[str, Any]],
# ):
# if dataset_schema not in expected_schemas:
# raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
def validate_row_schema(
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
# def validate_row_schema(
# input_row: Dict[str, Any],
# expected_schemas: List[Dict[str, Any]],
# ):
# for schema in expected_schemas:
# if all(key in input_row for key in schema):
# return
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")
# raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")