supported models wip

This commit is contained in:
Xi Yan 2024-09-21 18:37:22 -07:00
parent 20a4302877
commit c0199029e5
10 changed files with 215 additions and 34 deletions

View file

@ -51,6 +51,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import (
instantiate_builtin_provider,
instantiate_provider,
instantiate_router,
)
@ -306,15 +307,6 @@ async def resolve_impls_with_routing(
api = Api(api_str)
providers = all_providers[api]
# check for regular providers without routing
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
if provider_map_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_map_entry.provider_id]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
@ -323,6 +315,19 @@ async def resolve_impls_with_routing(
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
)
else:
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_id]
sorted_specs = topological_sort(specs.values())
@ -338,7 +343,7 @@ async def resolve_impls_with_routing(
spec, api.value, stack_run_config.provider_routing_table
)
else:
raise ValueError(f"Cannot find provider_config for Api {api.value}")
impl = await instantiate_builtin_provider(spec, stack_run_config)
impls[api] = impl
return impls, specs