mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
routers for inference chat_completion with models dependency
This commit is contained in:
parent
47be4c7222
commit
d2ec822b12
5 changed files with 113 additions and 15 deletions
|
@ -198,7 +198,7 @@ class ProviderRoutingEntry(GenericProviderConfig):
|
||||||
routing_key: str
|
routing_key: str
|
||||||
|
|
||||||
|
|
||||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry], str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -297,6 +297,13 @@ async def resolve_impls(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
specs[api] = providers[item.provider_id]
|
specs[api] = providers[item.provider_id]
|
||||||
|
elif isinstance(item, str) and item == "models-router":
|
||||||
|
specs[api] = RouterProviderSpec(
|
||||||
|
api=api,
|
||||||
|
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||||
|
api_dependencies=[Api.models],
|
||||||
|
inner_specs=[],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(item, list)
|
assert isinstance(item, list)
|
||||||
inner_specs = []
|
inner_specs = []
|
||||||
|
@ -314,6 +321,10 @@ async def resolve_impls(
|
||||||
inner_specs=inner_specs,
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for k, v in specs.items():
|
||||||
|
cprint(k, "blue")
|
||||||
|
cprint(v, "blue")
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
|
@ -333,9 +344,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
print(config)
|
|
||||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
print(impls)
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -38,19 +38,24 @@ async def instantiate_provider(
|
||||||
elif isinstance(provider_spec, RouterProviderSpec):
|
elif isinstance(provider_spec, RouterProviderSpec):
|
||||||
method = "get_router_impl"
|
method = "get_router_impl"
|
||||||
|
|
||||||
assert isinstance(provider_config, list)
|
if isinstance(provider_config, list):
|
||||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||||
inner_impls = []
|
inner_impls = []
|
||||||
for routing_entry in provider_config:
|
for routing_entry in provider_config:
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(
|
||||||
inner_specs[routing_entry.provider_id],
|
inner_specs[routing_entry.provider_id],
|
||||||
deps,
|
deps,
|
||||||
routing_entry,
|
routing_entry,
|
||||||
)
|
)
|
||||||
inner_impls.append((routing_entry.routing_key, impl))
|
inner_impls.append((routing_entry.routing_key, impl))
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [inner_impls, deps]
|
args = [inner_impls, deps]
|
||||||
|
elif isinstance(provider_config, str) and provider_config == "models-router":
|
||||||
|
config = None
|
||||||
|
args = [[], deps]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"provider_config {provider_config} is not valid")
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
|
17
llama_stack/providers/routers/inference/__init__.py
Normal file
17
llama_stack/providers/routers/inference/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
||||||
|
from .inference import InferenceRouterImpl
|
||||||
|
|
||||||
|
impl = InferenceRouterImpl(inner_impls, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
67
llama_stack/providers/routers/inference/inference.py
Normal file
67
llama_stack/providers/routers/inference/inference.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.providers.registry.inference import available_providers
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceRouterImpl(Inference):
|
||||||
|
"""Routes to an provider based on the memory bank type"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_impls: List[Tuple[str, Any]],
|
||||||
|
deps: List[Api],
|
||||||
|
) -> None:
|
||||||
|
self.inner_impls = inner_impls
|
||||||
|
self.deps = deps
|
||||||
|
print("INIT INFERENCE ROUTER!")
|
||||||
|
|
||||||
|
# self.providers = {}
|
||||||
|
# for routing_key, provider_impl in inner_impls:
|
||||||
|
# self.providers[routing_key] = provider_impl
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
# for p in self.providers.values():
|
||||||
|
# await p.shutdown()
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
print("router chat_completion")
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta="router chat completion",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# async for chunk in self.providers[model].chat_completion(
|
||||||
|
# model=model,
|
||||||
|
# messages=messages,
|
||||||
|
# sampling_params=sampling_params,
|
||||||
|
# tools=tools,
|
||||||
|
# tool_choice=tool_choice,
|
||||||
|
# tool_prompt_format=tool_prompt_format,
|
||||||
|
# stream=stream,
|
||||||
|
# logprobs=logprobs,
|
||||||
|
# ):
|
||||||
|
# yield chunk
|
Loading…
Add table
Add a link
Reference in a new issue