Make Safety test work, other cleanup

This commit is contained in:
Ashwin Bharambe 2024-10-09 21:09:50 -07:00
parent ba1f294cc6
commit fcd22b6baa
16 changed files with 229 additions and 123 deletions

View file

@ -7,6 +7,7 @@
import json
import os
from datetime import datetime
from typing import Any, Dict, List
import yaml
@ -16,9 +17,7 @@ from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
):
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
@ -27,15 +26,69 @@ async def resolve_impls_for_test(
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
providers = read_providers(api, config_dict)
chosen = choose_providers(providers, api, deps)
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api] + (deps or []),
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
providers = config_dict["providers"]
if isinstance(providers, dict):
return providers
elif isinstance(providers, list):
return {
api.value: providers,
}
else:
raise ValueError(
"Config file should contain a list of providers or dict(api to providers)"
)
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
def choose_providers(
providers: Dict[str, Any], api: Api, deps: List[Api] = None
) -> Dict[str, Provider]:
chosen = {}
if api.value not in providers:
raise ValueError(f"No providers found for `{api}`?")
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
for dep in deps or []:
if dep.value not in providers:
raise ValueError(f"No providers specified for `{dep}` in config?")
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
return chosen
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
providers_by_id = {x["provider_id"]: x for x in providers}
if len(providers_by_id) == 0:
raise ValueError(f"No providers found for `{api}` in config file")
if key in os.environ:
provider_id = os.environ[key]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
@ -44,20 +97,4 @@ async def resolve_impls_for_test(
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
return Provider(**provider)