mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
braintrust provider
This commit is contained in:
parent
68346fac39
commit
38186f7903
6 changed files with 116 additions and 49 deletions
|
@ -82,7 +82,7 @@ async def register_dataset(
|
|||
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier=dataset_id,
|
||||
provider_id=os.environ["PROVIDER_ID"] or os.environ["DATASETIO_PROVIDER_ID"],
|
||||
provider_id=os.environ["DATASETIO_PROVIDER_ID"] or os.environ["PROVIDER_ID"],
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
|
|
|
@ -141,16 +141,14 @@ async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
|||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
if provider_type not in ("meta-reference"):
|
||||
pytest.skip(
|
||||
"Other scoring providers don't support registering scoring functions."
|
||||
)
|
||||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=list(provider_scoring_functions[provider_type]),
|
||||
)
|
||||
|
||||
print("RESPONSE", response)
|
||||
|
||||
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in response.results
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue