Revert "backward compatibility"

This reverts commit 6a95edc806.
This commit is contained in:
Xi Yan 2024-09-21 12:42:08 -07:00
parent 3ea55d9b0f
commit 74765cc78f
2 changed files with 59 additions and 82 deletions

View file

@ -287,60 +287,60 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_")) return "".join(word.capitalize() for word in snake_str.split("_"))
# async def resolve_impls_with_routing( async def resolve_impls_with_routing(
# stack_run_config: StackRunConfig, stack_run_config: StackRunConfig,
# ) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
# all_providers = api_providers() all_providers = api_providers()
# specs = {} specs = {}
# for api_str in stack_run_config.apis_to_serve: for api_str in stack_run_config.apis_to_serve:
# api = Api(api_str) api = Api(api_str)
# providers = all_providers[api] providers = all_providers[api]
# # check for regular providers without routing # check for regular providers without routing
# if api_str in stack_run_config.provider_map: if api_str in stack_run_config.provider_map:
# provider_map_entry = stack_run_config.provider_map[api_str] provider_map_entry = stack_run_config.provider_map[api_str]
# if provider_map_entry.provider_id not in providers: if provider_map_entry.provider_id not in providers:
# raise ValueError( raise ValueError(
# f"Unknown provider `{provider_id}` is not available for API `{api}`" f"Unknown provider `{provider_id}` is not available for API `{api}`"
# ) )
# specs[api] = providers[provider_map_entry.provider_id] specs[api] = providers[provider_map_entry.provider_id]
# # check for routing table, we need to pass routing table to the router implementation # check for routing table, we need to pass routing table to the router implementation
# if api_str in stack_run_config.provider_routing_table: if api_str in stack_run_config.provider_routing_table:
# router_entry = stack_run_config.provider_routing_table[api_str] router_entry = stack_run_config.provider_routing_table[api_str]
# inner_specs = [] inner_specs = []
# for rt_entry in router_entry: for rt_entry in router_entry:
# if rt_entry.provider_id not in providers: if rt_entry.provider_id not in providers:
# raise ValueError( raise ValueError(
# f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
# ) )
# inner_specs.append(providers[rt_entry.provider_id]) inner_specs.append(providers[rt_entry.provider_id])
# specs[api] = RouterProviderSpec( specs[api] = RouterProviderSpec(
# api=api, api=api,
# module=f"llama_stack.distribution.routers.{api.value.lower()}", module=f"llama_stack.distribution.routers.{api.value.lower()}",
# api_dependencies=[], api_dependencies=[],
# inner_specs=inner_specs, inner_specs=inner_specs,
# ) )
# sorted_specs = topological_sort(specs.values()) sorted_specs = topological_sort(specs.values())
# impls = {} impls = {}
# for spec in sorted_specs: for spec in sorted_specs:
# api = spec.api api = spec.api
# deps = {api: impls[api] for api in spec.api_dependencies} deps = {api: impls[api] for api in spec.api_dependencies}
# if api.value in stack_run_config.provider_map: if api.value in stack_run_config.provider_map:
# provider_config = stack_run_config.provider_map[api.value] provider_config = stack_run_config.provider_map[api.value]
# elif api.value in stack_run_config.provider_routing_table: elif api.value in stack_run_config.provider_routing_table:
# provider_config = stack_run_config.provider_routing_table[api.value] provider_config = stack_run_config.provider_routing_table[api.value]
# else: else:
# raise ValueError(f"Cannot find provider_config for Api {api.value}") raise ValueError(f"Cannot find provider_config for Api {api.value}")
# impl = await instantiate_provider(spec, deps, provider_config) impl = await instantiate_provider(spec, deps, provider_config)
# impls[api] = impl impls[api] = impl
# return impls, specs return impls, specs
async def resolve_impls( async def resolve_impls(
@ -352,10 +352,12 @@ async def resolve_impls(
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
all_providers = api_providers() all_providers = api_providers()
specs = {} specs = {}
for api_str, item in provider_map.items(): for api_str, item in provider_map.items():
api = Api(api_str) api = Api(api_str)
providers = all_providers[api] providers = all_providers[api]
if isinstance(item, GenericProviderConfig): if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers: if item.provider_id not in providers:
raise ValueError( raise ValueError(
@ -363,39 +365,31 @@ async def resolve_impls(
) )
specs[api] = providers[item.provider_id] specs[api] = providers[item.provider_id]
else: else:
assert isinstance(item, list) raise ValueError(
inner_specs = [] f"Please define routing table in provider_routing_table of run config"
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
api_dependencies=[],
inner_specs=inner_specs,
) )
sorted_specs = topological_sort(specs.values()) sorted_specs = topological_sort(specs.values())
impls = {} impls = {}
for spec in sorted_specs: for spec in sorted_specs:
api = spec.api api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies} deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value]) impl = await instantiate_provider(spec, deps, provider_map[api.value])
impls[api] = impl impls[api] = impl
return impls, specs return impls, specs
# This runs def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
def run_main_DEPRECATED( with open(yaml_config, "r") as fp:
config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False config = StackRunConfig(**yaml.safe_load(fp))
):
cprint(f"StackRunConfig: {config}", "blue")
app = FastAPI() app = FastAPI()
impls, specs = asyncio.run(resolve_impls(config.provider_map)) impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
@ -455,21 +449,5 @@ def run_main_DEPRECATED(
uvicorn.run(app, host=listen_host, port=port) uvicorn.run(app, host=listen_host, port=port)
def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False):
raise ValueError("Not implemented")
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = StackRunConfig(**yaml.safe_load(fp))
cprint(f"StackRunConfig: {config}", "blue")
if not config.provider_routing_table:
run_main_DEPRECATED(config, port, disable_ipv6)
else:
run_main(config, port, disable_ipv6)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)

View file

@ -36,4 +36,3 @@ provider_map:
telemetry: telemetry:
provider_id: meta-reference provider_id: meta-reference
config: {} config: {}
provider_routing_table: {}