mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
make the telemetry API dep optional in inference router
This commit is contained in:
parent
52e533dc89
commit
b180069def
2 changed files with 19 additions and 21 deletions
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
@ -66,17 +66,16 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dic
|
|||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": [Api.telemetry],
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
deps = []
|
||||
for dep in api_to_deps.get(api.value, []):
|
||||
if dep not in _deps:
|
||||
raise ValueError(f"Dependency {dep} not found in _deps")
|
||||
deps.append(_deps[dep])
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, *deps)
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -119,10 +119,11 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Telemetry,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
|
@ -226,13 +227,13 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
model_input = self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
|
@ -255,16 +256,13 @@ class InferenceRouter(Inference):
|
|||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
if self.telemetry:
|
||||
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response = await provider.chat_completion(**params)
|
||||
model_input = self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
model_output = self.formatter.encode_dialog_prompt(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
|
@ -281,6 +279,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
if self.telemetry:
|
||||
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
return response
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue