router table registration works

This commit is contained in:
Xi Yan 2024-09-21 14:26:48 -07:00
parent 85d927adde
commit 951cc9d7b7
4 changed files with 91 additions and 27 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import Any, AsyncGenerator, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
@ -57,17 +57,20 @@ class MemoryRouter(Memory):
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.inference.value
self.routing_table = routing_table
async def initialize(self) -> None:
pass
await self.routing_table.initialize(self.api)
async def shutdown(self) -> None:
pass
await self.routing_table.shutdown(self.api)
async def chat_completion(
self,
@ -80,5 +83,17 @@ class InferenceRouter(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
print("Inference Router: chat_completion")
) -> AsyncGenerator:
provider_impl = self.routing_table.get_provider_impl(self.api, model)
print("InferenceRouter: chat_completion", provider_impl)
async for chunk in provider_impl.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
):
yield chunk