mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
backward compatibility
This commit is contained in:
parent
756e98cbd8
commit
6a95edc806
2 changed files with 82 additions and 59 deletions
|
@ -287,60 +287,60 @@ def snake_to_camel(snake_str):
|
||||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_with_routing(
|
# async def resolve_impls_with_routing(
|
||||||
stack_run_config: StackRunConfig,
|
# stack_run_config: StackRunConfig,
|
||||||
) -> Dict[Api, Any]:
|
# ) -> Dict[Api, Any]:
|
||||||
|
|
||||||
all_providers = api_providers()
|
# all_providers = api_providers()
|
||||||
specs = {}
|
# specs = {}
|
||||||
|
|
||||||
for api_str in stack_run_config.apis_to_serve:
|
# for api_str in stack_run_config.apis_to_serve:
|
||||||
api = Api(api_str)
|
# api = Api(api_str)
|
||||||
providers = all_providers[api]
|
# providers = all_providers[api]
|
||||||
|
|
||||||
# check for regular providers without routing
|
# # check for regular providers without routing
|
||||||
if api_str in stack_run_config.provider_map:
|
# if api_str in stack_run_config.provider_map:
|
||||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
# provider_map_entry = stack_run_config.provider_map[api_str]
|
||||||
if provider_map_entry.provider_id not in providers:
|
# if provider_map_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
# f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
)
|
# )
|
||||||
specs[api] = providers[provider_map_entry.provider_id]
|
# specs[api] = providers[provider_map_entry.provider_id]
|
||||||
|
|
||||||
# check for routing table, we need to pass routing table to the router implementation
|
# # check for routing table, we need to pass routing table to the router implementation
|
||||||
if api_str in stack_run_config.provider_routing_table:
|
# if api_str in stack_run_config.provider_routing_table:
|
||||||
router_entry = stack_run_config.provider_routing_table[api_str]
|
# router_entry = stack_run_config.provider_routing_table[api_str]
|
||||||
inner_specs = []
|
# inner_specs = []
|
||||||
for rt_entry in router_entry:
|
# for rt_entry in router_entry:
|
||||||
if rt_entry.provider_id not in providers:
|
# if rt_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
# f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
)
|
# )
|
||||||
inner_specs.append(providers[rt_entry.provider_id])
|
# inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
|
||||||
specs[api] = RouterProviderSpec(
|
# specs[api] = RouterProviderSpec(
|
||||||
api=api,
|
# api=api,
|
||||||
module=f"llama_stack.distribution.routers.{api.value.lower()}",
|
# module=f"llama_stack.distribution.routers.{api.value.lower()}",
|
||||||
api_dependencies=[],
|
# api_dependencies=[],
|
||||||
inner_specs=inner_specs,
|
# inner_specs=inner_specs,
|
||||||
)
|
# )
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
# sorted_specs = topological_sort(specs.values())
|
||||||
|
|
||||||
impls = {}
|
# impls = {}
|
||||||
for spec in sorted_specs:
|
# for spec in sorted_specs:
|
||||||
api = spec.api
|
# api = spec.api
|
||||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
# deps = {api: impls[api] for api in spec.api_dependencies}
|
||||||
if api.value in stack_run_config.provider_map:
|
# if api.value in stack_run_config.provider_map:
|
||||||
provider_config = stack_run_config.provider_map[api.value]
|
# provider_config = stack_run_config.provider_map[api.value]
|
||||||
elif api.value in stack_run_config.provider_routing_table:
|
# elif api.value in stack_run_config.provider_routing_table:
|
||||||
provider_config = stack_run_config.provider_routing_table[api.value]
|
# provider_config = stack_run_config.provider_routing_table[api.value]
|
||||||
else:
|
# else:
|
||||||
raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
# raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
||||||
impl = await instantiate_provider(spec, deps, provider_config)
|
# impl = await instantiate_provider(spec, deps, provider_config)
|
||||||
impls[api] = impl
|
# impls[api] = impl
|
||||||
|
|
||||||
return impls, specs
|
# return impls, specs
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
|
@ -352,12 +352,10 @@ async def resolve_impls(
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for api_str, item in provider_map.items():
|
for api_str, item in provider_map.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
providers = all_providers[api]
|
providers = all_providers[api]
|
||||||
|
|
||||||
if isinstance(item, GenericProviderConfig):
|
if isinstance(item, GenericProviderConfig):
|
||||||
if item.provider_id not in providers:
|
if item.provider_id not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -365,31 +363,39 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
specs[api] = providers[item.provider_id]
|
specs[api] = providers[item.provider_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
assert isinstance(item, list)
|
||||||
f"Please define routing table in provider_routing_table of run config"
|
inner_specs = []
|
||||||
|
for rt_entry in item:
|
||||||
|
if rt_entry.provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
|
||||||
|
specs[api] = RouterProviderSpec(
|
||||||
|
api=api,
|
||||||
|
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||||
|
api_dependencies=[],
|
||||||
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
api = spec.api
|
api = spec.api
|
||||||
|
|
||||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls, specs
|
return impls, specs
|
||||||
|
|
||||||
|
|
||||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
# This runs
|
||||||
with open(yaml_config, "r") as fp:
|
def run_main_DEPRECATED(
|
||||||
config = StackRunConfig(**yaml.safe_load(fp))
|
config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False
|
||||||
|
):
|
||||||
cprint(f"StackRunConfig: {config}", "blue")
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
@ -449,5 +455,21 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
uvicorn.run(app, host=listen_host, port=port)
|
uvicorn.run(app, host=listen_host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
raise ValueError("Not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
with open(yaml_config, "r") as fp:
|
||||||
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
|
cprint(f"StackRunConfig: {config}", "blue")
|
||||||
|
|
||||||
|
if not config.provider_routing_table:
|
||||||
|
run_main_DEPRECATED(config, port, disable_ipv6)
|
||||||
|
else:
|
||||||
|
run_main(config, port, disable_ipv6)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(main)
|
fire.Fire(main)
|
||||||
|
|
|
@ -36,3 +36,4 @@ provider_map:
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
provider_routing_table: {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue