mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 12:19:47 +00:00
api keys refactor
This commit is contained in:
parent
d3956a1d22
commit
e01d6d793c
4 changed files with 32 additions and 3 deletions
|
|
@ -6,10 +6,15 @@
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustProviderDataValidator(BaseModel):
|
||||||
|
openai_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: BraintrustScoringConfig,
|
config: BraintrustScoringConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, ProviderSpec],
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,11 @@ from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
import os
|
||||||
|
|
||||||
from autoevals.llm import Factuality
|
from autoevals.llm import Factuality
|
||||||
from autoevals.ragas import AnswerCorrectness
|
from autoevals.ragas import AnswerCorrectness
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||||
|
|
@ -24,7 +26,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||||
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class BraintrustScoringImpl(
|
||||||
|
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
|
||||||
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BraintrustScoringConfig,
|
config: BraintrustScoringConfig,
|
||||||
|
|
@ -79,12 +83,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def set_api_key(self) -> None:
|
||||||
|
# api key is in the request headers
|
||||||
|
if self.config.openai_api_key is None:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.openai_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
self.config.openai_api_key = provider_data.openai_api_key
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
|
await self.set_api_key()
|
||||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
|
@ -105,6 +122,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score_row(
|
async def score_row(
|
||||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
|
await self.set_api_key()
|
||||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
@ -118,6 +136,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score(
|
async def score(
|
||||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
|
await self.set_api_key()
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,8 @@
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringConfig(BaseModel): ...
|
class BraintrustScoringConfig(BaseModel):
|
||||||
|
openai_api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The OpenAI API Key",
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
],
|
],
|
||||||
|
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue