# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.common.content_types import ( URL, InterleavedContent, InterleavedContentItem, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset, DatasetPurpose, DataSource from llama_stack.apis.evaluation import ( Evaluation, EvaluationCandidate, EvaluationJob, EvaluationResponse, EvaluationTask, ) from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.shields import Shield from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, RAGQueryResult, RAGToolRuntime, ToolDef, ToolRuntime, ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") class VectorIORouter(VectorIO): """Routes to an provider based on the vector db identifier""" def __init__( self, routing_table: RoutingTable, ) -> None: logger.debug("Initializing VectorIORouter") self.routing_table = routing_table async def initialize(self) -> None: logger.debug("VectorIORouter.initialize") pass async def shutdown(self) -> None: logger.debug("VectorIORouter.shutdown") pass async def register_vector_db( self, vector_db_id: str, embedding_model: str, embedding_dimension: Optional[int] = 384, provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, embedding_dimension, provider_id, provider_vector_db_id, ) async def insert_chunks( self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) class InferenceRouter(Inference): """Routes to an provider based on the model""" def __init__( self, routing_table: RoutingTable, telemetry: Optional[Telemetry] = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table self.telemetry = telemetry if self.telemetry: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logger.debug("InferenceRouter.initialize") pass async def shutdown(self) -> None: logger.debug("InferenceRouter.shutdown") pass async def register_model( self, model_id: str, provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> None: logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) def _construct_metrics( self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model, ) -> List[MetricEvent]: """Constructs a list of MetricEvent objects containing token usage metrics. Args: prompt_tokens: Number of tokens in the prompt completion_tokens: Number of tokens in the completion total_tokens: Total number of tokens used model: Model object containing model_id and provider_id Returns: List of MetricEvent objects with token usage metrics """ span = get_current_span() if span is None: logger.warning("No span found for token usage metrics") return [] metrics = [ ("prompt_tokens", prompt_tokens), ("completion_tokens", completion_tokens), ("total_tokens", total_tokens), ] metric_events = [] for metric_name, value in metrics: metric_events.append( MetricEvent( trace_id=span.trace_id, span_id=span.span_id, metric=metric_name, value=value, timestamp=time.time(), unit="tokens", attributes={ "model_id": model.model_id, "provider_id": model.provider_id, }, ) ) return metric_events async def _compute_and_log_token_usage( self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model, ) -> List[MetricInResponse]: metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, messages: List[Message] | InterleavedContent, tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> Optional[int]: if isinstance(messages, list): encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) else: encoded = self.formatter.encode_content(messages) return len(encoded.tokens) if encoded and encoded.tokens else 0 async def chat_completion( self, model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: sampling_params = SamplingParams() model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") else: params = {} if tool_choice: params["tool_choice"] = tool_choice if tool_prompt_format: params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) tools = tools or [] if tool_config.tool_choice == ToolChoice.none: tools = [] elif tool_config.tool_choice == ToolChoice.auto: pass elif tool_config.tool_choice == ToolChoice.required: pass else: # verify tool_choice is one of the tools tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] if tool_config.tool_choice not in tool_names: raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") params = dict( model_id=model_id, messages=messages, sampling_params=sampling_params, tools=tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, response_format=response_format, stream=stream, logprobs=logprobs, tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): if chunk.event.event_type == ChatCompletionResponseEventType.progress: if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text if chunk.event.event_type == ChatCompletionResponseEventType.complete: completion_tokens = await self._count_tokens( [ CompletionMessage( content=completion_text, stop_reason=StopReason.end_of_turn, ) ], tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, ) chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() else: response = await provider.chat_completion(**params) completion_tokens = await self._count_tokens( [response.completion_message], tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, ) response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def completion( self, model_id: str, content: InterleavedContent, sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() logger.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, content=content, sampling_params=sampling_params, response_format=response_format, stream=stream, logprobs=logprobs, ) prompt_tokens = await self._count_tokens(content) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, ) chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() else: response = await provider.completion(**params) completion_tokens = await self._count_tokens(response.content) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( prompt_tokens or 0, completion_tokens or 0, total_tokens, model, ) response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def embeddings( self, model_id: str, contents: List[str] | List[InterleavedContentItem], text_truncation: Optional[TextTruncation] = TextTruncation.none, output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: logger.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, text_truncation=text_truncation, output_dimension=output_dimension, task_type=task_type, ) class SafetyRouter(Safety): def __init__( self, routing_table: RoutingTable, ) -> None: logger.debug("Initializing SafetyRouter") self.routing_table = routing_table async def initialize(self) -> None: logger.debug("SafetyRouter.initialize") pass async def shutdown(self) -> None: logger.debug("SafetyRouter.shutdown") pass async def register_shield( self, shield_id: str, provider_shield_id: Optional[str] = None, provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") return await self.routing_table.get_provider_impl(shield_id).run_shield( shield_id=shield_id, messages=messages, params=params, ) class DatasetIORouter(DatasetIO): def __init__( self, routing_table: RoutingTable, ) -> None: logger.debug("Initializing DatasetIORouter") self.routing_table = routing_table async def initialize(self) -> None: logger.debug("DatasetIORouter.initialize") pass async def shutdown(self) -> None: logger.debug("DatasetIORouter.shutdown") pass async def register_dataset( self, purpose: DatasetPurpose, source: DataSource, metadata: Optional[Dict[str, Any]] = None, dataset_id: Optional[str] = None, ) -> Dataset: logger.debug( f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", ) return await self.routing_table.register_dataset( purpose=purpose, source=source, metadata=metadata, dataset_id=dataset_id, ) async def iterrows( self, dataset_id: str, start_index: Optional[int] = None, limit: Optional[int] = None, ) -> IterrowsResponse: logger.debug( f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) return await self.routing_table.get_provider_impl(dataset_id).iterrows( dataset_id=dataset_id, start_index=start_index, limit=limit, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, rows=rows, ) class ToolRuntimeRouter(ToolRuntime): class RagToolImpl(RAGToolRuntime): def __init__( self, routing_table: RoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table async def query( self, content: InterleavedContent, vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) async def insert( self, documents: List[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) return await self.routing_table.get_provider_impl("insert_into_memory").insert( documents, vector_db_id, chunk_size_in_tokens ) def __init__( self, routing_table: RoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table # HACK ALERT this should be in sync with "get_all_api_endpoints()" self.rag_tool = self.RagToolImpl(routing_table) for method in ("query", "insert"): setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) async def initialize(self) -> None: logger.debug("ToolRuntimeRouter.initialize") pass async def shutdown(self) -> None: logger.debug("ToolRuntimeRouter.shutdown") pass async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, kwargs=kwargs, ) async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) class EvaluationRouter(Evaluation): def __init__( self, routing_table: RoutingTable, ) -> None: logger.debug("Initializing EvaluationRouter") self.routing_table = routing_table async def initialize(self) -> None: logger.debug("EvaluationRouter.initialize") pass async def shutdown(self) -> None: logger.debug("EvaluationRouter.shutdown") pass async def register_benchmark( self, dataset_id: str, grader_ids: List[str], benchmark_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Benchmark: logger.debug( f"EvaluationRouter.register_benchmark: {benchmark_id=} {dataset_id=} {grader_ids=} {metadata=}", ) return await self.routing_table.register_benchmark( benchmark_id=benchmark_id, dataset_id=dataset_id, grader_ids=grader_ids, metadata=metadata, ) async def run( self, task: EvaluationTask, candidate: EvaluationCandidate, ) -> EvaluationJob: raise NotImplementedError("Run is not implemented yet") async def run_sync( self, task: EvaluationTask, candidate: EvaluationCandidate, ) -> EvaluationResponse: raise NotImplementedError("Run sync is not implemented yet") async def grade(self, task: EvaluationTask) -> EvaluationJob: raise NotImplementedError("Grade is not implemented yet") async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse: raise NotImplementedError("Grade sync is not implemented yet")