Code refactoring and removing dead code

This commit is contained in:
Omar Abdelwahab 2025-10-02 18:38:30 -07:00
parent ef0736527d
commit f6080040da
6 changed files with 302 additions and 137 deletions

View file

@ -7,24 +7,16 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from datetime import datetime, UTC
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.providers.utils.telemetry.tracing import (
enqueue_event,
get_current_span,
)
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
logger = get_logger(name=__name__, category="core::routers")
@ -101,7 +102,9 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics(
self,
@ -156,11 +159,16 @@ class InferenceRouter(Inference):
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry:
for metric in metrics:
enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens(
self,
@ -207,8 +215,13 @@ class InferenceRouter(Inference):
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
if (
tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
params = {}
if tool_choice:
@ -226,9 +239,14 @@ class InferenceRouter(Inference):
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict(
model_id=model_id,
@ -243,7 +261,9 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream:
response_stream = await provider.chat_completion(**params)
@ -263,7 +283,9 @@ class InferenceRouter(Inference):
)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
@ -336,7 +358,9 @@ class InferenceRouter(Inference):
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
@ -374,9 +398,13 @@ class InferenceRouter(Inference):
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
tool_choice
)
if tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
raise ValueError(
"'tool_choice' is only allowed when 'tools' is also provided"
)
if tools:
for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
@ -441,7 +469,9 @@ class InferenceRouter(Inference):
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
@ -477,19 +507,31 @@ class InferenceRouter(Inference):
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
raise NotImplementedError(
"List chat completions is not supported: inference store is not configured."
)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
raise NotImplementedError(
"Get chat completion is not supported: inference store is not configured."
)
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: dict
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
if (
choice.message
and choice.message.tool_calls is not None
and len(choice.message.tool_calls) == 0
):
choice.message.tool_calls = None
return response
@ -509,7 +551,9 @@ class InferenceRouter(Inference):
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED
)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
@ -522,7 +566,7 @@ class InferenceRouter(Inference):
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
completion_text = ""
async for chunk in response:
complete = False
@ -544,7 +588,11 @@ class InferenceRouter(Inference):
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
if (
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
@ -569,9 +617,14 @@ class InferenceRouter(Inference):
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
@ -581,14 +634,19 @@ class InferenceRouter(Inference):
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
yield chunk
async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
response: ChatCompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
@ -597,7 +655,9 @@ class InferenceRouter(Inference):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
completion_tokens = await self._count_tokens(
messages=content, tool_prompt_format=tool_prompt_format
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
@ -610,11 +670,17 @@ class InferenceRouter(Inference):
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
enqueue_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
# Fallback if no telemetry
metrics = self._construct_metrics(
@ -623,7 +689,10 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
@ -664,33 +733,48 @@ class InferenceRouter(Inference):
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
current_choice_data["content_parts"].append(
delta.content
)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
current_choice_data["tool_calls_builder"][tc_idx] = {
if (
tc_idx
not in current_choice_data["tool_calls_builder"]
):
current_choice_data["tool_calls_builder"][
tc_idx
] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
builder = current_choice_data["tool_calls_builder"][
tc_idx
]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
builder["function_name_parts"].append(
tool_call_delta.function.name
)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(
tool_call_delta.function.arguments
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
current_choice_data["finish_reason"] = (
choice_delta.finish_reason
)
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
current_choice_data["logprobs_content_parts"].extend(
choice_delta.logprobs.content
)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:
@ -720,8 +804,12 @@ class InferenceRouter(Inference):
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
func_name = "".join(
tc_build_data["function_name_parts"]
)
func_args = "".join(
tc_build_data["function_arguments_parts"]
)
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
@ -734,10 +822,16 @@ class InferenceRouter(Inference):
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
tool_calls=(
assembled_tool_calls if assembled_tool_calls else None
),
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
final_logprobs = (
OpenAIChoiceLogprobs(content=logprobs_content)
if logprobs_content
else None
)
assembled_choices.append(
OpenAIChoice(
@ -756,4 +850,6 @@ class InferenceRouter(Inference):
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
asyncio.create_task(
self.store.store_chat_completion(final_response, messages)
)