routers for inference chat_completion with models dependency

This commit is contained in:
Xi Yan 2024-09-19 20:59:32 -07:00
parent 47be4c7222
commit d2ec822b12
5 changed files with 113 additions and 15 deletions

View file

@ -198,7 +198,7 @@ class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry], str]
@json_schema_type

View file

@ -297,6 +297,13 @@ async def resolve_impls(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
elif isinstance(item, str) and item == "models-router":
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
api_dependencies=[Api.models],
inner_specs=[],
)
else:
assert isinstance(item, list)
inner_specs = []
@ -314,6 +321,10 @@ async def resolve_impls(
inner_specs=inner_specs,
)
for k, v in specs.items():
cprint(k, "blue")
cprint(v, "blue")
sorted_specs = topological_sort(specs.values())
impls = {}
@ -333,9 +344,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
print(config)
impls, specs = asyncio.run(resolve_impls(config.provider_map))
print(impls)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -38,19 +38,24 @@ async def instantiate_provider(
elif isinstance(provider_spec, RouterProviderSpec):
method = "get_router_impl"
assert isinstance(provider_config, list)
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
if isinstance(provider_config, list):
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [inner_impls, deps]
config = None
args = [inner_impls, deps]
elif isinstance(provider_config, str) and provider_config == "models-router":
config = None
args = [[], deps]
else:
raise ValueError(f"provider_config {provider_config} is not valid")
else:
method = "get_provider_impl"

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
from .inference import InferenceRouterImpl
impl = InferenceRouterImpl(inner_impls, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.registry.inference import available_providers
class InferenceRouterImpl(Inference):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
) -> None:
self.inner_impls = inner_impls
self.deps = deps
print("INIT INFERENCE ROUTER!")
# self.providers = {}
# for routing_key, provider_impl in inner_impls:
# self.providers[routing_key] = provider_impl
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
# for p in self.providers.values():
# await p.shutdown()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
print("router chat_completion")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta="router chat completion",
)
)
# async for chunk in self.providers[model].chat_completion(
# model=model,
# messages=messages,
# sampling_params=sampling_params,
# tools=tools,
# tool_choice=tool_choice,
# tool_prompt_format=tool_prompt_format,
# stream=stream,
# logprobs=logprobs,
# ):
# yield chunk