diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 6660e180c..d0fca8771 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -66,17 +66,16 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dic "tool_runtime": ToolRuntimeRouter, } api_to_deps = { - "inference": [Api.telemetry], + "inference": {"telemetry": Api.telemetry}, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - deps = [] - for dep in api_to_deps.get(api.value, []): - if dep not in _deps: - raise ValueError(f"Dependency {dep} not found in _deps") - deps.append(_deps[dep]) + api_to_dep_impl = {} + for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): + if dep_api in deps: + api_to_dep_impl[dep_name] = deps[dep_api] - impl = api_to_routers[api.value](routing_table, *deps) + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 4620efef7..a9fc13502 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -119,12 +119,13 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, - telemetry: Telemetry, + telemetry: Optional[Telemetry] = None, ) -> None: self.routing_table = routing_table self.telemetry = telemetry - self.tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(self.tokenizer) + if self.telemetry: + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: pass @@ -226,13 +227,13 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) + model_input = self.formatter.encode_dialog_prompt( + messages, + tool_config.tool_prompt_format, + ) if stream: async def stream_generator(): - model_input = self.formatter.encode_dialog_prompt( - messages, - tool_config.tool_prompt_format, - ) prompt_tokens = len(model_input.tokens) if model_input.tokens else 0 completion_text = "" async for chunk in await provider.chat_completion(**params): @@ -255,16 +256,13 @@ class InferenceRouter(Inference): total_tokens=total_tokens, ) ) - await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) + if self.telemetry: + await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) yield chunk return stream_generator() else: response = await provider.chat_completion(**params) - model_input = self.formatter.encode_dialog_prompt( - messages, - tool_config.tool_prompt_format, - ) model_output = self.formatter.encode_dialog_prompt( [response.completion_message], tool_config.tool_prompt_format, @@ -281,7 +279,8 @@ class InferenceRouter(Inference): total_tokens=total_tokens, ) ) - await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) + if self.telemetry: + await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) return response async def completion(