mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Adapter -> Provider
This commit is contained in:
parent
db3e6dda07
commit
65a9e40174
15 changed files with 119 additions and 110 deletions
|
@ -36,9 +36,9 @@ from fastapi.routing import APIRoute
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
|
||||
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
from .dynamic import instantiate_adapter, instantiate_client
|
||||
from .dynamic import instantiate_client, instantiate_provider
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
||||
|
@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any):
|
|||
return endpoint
|
||||
|
||||
|
||||
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
|
||||
by_id = {x.api: x for x in adapters}
|
||||
by_id = {x.api: x for x in providers}
|
||||
|
||||
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
|
||||
if not isinstance(a, PassthroughApiAdapter):
|
||||
for api in a.adapter_dependencies:
|
||||
if not isinstance(a, RemoteProviderSpec):
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
|
||||
|
@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
|||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in adapters:
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
|
||||
|
@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
|||
|
||||
|
||||
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||
adapter_configs = config["adapters"]
|
||||
adapters = topological_sort(dist.adapters.values())
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(dist.provider_specs.values())
|
||||
|
||||
impls = {}
|
||||
for adapter in adapters:
|
||||
api = adapter.api
|
||||
if api.value not in adapter_configs:
|
||||
for provider_spec in provider_specs:
|
||||
api = provider_spec.api
|
||||
if api.value not in provider_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {api}. Please add it to the config"
|
||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
)
|
||||
|
||||
adapter_config = adapter_configs[api.value]
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
||||
provider_config = provider_configs[api.value]
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
impls[api] = instantiate_client(
|
||||
provider_spec, provider_spec.base_url.rstrip("/")
|
||||
)
|
||||
else:
|
||||
deps = {api: impls[api] for api in adapter.adapter_dependencies}
|
||||
impl = instantiate_adapter(adapter, adapter_config, deps)
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
@ -288,12 +290,12 @@ def main(
|
|||
all_endpoints = api_endpoints()
|
||||
impls = resolve_impls(dist, config)
|
||||
|
||||
for adapter in dist.adapters.values():
|
||||
api = adapter.api
|
||||
for provider_spec in dist.provider_specs.values():
|
||||
api = provider_spec.api
|
||||
endpoints = all_endpoints[api]
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
for endpoint in endpoints:
|
||||
url = adapter.base_url.rstrip("/") + endpoint.route
|
||||
url = provider_spec.base_url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue