forked from phoenix-oss/llama-stack-mirror
fixes tests & move braintrust api_keys to request headers (#535)
# What does this PR do? - braintrust scoring provider requires OPENAI_API_KEY env variable to be set - move this to be able to be set as request headers (e.g. like together / fireworks api keys) - fixes pytest with agents dependency ## Test Plan **E2E** ``` llama stack run ``` ```yaml scoring: - provider_id: braintrust-0 provider_type: inline::braintrust config: {} ``` **Client** ```python self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"), provider_data={ "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), }, ) ``` - run `llama-stack-client eval run_scoring` **Unit Test** ``` pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py ``` ``` pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py --env OPENAI_API_KEY=$OPENAI_API_KEY ``` <img width="745" alt="image" src="https://github.com/user-attachments/assets/68f5cdda-f6c8-496d-8b4f-1b3dabeca9c2"> ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
d3956a1d22
commit
50cc165077
8 changed files with 72 additions and 8 deletions
|
@ -6,10 +6,15 @@
|
|||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
||||
class BraintrustProviderDataValidator(BaseModel):
|
||||
openai_api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
|
|
|
@ -12,9 +12,11 @@ from llama_stack.apis.common.type_system import * # noqa: F403
|
|||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||
import os
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||
|
@ -24,7 +26,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
|||
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||
|
||||
|
||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
class BraintrustScoringImpl(
|
||||
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: BraintrustScoringConfig,
|
||||
|
@ -79,12 +83,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def set_api_key(self) -> None:
|
||||
# api key is in the request headers
|
||||
if self.config.openai_api_key is None:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.openai_api_key:
|
||||
raise ValueError(
|
||||
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
|
||||
)
|
||||
self.config.openai_api_key = provider_data.openai_api_key
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.set_api_key()
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -105,6 +122,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
@ -118,6 +136,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
await self.set_api_key()
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
|
|
|
@ -6,4 +6,8 @@
|
|||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel): ...
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The OpenAI API Key",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue