diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 457ab0d3a..5a7d0d64d 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -198,7 +198,7 @@ class ProviderRoutingEntry(GenericProviderConfig): routing_key: str -ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]] +ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry], str] @json_schema_type diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 583a25e1a..5cf299bff 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -297,6 +297,13 @@ async def resolve_impls( f"Unknown provider `{provider_id}` is not available for API `{api}`" ) specs[api] = providers[item.provider_id] + elif isinstance(item, str) and item == "models-router": + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_stack.providers.routers.{api.value.lower()}", + api_dependencies=[Api.models], + inner_specs=[], + ) else: assert isinstance(item, list) inner_specs = [] @@ -314,6 +321,10 @@ async def resolve_impls( inner_specs=inner_specs, ) + for k, v in specs.items(): + cprint(k, "blue") + cprint(v, "blue") + sorted_specs = topological_sort(specs.values()) impls = {} @@ -333,9 +344,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): app = FastAPI() - print(config) impls, specs = asyncio.run(resolve_impls(config.provider_map)) - print(impls) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 002a738ae..048a418d4 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -38,19 +38,24 @@ async def instantiate_provider( elif isinstance(provider_spec, RouterProviderSpec): method = "get_router_impl" - assert isinstance(provider_config, list) - inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} - inner_impls = [] - for routing_entry in provider_config: - impl = await instantiate_provider( - inner_specs[routing_entry.provider_id], - deps, - routing_entry, - ) - inner_impls.append((routing_entry.routing_key, impl)) + if isinstance(provider_config, list): + inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} + inner_impls = [] + for routing_entry in provider_config: + impl = await instantiate_provider( + inner_specs[routing_entry.provider_id], + deps, + routing_entry, + ) + inner_impls.append((routing_entry.routing_key, impl)) - config = None - args = [inner_impls, deps] + config = None + args = [inner_impls, deps] + elif isinstance(provider_config, str) and provider_config == "models-router": + config = None + args = [[], deps] + else: + raise ValueError(f"provider_config {provider_config} is not valid") else: method = "get_provider_impl" diff --git a/llama_stack/providers/routers/inference/__init__.py b/llama_stack/providers/routers/inference/__init__.py new file mode 100644 index 000000000..c6619ffc9 --- /dev/null +++ b/llama_stack/providers/routers/inference/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Tuple + +from llama_stack.distribution.datatypes import Api + + +async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): + from .inference import InferenceRouterImpl + + impl = InferenceRouterImpl(inner_impls, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py new file mode 100644 index 000000000..f029892b0 --- /dev/null +++ b/llama_stack/providers/routers/inference/inference.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, AsyncGenerator, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.registry.inference import available_providers + + +class InferenceRouterImpl(Inference): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + deps: List[Api], + ) -> None: + self.inner_impls = inner_impls + self.deps = deps + print("INIT INFERENCE ROUTER!") + + # self.providers = {} + # for routing_key, provider_impl in inner_impls: + # self.providers[routing_key] = provider_impl + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + # for p in self.providers.values(): + # await p.shutdown() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = list, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + print("router chat_completion") + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta="router chat completion", + ) + ) + # async for chunk in self.providers[model].chat_completion( + # model=model, + # messages=messages, + # sampling_params=sampling_params, + # tools=tools, + # tool_choice=tool_choice, + # tool_prompt_format=tool_prompt_format, + # stream=stream, + # logprobs=logprobs, + # ): + # yield chunk