Merge-related changes.

This commit is contained in:
ilya-kolchinsky 2025-04-02 19:56:44 +02:00
commit 60e9f46856
456 changed files with 38636 additions and 10892 deletions

View file

@ -4,22 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -27,13 +28,14 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataElement,
@ -48,18 +50,22 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.utils.chain import execute_preprocessor_chain
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
@ -126,9 +132,14 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
@ -151,6 +162,75 @@ class InferenceRouter(Inference):
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
@ -163,7 +243,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
@ -213,10 +293,52 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.chat_completion(**params)
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion(
self,
@ -246,10 +368,41 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.completion(**params)
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings(
self,
@ -330,21 +483,36 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
@ -457,7 +625,7 @@ class EvalRouter(Eval):
self,
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
@ -547,7 +715,7 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)