mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
refactor a method out
This commit is contained in:
parent
125fdb1b2a
commit
db3e6dda07
1 changed files with 27 additions and 18 deletions
|
@ -36,7 +36,7 @@ from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .datatypes import Adapter, Api, PassthroughApiAdapter
|
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
|
||||||
from .distribution import api_endpoints
|
from .distribution import api_endpoints
|
||||||
from .dynamic import instantiate_adapter, instantiate_client
|
from .dynamic import instantiate_adapter, instantiate_client
|
||||||
|
|
||||||
|
@ -250,6 +250,29 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||||
|
adapter_configs = config["adapters"]
|
||||||
|
adapters = topological_sort(dist.adapters.values())
|
||||||
|
|
||||||
|
impls = {}
|
||||||
|
for adapter in adapters:
|
||||||
|
api = adapter.api
|
||||||
|
if api.value not in adapter_configs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find adapter config for {api}. Please add it to the config"
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_config = adapter_configs[api.value]
|
||||||
|
if isinstance(adapter, PassthroughApiAdapter):
|
||||||
|
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
||||||
|
else:
|
||||||
|
deps = {api: impls[api] for api in adapter.adapter_dependencies}
|
||||||
|
impl = instantiate_adapter(adapter, adapter_config, deps)
|
||||||
|
impls[api] = impl
|
||||||
|
|
||||||
|
return impls
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||||
):
|
):
|
||||||
|
@ -263,21 +286,10 @@ def main(
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
|
impls = resolve_impls(dist, config)
|
||||||
|
|
||||||
adapter_configs = config["adapters"]
|
for adapter in dist.adapters.values():
|
||||||
adapters = topological_sort(dist.adapters.values())
|
|
||||||
|
|
||||||
# TODO: split this into two parts, first you resolve all impls
|
|
||||||
# and then you create the routes.
|
|
||||||
impls = {}
|
|
||||||
for adapter in adapters:
|
|
||||||
api = adapter.api
|
api = adapter.api
|
||||||
if api.value not in adapter_configs:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not find adapter config for {api}. Please add it to the config"
|
|
||||||
)
|
|
||||||
|
|
||||||
adapter_config = adapter_configs[api.value]
|
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
if isinstance(adapter, PassthroughApiAdapter):
|
if isinstance(adapter, PassthroughApiAdapter):
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
|
@ -285,11 +297,8 @@ def main(
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
create_dynamic_passthrough(url)
|
create_dynamic_passthrough(url)
|
||||||
)
|
)
|
||||||
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
|
||||||
else:
|
else:
|
||||||
deps = {api: impls[api] for api in adapter.adapter_dependencies}
|
impl = impls[api]
|
||||||
impl = instantiate_adapter(adapter, adapter_config, deps)
|
|
||||||
impls[api] = impl
|
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
if not hasattr(impl, endpoint.name):
|
if not hasattr(impl, endpoint.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue