mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fixes tests & move braintrust api_keys to request headers (#535)
# What does this PR do? - braintrust scoring provider requires OPENAI_API_KEY env variable to be set - move this to be able to be set as request headers (e.g. like together / fireworks api keys) - fixes pytest with agents dependency ## Test Plan **E2E** ``` llama stack run ``` ```yaml scoring: - provider_id: braintrust-0 provider_type: inline::braintrust config: {} ``` **Client** ```python self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"), provider_data={ "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), }, ) ``` - run `llama-stack-client eval run_scoring` **Unit Test** ``` pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py ``` ``` pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py --env OPENAI_API_KEY=$OPENAI_API_KEY ``` <img width="745" alt="image" src="https://github.com/user-attachments/assets/68f5cdda-f6c8-496d-8b4f-1b3dabeca9c2"> ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
d3956a1d22
commit
50cc165077
8 changed files with 72 additions and 8 deletions
|
@ -35,7 +35,7 @@ class NeedsRequestProviderData:
|
||||||
provider_data = validator(**val)
|
provider_data = validator(**val)
|
||||||
return provider_data
|
return provider_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("Error parsing provider data", e)
|
log.error(f"Error parsing provider data: {e}")
|
||||||
|
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str]):
|
def set_request_provider_data(headers: Dict[str, str]):
|
||||||
|
|
|
@ -6,10 +6,15 @@
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustProviderDataValidator(BaseModel):
|
||||||
|
openai_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: BraintrustScoringConfig,
|
config: BraintrustScoringConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, ProviderSpec],
|
||||||
|
|
|
@ -12,9 +12,11 @@ from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
import os
|
||||||
|
|
||||||
from autoevals.llm import Factuality
|
from autoevals.llm import Factuality
|
||||||
from autoevals.ragas import AnswerCorrectness
|
from autoevals.ragas import AnswerCorrectness
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||||
|
@ -24,7 +26,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||||
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class BraintrustScoringImpl(
|
||||||
|
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
|
||||||
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BraintrustScoringConfig,
|
config: BraintrustScoringConfig,
|
||||||
|
@ -79,12 +83,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def set_api_key(self) -> None:
|
||||||
|
# api key is in the request headers
|
||||||
|
if self.config.openai_api_key is None:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.openai_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
self.config.openai_api_key = provider_data.openai_api_key
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
|
await self.set_api_key()
|
||||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
@ -105,6 +122,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score_row(
|
async def score_row(
|
||||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
|
await self.set_api_key()
|
||||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
@ -118,6 +136,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score(
|
async def score(
|
||||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
|
await self.set_api_key()
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||||
|
|
|
@ -6,4 +6,8 @@
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringConfig(BaseModel): ...
|
class BraintrustScoringConfig(BaseModel):
|
||||||
|
openai_api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The OpenAI API Key",
|
||||||
|
)
|
||||||
|
|
|
@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
],
|
],
|
||||||
|
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,10 +6,14 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..agents.fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
|
from ..safety.fixtures import SAFETY_FIXTURES
|
||||||
from ..scoring.fixtures import SCORING_FIXTURES
|
from ..scoring.fixtures import SCORING_FIXTURES
|
||||||
from .fixtures import EVAL_FIXTURES
|
from .fixtures import EVAL_FIXTURES
|
||||||
|
|
||||||
|
@ -20,6 +24,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"scoring": "basic",
|
"scoring": "basic",
|
||||||
"datasetio": "localfs",
|
"datasetio": "localfs",
|
||||||
"inference": "fireworks",
|
"inference": "fireworks",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
},
|
},
|
||||||
id="meta_reference_eval_fireworks_inference",
|
id="meta_reference_eval_fireworks_inference",
|
||||||
marks=pytest.mark.meta_reference_eval_fireworks_inference,
|
marks=pytest.mark.meta_reference_eval_fireworks_inference,
|
||||||
|
@ -30,6 +37,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"scoring": "basic",
|
"scoring": "basic",
|
||||||
"datasetio": "localfs",
|
"datasetio": "localfs",
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
},
|
},
|
||||||
id="meta_reference_eval_together_inference",
|
id="meta_reference_eval_together_inference",
|
||||||
marks=pytest.mark.meta_reference_eval_together_inference,
|
marks=pytest.mark.meta_reference_eval_together_inference,
|
||||||
|
@ -40,6 +50,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"scoring": "basic",
|
"scoring": "basic",
|
||||||
"datasetio": "huggingface",
|
"datasetio": "huggingface",
|
||||||
"inference": "together",
|
"inference": "together",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
},
|
},
|
||||||
id="meta_reference_eval_together_inference_huggingface_datasetio",
|
id="meta_reference_eval_together_inference_huggingface_datasetio",
|
||||||
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
|
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
|
||||||
|
@ -75,6 +88,9 @@ def pytest_generate_tests(metafunc):
|
||||||
"scoring": SCORING_FIXTURES,
|
"scoring": SCORING_FIXTURES,
|
||||||
"datasetio": DATASETIO_FIXTURES,
|
"datasetio": DATASETIO_FIXTURES,
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"agents": AGENTS_FIXTURES,
|
||||||
|
"safety": SAFETY_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
|
|
@ -40,14 +40,30 @@ async def eval_stack(request):
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
for key in ["datasetio", "eval", "scoring", "inference"]:
|
for key in [
|
||||||
|
"datasetio",
|
||||||
|
"eval",
|
||||||
|
"scoring",
|
||||||
|
"inference",
|
||||||
|
"agents",
|
||||||
|
"safety",
|
||||||
|
"memory",
|
||||||
|
]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = fixture.providers
|
providers[key] = fixture.providers
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
|
[
|
||||||
|
Api.eval,
|
||||||
|
Api.datasetio,
|
||||||
|
Api.inference,
|
||||||
|
Api.scoring,
|
||||||
|
Api.agents,
|
||||||
|
Api.safety,
|
||||||
|
Api.memory,
|
||||||
|
],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,9 +10,10 @@ import pytest_asyncio
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -40,7 +41,9 @@ def scoring_braintrust() -> ProviderFixture:
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_type="inline::braintrust",
|
provider_type="inline::braintrust",
|
||||||
config={},
|
config=BraintrustScoringConfig(
|
||||||
|
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
|
||||||
|
).model_dump(),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue