diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 7d62dc76c..deee16ae8 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Tuple +from typing import Any, Dict, List, Tuple from llama_stack.distribution.datatypes import Api @@ -12,6 +12,10 @@ from .routing_table import RoutingTable from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 +from types import MethodType + +from termcolor import cprint + class MemoryRouter(Memory): """Routes to an provider based on the memory bank type""" @@ -84,9 +88,10 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - provider_impl = self.routing_table.get_provider_impl(self.api, model) - print("InferenceRouter: chat_completion", provider_impl) - async for chunk in provider_impl.chat_completion( + # TODO: we need to fix streaming response to align provider implementations with Protocol + async for chunk in self.routing_table.get_provider_impl( + self.api, model + ).chat_completion( model=model, messages=messages, sampling_params=sampling_params, @@ -97,3 +102,29 @@ class InferenceRouter(Inference): logprobs=logprobs, ): yield chunk + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + return await self.routing_table.get_provider_impl(self.api, model).completion( + model=model, + content=content, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ) + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + return await self.routing_table.get_provider_impl(self.api, model).embeddings( + model=model, + contents=contents, + ) diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 4264b8984..a400011d3 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -20,14 +20,14 @@ provider_routing_table: torch_seed: null max_seq_len: 4096 max_batch_size: 1 - - routing_key: Meta-Llama3.1-8B - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 + # - routing_key: Meta-Llama3.1-8B + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 # memory: # - routing_key: keyvalue # provider_id: remote::pgvector