Rename Distribution -> DistributionSpec, simplify RemoteProviders

This commit is contained in:
Ashwin Bharambe 2024-08-06 10:45:06 -07:00
parent 0a67f3d3e6
commit 7cc0445517
9 changed files with 181 additions and 147 deletions

View file

@ -36,11 +36,11 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution
from .registry import resolve_distribution_spec
load_dotenv()
@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
provider_spec, provider_config.base_url.rstrip("/")
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
return impls
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
dist = resolve_distribution(dist_name)
if dist is None:
raise ValueError(f"Could not find distribution {dist_name}")
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
@ -293,14 +292,15 @@ def main(
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = provider_spec.base_url.rstrip("/") + endpoint.route
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already