Ran precommit

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:27:19 -07:00
parent 9886520b40
commit 9fc0d966f6
7 changed files with 153 additions and 310 deletions

View file

@ -7,9 +7,17 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import datetime, UTC
from datetime import UTC, datetime
from typing import Annotated, Any
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
)
from openai.types.chat import (
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
@ -48,12 +56,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
get_current_span,
)
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
logger = get_logger(name=__name__, category="core::routers")
@ -96,9 +98,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
@ -153,16 +153,11 @@ class InferenceRouter(Inference):
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model
)
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
enqueue_event(metric)
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
@ -256,9 +251,7 @@ class InferenceRouter(Inference):
# these metrics will show up in the client response.
response.metrics = (
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
@ -296,13 +289,9 @@ class InferenceRouter(Inference):
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
tool_choice
)
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
raise ValueError(
"'tool_choice' is only allowed when 'tools' is also provided"
)
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
@ -367,9 +356,7 @@ class InferenceRouter(Inference):
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
@ -405,31 +392,19 @@ class InferenceRouter(Inference):
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError(
"List chat completions is not supported: inference store is not configured."
)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError(
"Get chat completion is not supported: inference store is not configured."
)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: dict
) -> OpenAIChatCompletion:
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if (
choice.message
and choice.message.tool_calls is not None
and len(choice.message.tool_calls) == 0
):
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
choice.message.tool_calls = None
return response
@ -449,9 +424,7 @@ class InferenceRouter(Inference):
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED
)
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
@ -486,11 +459,7 @@ class InferenceRouter(Inference):
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if (
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
@ -515,14 +484,9 @@ class InferenceRouter(Inference):
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
@ -532,14 +496,9 @@ class InferenceRouter(Inference):
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
]
chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
yield chunk
async def count_tokens_and_compute_metrics(
@ -553,9 +512,7 @@ class InferenceRouter(Inference):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(
messages=content, tool_prompt_format=tool_prompt_format
)
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
@ -575,10 +532,7 @@ class InferenceRouter(Inference):
enqueue_event(metric)
# Return metrics in response
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
# Fallback if no telemetry
metrics = self._construct_metrics(
@ -587,10 +541,7 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
@ -631,48 +582,33 @@ class InferenceRouter(Inference):
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(
delta.content
)
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if (
tc_idx
not in current_choice_data["tool_calls_builder"]
):
current_choice_data["tool_calls_builder"][
tc_idx
] = {
if tc_idx not in current_choice_data["tool_calls_builder"]:
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][
tc_idx
]
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(
tool_call_delta.function.name
)
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(
tool_call_delta.function.arguments
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = (
choice_delta.finish_reason
)
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(
choice_delta.logprobs.content
)
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:
@ -702,12 +638,8 @@ class InferenceRouter(Inference):
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(
tc_build_data["function_name_parts"]
)
func_args = "".join(
tc_build_data["function_arguments_parts"]
)
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
@ -720,16 +652,10 @@ class InferenceRouter(Inference):
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=(
assembled_tool_calls if assembled_tool_calls else None
),
tool_calls=(assembled_tool_calls if assembled_tool_calls else None),
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = (
OpenAIChoiceLogprobs(content=logprobs_content)
if logprobs_content
else None
)
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
@ -748,6 +674,4 @@ class InferenceRouter(Inference):
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task(
self.store.store_chat_completion(final_response, messages)
)
asyncio.create_task(self.store.store_chat_completion(final_response, messages))