mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
make the telemetry API dep optional in inference router
This commit is contained in:
parent
52e533dc89
commit
b180069def
2 changed files with 19 additions and 21 deletions
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dict[str, Any]) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
@ -66,17 +66,16 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps: Dic
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
}
|
}
|
||||||
api_to_deps = {
|
api_to_deps = {
|
||||||
"inference": [Api.telemetry],
|
"inference": {"telemetry": Api.telemetry},
|
||||||
}
|
}
|
||||||
if api.value not in api_to_routers:
|
if api.value not in api_to_routers:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
deps = []
|
api_to_dep_impl = {}
|
||||||
for dep in api_to_deps.get(api.value, []):
|
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||||
if dep not in _deps:
|
if dep_api in deps:
|
||||||
raise ValueError(f"Dependency {dep} not found in _deps")
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
deps.append(_deps[dep])
|
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, *deps)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -119,12 +119,13 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Telemetry,
|
telemetry: Optional[Telemetry] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry = telemetry
|
self.telemetry = telemetry
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
if self.telemetry:
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -226,13 +227,13 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
|
messages,
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
model_input = self.formatter.encode_dialog_prompt(
|
|
||||||
messages,
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
prompt_tokens = len(model_input.tokens) if model_input.tokens else 0
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
|
@ -255,16 +256,13 @@ class InferenceRouter(Inference):
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
if self.telemetry:
|
||||||
|
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
response = await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
model_input = self.formatter.encode_dialog_prompt(
|
|
||||||
messages,
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
model_output = self.formatter.encode_dialog_prompt(
|
model_output = self.formatter.encode_dialog_prompt(
|
||||||
[response.completion_message],
|
[response.completion_message],
|
||||||
tool_config.tool_prompt_format,
|
tool_config.tool_prompt_format,
|
||||||
|
@ -281,7 +279,8 @@ class InferenceRouter(Inference):
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
if self.telemetry:
|
||||||
|
await self._log_token_usage(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue