mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Rename Distribution -> DistributionSpec, simplify RemoteProviders
This commit is contained in:
parent
0a67f3d3e6
commit
7cc0445517
9 changed files with 181 additions and 147 deletions
|
@ -36,11 +36,11 @@ from fastapi.routing import APIRoute
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
|
||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
from .dynamic import instantiate_client, instantiate_provider
|
||||
|
||||
from .registry import resolve_distribution
|
||||
from .registry import resolve_distribution_spec
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(dist.provider_specs.values())
|
||||
|
||||
|
@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
|||
provider_config = provider_configs[api.value]
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
impls[api] = instantiate_client(
|
||||
provider_spec, provider_spec.base_url.rstrip("/")
|
||||
provider_spec, provider_config.base_url.rstrip("/")
|
||||
)
|
||||
else:
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
|
@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
|||
return impls
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
dist = resolve_distribution(dist_name)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution {dist_name}")
|
||||
|
||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
spec = config["spec"]
|
||||
dist = resolve_distribution_spec(spec)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution specification `{spec}`")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
|
@ -293,14 +292,15 @@ def main(
|
|||
for provider_spec in dist.provider_specs.values():
|
||||
api = provider_spec.api
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
for endpoint in endpoints:
|
||||
url = provider_spec.base_url.rstrip("/") + endpoint.route
|
||||
url = impl.base_url + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
impl = impls[api]
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue