mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Code refactoring and removing dead code
This commit is contained in:
parent
ef0736527d
commit
f6080040da
6 changed files with 302 additions and 137 deletions
|
@ -7,24 +7,16 @@
|
|||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, UTC
|
||||
from typing import Annotated, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
|
@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
enqueue_event,
|
||||
get_current_span,
|
||||
)
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
|
||||
)
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
@ -101,7 +102,9 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
|
@ -156,11 +159,16 @@ class InferenceRouter(Inference):
|
|||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model
|
||||
)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
return [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in metrics
|
||||
]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
|
@ -207,8 +215,13 @@ class InferenceRouter(Inference):
|
|||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
if (
|
||||
tool_prompt_format
|
||||
and tool_prompt_format != tool_config.tool_prompt_format
|
||||
):
|
||||
raise ValueError(
|
||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||
)
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
|
@ -226,9 +239,14 @@ class InferenceRouter(Inference):
|
|||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
tool_names = [
|
||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||
for t in tools
|
||||
]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||
)
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -243,7 +261,9 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
prompt_tokens = await self._count_tokens(
|
||||
messages, tool_config.tool_prompt_format
|
||||
)
|
||||
|
||||
if stream:
|
||||
response_stream = await provider.chat_completion(**params)
|
||||
|
@ -263,7 +283,9 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
metrics
|
||||
if not hasattr(response, "metrics") or response.metrics is None
|
||||
else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -336,7 +358,9 @@ class InferenceRouter(Inference):
|
|||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
metrics
|
||||
if not hasattr(response, "metrics") or response.metrics is None
|
||||
else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -374,9 +398,13 @@ class InferenceRouter(Inference):
|
|||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
|
||||
tool_choice
|
||||
)
|
||||
if tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
raise ValueError(
|
||||
"'tool_choice' is only allowed when 'tools' is also provided"
|
||||
)
|
||||
if tools:
|
||||
for tool in tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
@ -441,7 +469,9 @@ class InferenceRouter(Inference):
|
|||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
metrics
|
||||
if not hasattr(response, "metrics") or response.metrics is None
|
||||
else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -477,19 +507,31 @@ class InferenceRouter(Inference):
|
|||
) -> ListOpenAIChatCompletionResponse:
|
||||
if self.store:
|
||||
return await self.store.list_chat_completions(after, limit, model, order)
|
||||
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
||||
raise NotImplementedError(
|
||||
"List chat completions is not supported: inference store is not configured."
|
||||
)
|
||||
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
async def get_chat_completion(
|
||||
self, completion_id: str
|
||||
) -> OpenAICompletionWithInputMessages:
|
||||
if self.store:
|
||||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
raise NotImplementedError(
|
||||
"Get chat completion is not supported: inference store is not configured."
|
||||
)
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: dict
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
||||
if (
|
||||
choice.message
|
||||
and choice.message.tool_calls is not None
|
||||
and len(choice.message.tool_calls) == 0
|
||||
):
|
||||
choice.message.tool_calls = None
|
||||
return response
|
||||
|
||||
|
@ -509,7 +551,9 @@ class InferenceRouter(Inference):
|
|||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED
|
||||
)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
|
@ -522,7 +566,7 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
complete = False
|
||||
|
@ -544,7 +588,11 @@ class InferenceRouter(Inference):
|
|||
else:
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
if (
|
||||
hasattr(chunk, "stop_reason")
|
||||
and chunk.stop_reason
|
||||
and self.telemetry
|
||||
):
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
# if we are done receiving tokens
|
||||
|
@ -569,9 +617,14 @@ class InferenceRouter(Inference):
|
|||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
chunk.metrics = (
|
||||
async_metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + async_metrics
|
||||
)
|
||||
else:
|
||||
# Fallback if no telemetry
|
||||
completion_metrics = self._construct_metrics(
|
||||
|
@ -581,14 +634,19 @@ class InferenceRouter(Inference):
|
|||
model,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
chunk.metrics = (
|
||||
async_metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + async_metrics
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def count_tokens_and_compute_metrics(
|
||||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
response: ChatCompletionResponse,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
|
@ -597,7 +655,9 @@ class InferenceRouter(Inference):
|
|||
content = [response.completion_message]
|
||||
else:
|
||||
content = response.content
|
||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||
completion_tokens = await self._count_tokens(
|
||||
messages=content, tool_prompt_format=tool_prompt_format
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for completion metrics
|
||||
|
@ -610,11 +670,17 @@ class InferenceRouter(Inference):
|
|||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
if metric.metric in [
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
return [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in completion_metrics
|
||||
]
|
||||
|
||||
# Fallback if no telemetry
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -623,7 +689,10 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
return [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in metrics
|
||||
]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
|
@ -664,33 +733,48 @@ class InferenceRouter(Inference):
|
|||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
current_choice_data["content_parts"].append(
|
||||
delta.content
|
||||
)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
if (
|
||||
tc_idx
|
||||
not in current_choice_data["tool_calls_builder"]
|
||||
):
|
||||
current_choice_data["tool_calls_builder"][
|
||||
tc_idx
|
||||
] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
builder = current_choice_data["tool_calls_builder"][
|
||||
tc_idx
|
||||
]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
builder["function_name_parts"].append(
|
||||
tool_call_delta.function.name
|
||||
)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(
|
||||
tool_call_delta.function.arguments
|
||||
)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
current_choice_data["finish_reason"] = (
|
||||
choice_delta.finish_reason
|
||||
)
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
current_choice_data["logprobs_content_parts"].extend(
|
||||
choice_delta.logprobs.content
|
||||
)
|
||||
|
||||
# Compute metrics on final chunk
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
|
@ -720,8 +804,12 @@ class InferenceRouter(Inference):
|
|||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
func_name = "".join(
|
||||
tc_build_data["function_name_parts"]
|
||||
)
|
||||
func_args = "".join(
|
||||
tc_build_data["function_arguments_parts"]
|
||||
)
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
|
@ -734,10 +822,16 @@ class InferenceRouter(Inference):
|
|||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
tool_calls=(
|
||||
assembled_tool_calls if assembled_tool_calls else None
|
||||
),
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
final_logprobs = (
|
||||
OpenAIChoiceLogprobs(content=logprobs_content)
|
||||
if logprobs_content
|
||||
else None
|
||||
)
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
|
@ -756,4 +850,6 @@ class InferenceRouter(Inference):
|
|||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||
asyncio.create_task(
|
||||
self.store.store_chat_completion(final_response, messages)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue