make the telemetry API dep optional in inference router

This commit is contained in:
Dinesh Yeduguru 2025-02-05 09:25:21 -08:00
parent 52e533dc89
commit b180069def
2 changed files with 19 additions and 21 deletions

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,
@ -66,17 +66,16 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dic
"tool_runtime": ToolRuntimeRouter, "tool_runtime": ToolRuntimeRouter,
} }
api_to_deps = { api_to_deps = {
"inference": [Api.telemetry], "inference": {"telemetry": Api.telemetry},
} }
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
deps = [] api_to_dep_impl = {}
for dep in api_to_deps.get(api.value, []): for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep not in _deps: if dep_api in deps:
raise ValueError(f"Dependency {dep} not found in _deps") api_to_dep_impl[dep_name] = deps[dep_api]
deps.append(_deps[dep])
impl = api_to_routers[api.value](routing_table, *deps) impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -119,12 +119,13 @@ class InferenceRouter(Inference):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Telemetry, telemetry: Optional[Telemetry] = None,
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry self.telemetry = telemetry
self.tokenizer = Tokenizer.get_instance() if self.telemetry:
self.formatter = ChatFormat(self.tokenizer) self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -226,13 +227,13 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0 prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
@ -255,16 +256,13 @@ class InferenceRouter(Inference):
total_tokens=total_tokens, total_tokens=total_tokens,
) )
) )
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry:
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
yield chunk yield chunk
return stream_generator() return stream_generator()
else: else:
response = await provider.chat_completion(**params) response = await provider.chat_completion(**params)
model_input = self.formatter.encode_dialog_prompt(
messages,
tool_config.tool_prompt_format,
)
model_output = self.formatter.encode_dialog_prompt( model_output = self.formatter.encode_dialog_prompt(
[response.completion_message], [response.completion_message],
tool_config.tool_prompt_format, tool_config.tool_prompt_format,
@ -281,7 +279,8 @@ class InferenceRouter(Inference):
total_tokens=total_tokens, total_tokens=total_tokens,
) )
) )
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry:
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
return response return response
async def completion( async def completion(