diff --git a/llama_stack/providers/inline/scoring/braintrust/__init__.py b/llama_stack/providers/inline/scoring/braintrust/__init__.py index f442a6c3b..dc4ea4951 100644 --- a/llama_stack/providers/inline/scoring/braintrust/__init__.py +++ b/llama_stack/providers/inline/scoring/braintrust/__init__.py @@ -6,10 +6,15 @@ from typing import Dict from llama_stack.distribution.datatypes import Api, ProviderSpec +from pydantic import BaseModel from .config import BraintrustScoringConfig +class BraintrustProviderDataValidator(BaseModel): + openai_api_key: str + + async def get_provider_impl( config: BraintrustScoringConfig, deps: Dict[Api, ProviderSpec], diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 00817bb33..cf6e22a29 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -12,9 +12,11 @@ from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 -# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn +import os + from autoevals.llm import Factuality from autoevals.ragas import AnswerCorrectness +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average @@ -24,7 +26,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def -class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): +class BraintrustScoringImpl( + Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData +): def __init__( self, config: BraintrustScoringConfig, @@ -79,12 +83,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." ) + async def set_api_key(self) -> None: + # api key is in the request headers + if self.config.openai_api_key is None: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.openai_api_key: + raise ValueError( + 'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": }' + ) + self.config.openai_api_key = provider_data.openai_api_key + + os.environ["OPENAI_API_KEY"] = self.config.openai_api_key + async def score_batch( self, dataset_id: str, scoring_functions: List[str], save_results_dataset: bool = False, ) -> ScoreBatchResponse: + await self.set_api_key() await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, @@ -105,6 +122,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_row( self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None ) -> ScoringResultRow: + await self.set_api_key() assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] @@ -118,6 +136,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: + await self.set_api_key() res = {} for scoring_fn_id in scoring_functions: if scoring_fn_id not in self.supported_fn_defs_registry: diff --git a/llama_stack/providers/inline/scoring/braintrust/config.py b/llama_stack/providers/inline/scoring/braintrust/config.py index fef6df5c8..fae0b17eb 100644 --- a/llama_stack/providers/inline/scoring/braintrust/config.py +++ b/llama_stack/providers/inline/scoring/braintrust/config.py @@ -6,4 +6,8 @@ from llama_stack.apis.scoring import * # noqa: F401, F403 -class BraintrustScoringConfig(BaseModel): ... +class BraintrustScoringConfig(BaseModel): + openai_api_key: Optional[str] = Field( + default=None, + description="The OpenAI API Key", + ) diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index 2da9797bc..f31ff44d7 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]: Api.datasetio, Api.datasets, ], + provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator", ), ]