mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
braintrust skeleton
This commit is contained in:
parent
caf6a266e0
commit
d3d2243dfb
5 changed files with 163 additions and 0 deletions
21
llama_stack/providers/impls/braintrust/scoring/__init__.py
Normal file
21
llama_stack/providers/impls/braintrust/scoring/__init__.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: BraintrustScoringConfig,
|
||||||
|
deps: Dict[Api, ProviderSpec],
|
||||||
|
):
|
||||||
|
from .braintrust import BraintrustScoringImpl
|
||||||
|
|
||||||
|
impl = BraintrustScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
119
llama_stack/providers/impls/braintrust/scoring/braintrust.py
Normal file
119
llama_stack/providers/impls/braintrust/scoring/braintrust.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BraintrustScoringConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
self.scoring_fn_id_impls = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
# for x in FIXED_FNS:
|
||||||
|
# impl = x()
|
||||||
|
# await impl.initialize()
|
||||||
|
# for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
|
# self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
|
# for x in LLM_JUDGE_FNS:
|
||||||
|
# impl = x(inference_api=self.inference_api)
|
||||||
|
# await impl.initialize()
|
||||||
|
# for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
|
# self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
|
# self.llm_as_judge_fn = impl
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
|
return []
|
||||||
|
# return [
|
||||||
|
# fn_defs
|
||||||
|
# for impl in self.scoring_fn_id_impls.values()
|
||||||
|
# for fn_defs in impl.get_supported_scoring_fn_defs()
|
||||||
|
# ]
|
||||||
|
|
||||||
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
|
# self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||||
|
# self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||||
|
if required_column not in dataset_def.dataset_schema:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||||
|
)
|
||||||
|
if dataset_def.dataset_schema[required_column].type != "string":
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
print("score_batch")
|
||||||
|
# await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
|
# all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
|
# dataset_id=dataset_id,
|
||||||
|
# rows_in_page=-1,
|
||||||
|
# )
|
||||||
|
# res = await self.score(
|
||||||
|
# input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||||
|
# )
|
||||||
|
# if save_results_dataset:
|
||||||
|
# # TODO: persist and register dataset on to server for reading
|
||||||
|
# # self.datasets_api.register_dataset()
|
||||||
|
# raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
|
return ScoreBatchResponse(
|
||||||
|
results=res.results,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||||
|
) -> ScoreResponse:
|
||||||
|
res = {}
|
||||||
|
print("score")
|
||||||
|
# for scoring_fn_id in scoring_functions:
|
||||||
|
# if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||||
|
# raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
|
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
|
# score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||||
|
# agg_results = await scoring_fn.aggregate(score_results)
|
||||||
|
# res[scoring_fn_id] = ScoringResult(
|
||||||
|
# score_rows=score_results,
|
||||||
|
# aggregated_results=agg_results,
|
||||||
|
# )
|
||||||
|
|
||||||
|
return ScoreResponse(
|
||||||
|
results=res,
|
||||||
|
)
|
9
llama_stack/providers/impls/braintrust/scoring/config.py
Normal file
9
llama_stack/providers/impls/braintrust/scoring/config.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustScoringConfig(BaseModel): ...
|
|
@ -23,4 +23,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.inference,
|
Api.inference,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.scoring,
|
||||||
|
provider_type="braintrust",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.impls.braintrust.scoring",
|
||||||
|
config_class="llama_stack.providers.impls.braintrust.scoring.BraintrustScoringConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,6 +7,9 @@ providers:
|
||||||
- provider_id: test-meta
|
- provider_id: test-meta
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
- provider_id: test-braintrust
|
||||||
|
provider_type: braintrust
|
||||||
|
config: {}
|
||||||
inference:
|
inference:
|
||||||
- provider_id: tgi0
|
- provider_id: tgi0
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue