mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 20:50:52 +00:00
routers for inference chat_completion with models dependency
This commit is contained in:
parent
47be4c7222
commit
d2ec822b12
5 changed files with 113 additions and 15 deletions
|
@ -198,7 +198,7 @@ class ProviderRoutingEntry(GenericProviderConfig):
|
|||
routing_key: str
|
||||
|
||||
|
||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry], str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -297,6 +297,13 @@ async def resolve_impls(
|
|||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
elif isinstance(item, str) and item == "models-router":
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||
api_dependencies=[Api.models],
|
||||
inner_specs=[],
|
||||
)
|
||||
else:
|
||||
assert isinstance(item, list)
|
||||
inner_specs = []
|
||||
|
@ -314,6 +321,10 @@ async def resolve_impls(
|
|||
inner_specs=inner_specs,
|
||||
)
|
||||
|
||||
for k, v in specs.items():
|
||||
cprint(k, "blue")
|
||||
cprint(v, "blue")
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
impls = {}
|
||||
|
@ -333,9 +344,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
print(config)
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
print(impls)
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -38,19 +38,24 @@ async def instantiate_provider(
|
|||
elif isinstance(provider_spec, RouterProviderSpec):
|
||||
method = "get_router_impl"
|
||||
|
||||
assert isinstance(provider_config, list)
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in provider_config:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
if isinstance(provider_config, list):
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in provider_config:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
elif isinstance(provider_config, str) and provider_config == "models-router":
|
||||
config = None
|
||||
args = [[], deps]
|
||||
else:
|
||||
raise ValueError(f"provider_config {provider_config} is not valid")
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue