From c7139b0b6792e5c3b30d245b05c3870554c7e758 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 11:59:21 -0700 Subject: [PATCH 01/11] fix: fix precommit (#1594) # What does this PR do? - fix precommit [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan CI [//]: # (## Documentation) --- .../open-benchmark/open_benchmark.py | 29 ++++++++++++------- llama_stack/templates/template.py | 6 ++-- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 7df33a715..2b40797f9 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Tuple +from typing import Dict, List, Tuple +from llama_stack.apis.common.content_types import URL from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( BenchmarkInput, @@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( + SQLiteVectorIOConfig, +) from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig from llama_stack.providers.remote.inference.gemini.config import GeminiConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig -from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, +from llama_stack.providers.remote.vector_io.pgvector.config import ( + PGVectorVectorIOConfig, +) +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, ) -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry -def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: +def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: # in this template, we allow each API key to be optional providers = [ ( @@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="simpleqa", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"), metadata={ "path": "llamastack/simpleqa", "split": "train", @@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="mmlu_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"), metadata={ "path": "llamastack/mmlu_cot", "name": "all", @@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="gpqa_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"), metadata={ "path": "llamastack/gpqa_0shot_cot", "name": "gpqa_main", @@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="math_500", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/math_500"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"), metadata={ "path": "llamastack/math_500", "split": "test", diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index aa1ce144f..a5c8e80bc 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]: +def get_model_registry( + available_models: Dict[str, List[ProviderModelEntry]], +) -> List[ModelInput]: models = [] for provider_id, entries in available_models.items(): for entry in entries: @@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel): default_models.append( DefaultModel( model_id=model_entry.provider_model_id, - doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "", + doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""), ) ) From 58d08d100ef807ddd663cfaae18be383aa911ae5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 12 Mar 2025 12:01:03 -0700 Subject: [PATCH 02/11] feat: Add back inference metrics and preserve context variables across asyncio boundary (#1552) # What does this PR do? This PR adds back the changes in #1300 which were reverted in #1476 . It also adds logic to preserve context variables across asyncio boundary. this is needed with the library client since the async generator logic yields control to code outside the event loop, and on resuming, does not have the same context as before and this requires preserving the context vars. address #1477 ## Test Plan ``` curl --request POST \ --url http://localhost:8321/v1/inference/chat-completion \ --header 'content-type: application/json' \ --data '{ "model_id": "meta-llama/Llama-3.1-70B-Instruct", "messages": [ { "role": "user", "content": { "type": "text", "text": "where do humans live" } } ], "stream": false }' | jq . { "metrics": [ { "trace_id": "kCZwO3tyQC-FuAGb", "span_id": "bsP_5a5O", "timestamp": "2025-03-11T16:47:38.549084Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "prompt_tokens", "value": 10, "unit": "tokens" }, { "trace_id": "kCZwO3tyQC-FuAGb", "span_id": "bsP_5a5O", "timestamp": "2025-03-11T16:47:38.549449Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "completion_tokens", "value": 369, "unit": "tokens" }, { "trace_id": "kCZwO3tyQC-FuAGb", "span_id": "bsP_5a5O", "timestamp": "2025-03-11T16:47:38.549457Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "total_tokens", "value": 379, "unit": "tokens" } ], "completion_message": { "role": "assistant", "content": "Humans live on the planet Earth, specifically on its landmasses and in its oceans. Here's a breakdown of where humans live:\n\n1. **Continents:** Humans inhabit all seven continents:\n\t* Africa\n\t* Antarctica ( temporary residents, mostly scientists and researchers)\n\t* Asia\n\t* Australia\n\t* Europe\n\t* North America\n\t* South America\n2. **Countries:** There are 196 countries recognized by the United Nations, and humans live in almost all of them.\n3. **Cities and towns:** Many humans live in urban areas, such as cities and towns, which are often located near coastlines, rivers, or other bodies of water.\n4. **Rural areas:** Some humans live in rural areas, such as villages, farms, and countryside.\n5. **Islands:** Humans inhabit many islands around the world, including those in the Pacific, Indian, and Atlantic Oceans.\n6. **Mountains and highlands:** Humans live in mountainous regions, such as the Himalayas, the Andes, and the Rocky Mountains.\n7. **Deserts:** Some humans live in desert regions, such as the Sahara, the Mojave, and the Atacama.\n8. **Coastal areas:** Many humans live in coastal areas, such as beaches, ports, and coastal cities.\n9. **Underwater habitats:** A few humans live in underwater habitats, such as research stations and submarines.\n10. **Space:** A small number of humans have lived in space, including astronauts on the International Space Station and those who have visited the Moon.\n\nOverall, humans can be found living in almost every environment on Earth, from the frozen tundra to the hottest deserts, and from the highest mountains to the deepest oceans.", "stop_reason": "end_of_turn", "tool_calls": [] }, "logprobs": null } ``` Orignal repro no longer showing any error: ``` LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/fireworks/fireworks-run.yaml python -m examples.agents.e2e_loop_with_client_tools localhost 8321 ``` client logs: https://gist.github.com/dineshyv/047c7e87b18a5792aa660e311ea53166 server logs: https://gist.github.com/dineshyv/97a2174099619e9916c7c490be26e559 --- llama_stack/apis/inference/inference.py | 8 +- llama_stack/distribution/library_client.py | 8 +- llama_stack/distribution/request_headers.py | 37 +--- llama_stack/distribution/resolver.py | 4 +- llama_stack/distribution/routers/__init__.py | 12 +- llama_stack/distribution/routers/routers.py | 163 +++++++++++++++++- llama_stack/distribution/server/server.py | 8 +- llama_stack/distribution/utils/context.py | 33 ++++ .../distribution/utils/tests/test_context.py | 155 +++++++++++++++++ .../telemetry/meta_reference/telemetry.py | 3 + 10 files changed, 380 insertions(+), 51 deletions(-) create mode 100644 llama_stack/distribution/utils/context.py create mode 100644 llama_stack/distribution/utils/tests/test_context.py diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d0f5d15c5..fa917ac22 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -285,7 +285,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(BaseModel): +class CompletionResponse(MetricResponseMixin): """Response from a completion request. :param content: The generated completion text @@ -299,7 +299,7 @@ class CompletionResponse(BaseModel): @json_schema_type -class CompletionResponseStreamChunk(BaseModel): +class CompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): +class ChatCompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin, BaseModel): +class ChatCompletionResponse(MetricResponseMixin): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5dc70bb67..15c4fe6ea 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry @@ -44,8 +44,10 @@ from llama_stack.distribution.stack import ( redact_sensitive_fields, replace_env_vars, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): finally: await end_trace() - # Wrap the generator to preserve context across iterations - wrapped_gen = preserve_headers_context_async_generator(gen()) + wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]) + mock_response = httpx.Response( status_code=httpx.codes.OK, content=wrapped_gen, diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 19afae59b..8709fc040 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,14 +7,14 @@ import contextvars import json import logging -from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar +from typing import Any, ContextManager, Dict, Optional from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) # Context variable for request provider data -_provider_data_var = contextvars.ContextVar("provider_data", default=None) +PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) class RequestProviderDataContext(ContextManager): @@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager): def __enter__(self): # Save the current value and set the new one - self.token = _provider_data_var.set(self.provider_data) + self.token = PROVIDER_DATA_VAR.set(self.provider_data) return self def __exit__(self, exc_type, exc_val, exc_tb): # Restore the previous value if self.token is not None: - _provider_data_var.reset(self.token) - - -T = TypeVar("T") - - -def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: - """ - Wraps an async generator to preserve request headers context variables across iterations. - - This ensures that context variables set during generator creation are - available during each iteration of the generator, even if the original - context manager has exited. - """ - # Capture the current context value right now - context_value = _provider_data_var.get() - - async def wrapper(): - while True: - # Set context before each anext() call - _ = _provider_data_var.set(context_value) - try: - item = await gen.__anext__() - yield item - except StopAsyncIteration: - break - - return wrapper() + PROVIDER_DATA_VAR.reset(self.token) class NeedsRequestProviderData: @@ -72,7 +45,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_type} does not have a validator") - val = _provider_data_var.get() + val = PROVIDER_DATA_VAR.get() if not val: return None diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index d7ca4414d..ab075f399 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - deps__=[info.routing_table_api.value], + # Add telemetry as an optional dependency to all auto-routed providers + optional_api_dependencies=[Api.telemetry], + deps__=([info.routing_table_api.value, Api.telemetry.value]), ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..d0fca8771 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } + api_to_deps = { + "inference": {"telemetry": Api.telemetry}, + } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_routers[api.value](routing_table) + api_to_dep_impl = {} + for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): + if dep_api in deps: + api_to_dep_impl[dep_name] = deps[dep_api] + + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 28df67922..68b8e55cb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional +import time +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( URL, @@ -20,6 +21,10 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -27,13 +32,14 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, SamplingParams, + StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -42,6 +48,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield +from llama_stack.apis.telemetry import MetricEvent, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -52,7 +59,10 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger +from llama_stack.models.llama.llama3.chat_format import ChatFormat +from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -119,9 +129,14 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, + telemetry: Optional[Telemetry] = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table + self.telemetry = telemetry + if self.telemetry: + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logger.debug("InferenceRouter.initialize") @@ -144,6 +159,71 @@ class InferenceRouter(Inference): ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + def _construct_metrics( + self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + ) -> List[MetricEvent]: + """Constructs a list of MetricEvent objects containing token usage metrics. + + Args: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total number of tokens used + model: Model object containing model_id and provider_id + + Returns: + List of MetricEvent objects with token usage metrics + """ + span = get_current_span() + if span is None: + logger.warning("No span found for token usage metrics") + return [] + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + metric_events = [] + for metric_name, value in metrics: + metric_events.append( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model.model_id, + "provider_id": model.provider_id, + }, + ) + ) + return metric_events + + async def _compute_and_log_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, + ) -> List[MetricEvent]: + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + if self.telemetry: + for metric in metrics: + await self.telemetry.log_event(metric) + return metrics + + async def _count_tokens( + self, + messages: List[Message] | InterleavedContent, + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Optional[int]: + if isinstance(messages, list): + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + else: + encoded = self.formatter.encode_content(messages) + return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def chat_completion( self, model_id: str, @@ -156,8 +236,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( + "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: @@ -206,10 +287,47 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + if stream: - return (chunk async for chunk in await provider.chat_completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.chat_completion(**params): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + completion_tokens = await self._count_tokens( + [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.chat_completion(**params) + response = await provider.chat_completion(**params) + completion_tokens = await self._count_tokens( + [response.completion_message], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def completion( self, @@ -239,10 +357,41 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) + + prompt_tokens = await self._count_tokens(content) + if stream: - return (chunk async for chunk in await provider.completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.completion(**params): + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + completion_tokens = await self._count_tokens(completion_text) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.completion(**params) + response = await provider.completion(**params) + completion_tokens = await self._count_tokens(response.content) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def embeddings( self, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2cc70a738..7ca009b13 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,7 +28,7 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError @@ -38,6 +38,7 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig @@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, ) from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): try: if is_streaming: - gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) + gen = preserve_contexts_async_generator( + sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] + ) return StreamingResponse(gen, media_type="text/event-stream") else: value = func(**kwargs) diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py new file mode 100644 index 000000000..107ce7127 --- /dev/null +++ b/llama_stack/distribution/utils/context.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextvars import ContextVar +from typing import AsyncGenerator, List, TypeVar + +T = TypeVar("T") + + +def preserve_contexts_async_generator( + gen: AsyncGenerator[T, None], context_vars: List[ContextVar] +) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve context variables across iterations. + This is needed because we start a new asyncio event loop for each streaming request, + and we need to preserve the context across the event loop boundary. + """ + + async def wrapper(): + while True: + try: + item = await gen.__anext__() + context_values = {context_var.name: context_var.get() for context_var in context_vars} + yield item + for context_var in context_vars: + _ = context_var.set(context_values[context_var.name]) + except StopAsyncIteration: + break + + return wrapper() diff --git a/llama_stack/distribution/utils/tests/test_context.py b/llama_stack/distribution/utils/tests/test_context.py new file mode 100644 index 000000000..84944bfe8 --- /dev/null +++ b/llama_stack/distribution/utils/tests/test_context.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar + +import pytest + +from llama_stack.distribution.utils.context import preserve_contexts_async_generator + + +@pytest.mark.asyncio +async def test_preserve_contexts_with_exception(): + # Create context variable + context_var = ContextVar("exception_var", default="initial") + token = context_var.set("start_value") + + # Create an async generator that raises an exception + async def exception_generator(): + yield context_var.get() + context_var.set("modified") + raise ValueError("Test exception") + yield None # This will never be reached + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var]) + + # First iteration should work + value = await wrapped_gen.__anext__() + assert value == "start_value" + + # Second iteration should raise the exception + with pytest.raises(ValueError, match="Test exception"): + await wrapped_gen.__anext__() + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_empty_generator(): + # Create context variable + context_var = ContextVar("empty_var", default="initial") + token = context_var.set("value") + + # Create an empty async generator + async def empty_generator(): + if False: # This condition ensures the generator yields nothing + yield None + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var]) + + # The generator should raise StopAsyncIteration immediately + with pytest.raises(StopAsyncIteration): + await wrapped_gen.__anext__() + + # Context variable should remain unchanged + assert context_var.get() == "value" + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_across_event_loops(): + """ + Test that context variables are preserved across event loop boundaries with nested generators. + This simulates the real-world scenario where: + 1. A new event loop is created for each streaming request + 2. The async generator runs inside that loop + 3. There are multiple levels of nested generators + 4. Context needs to be preserved across these boundaries + """ + # Create context variables + request_id = ContextVar("request_id", default=None) + user_id = ContextVar("user_id", default=None) + + # Set initial values + + # Results container to verify values across thread boundaries + results = [] + + # Inner-most generator (level 2) + async def inner_generator(): + # Should have the context from the outer scope + yield (1, request_id.get(), user_id.get()) + + # Modify one context variable + user_id.set("user-modified") + + # Should reflect the modification + yield (2, request_id.get(), user_id.get()) + + # Middle generator (level 1) + async def middle_generator(): + inner_gen = inner_generator() + + # Forward the first yield from inner + item = await inner_gen.__anext__() + yield item + + # Forward the second yield from inner + item = await inner_gen.__anext__() + yield item + + request_id.set("req-modified") + + # Add our own yield with both modified variables + yield (3, request_id.get(), user_id.get()) + + # Function to run in a separate thread with a new event loop + def run_in_new_loop(): + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Outer generator (runs in the new loop) + async def outer_generator(): + request_id.set("req-12345") + user_id.set("user-6789") + # Wrap the middle generator + wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id]) + + # Process all items from the middle generator + async for item in wrapped_gen: + # Store results for verification + results.append(item) + + # Run the outer generator in the new loop + loop.run_until_complete(outer_generator()) + finally: + loop.close() + + # Run the generator chain in a separate thread with a new event loop + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + future.result() # Wait for completion + + # Verify the results + assert len(results) == 3 + + # First yield should have original values + assert results[0] == (1, "req-12345", "user-6789") + + # Second yield should have modified user_id + assert results[1] == (2, "req-12345", "user-modified") + + # Third yield should have both modified values + assert results[2] == (3, "req-modified", "user-modified") diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index e713a057f..4cdb420b2 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) + self.meter = None resource = Resource.create( { @@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: + if self.meter is None: + return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=event.attributes) From b7a9c454779979eaa89196901fc7856a5f10b900 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 12 Mar 2025 12:10:21 -0700 Subject: [PATCH 03/11] chore: deprecate ToolResponseMessage in agent.resume API (#1566) # Summary: closes #1431 # Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct --- docs/_static/llama-stack-spec.html | 20 +++++-------------- docs/_static/llama-stack-spec.yaml | 13 ++++-------- llama_stack/apis/agents/agents.py | 5 ++--- .../agents/meta_reference/agent_instance.py | 18 +++++------------ .../inline/agents/meta_reference/agents.py | 2 +- 5 files changed, 17 insertions(+), 41 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b0febbbef..709360ede 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9490,21 +9490,11 @@ "type": "object", "properties": { "tool_responses": { - "oneOf": [ - { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolResponse" - } - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolResponseMessage" - } - } - ], - "description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse." + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolResponse" + }, + "description": "The tool call responses to resume the turn with." }, "stream": { "type": "boolean", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 2985e6222..4c00fbe63 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6405,16 +6405,11 @@ components: type: object properties: tool_responses: - oneOf: - - type: array - items: - $ref: '#/components/schemas/ToolResponse' - - type: array - items: - $ref: '#/components/schemas/ToolResponseMessage' + type: array + items: + $ref: '#/components/schemas/ToolResponse' description: >- - The tool call responses to resume the turn with. NOTE: ToolResponseMessage - will be deprecated. Use ToolResponse. + The tool call responses to resume the turn with. stream: type: boolean description: Whether to stream the response. diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 1170a56d5..5cc910a55 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel): agent_id: str session_id: str turn_id: str - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]] + tool_responses: List[ToolResponse] stream: Optional[bool] = False @@ -449,7 +449,7 @@ class Agents(Protocol): agent_id: str, session_id: str, turn_id: str, - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], + tool_responses: List[ToolResponse], stream: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: """Resume an agent turn with executed tool call responses. @@ -460,7 +460,6 @@ class Agents(Protocol): :param session_id: The ID of the session to resume. :param turn_id: The ID of the turn to resume. :param tool_responses: The tool call responses to resume the turn with. - NOTE: ToolResponseMessage will be deprecated. Use ToolResponse. :param stream: Whether to stream the response. :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. """ diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fedd695c1..1d9f54e96 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin): steps = [] messages = await self.get_messages_from_turns(turns) if is_resume: - if isinstance(request.tool_responses[0], ToolResponseMessage): - tool_response_messages = request.tool_responses - tool_responses = [ - ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content) - for x in request.tool_responses - ] - else: - tool_response_messages = [ - ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) - for x in request.tool_responses - ] - tool_responses = request.tool_responses + tool_response_messages = [ + ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) + for x in request.tool_responses + ] messages.extend(tool_response_messages) last_turn = turns[-1] last_turn_messages = self.turn_to_messages(last_turn) @@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin): step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), - tool_responses=tool_responses, + tool_responses=request.tool_responses, completed_at=now, started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index c24b14e35..5ca123595 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, turn_id: str, - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], + tool_responses: List[ToolResponse], stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnResumeRequest( From 0fdb15bcc74f236cd4e6ac4291a361a08b6bf1b3 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 12 Mar 2025 13:26:23 -0700 Subject: [PATCH 04/11] fix: fix build error in context.py (#1595) # What does this PR do? This fixes the build error ## Test Plan pre-commit run --all-files check for merge conflicts................................................Passed trim trailing whitespace.................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed ruff.....................................................................Passed ruff-format..............................................................Passed blacken-docs.............................................................Passed uv-lock..................................................................Passed uv-export................................................................Passed mypy.....................................................................Passed Distribution Template Codegen............................................Passed --- llama_stack/distribution/utils/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py index 107ce7127..2f32afba2 100644 --- a/llama_stack/distribution/utils/context.py +++ b/llama_stack/distribution/utils/context.py @@ -19,7 +19,7 @@ def preserve_contexts_async_generator( and we need to preserve the context across the event loop boundary. """ - async def wrapper(): + async def wrapper() -> AsyncGenerator[T, None]: while True: try: item = await gen.__anext__() From 1311faf3f5e7e18111b642be6bbdd941c2034e02 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 12 Mar 2025 14:57:31 -0700 Subject: [PATCH 05/11] fix: logging (#1598) Summary: Test Plan: --- llama_stack/distribution/routers/routers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 68b8e55cb..34102d04b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -238,7 +238,6 @@ class InferenceRouter(Inference): tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( - "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: From ad939c97c37f8e33d0e94fe43893773cffe7618e Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 12 Mar 2025 18:41:35 -0400 Subject: [PATCH 06/11] docs: add unit test badge to README (#1591) # What does this PR do? This PR adds a simple unit test badge to the project README It also modifies the workflow to run on merges to main, so that the status reflected in the README is that of main and not pull request branches --------- Signed-off-by: Nathan Weinberg --- .github/workflows/unit-tests.yml | 2 ++ README.md | 1 + 2 files changed, 3 insertions(+) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 39505ba11..59d18b3be 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,8 @@ name: Unit Tests on: + push: + branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: diff --git a/README.md b/README.md index b24e69514..6e1fd088e 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack) +![Unit](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) From 99bbe0e70b125f93da659ca722a9d5c2f6ef7022 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 12 Mar 2025 15:45:44 -0700 Subject: [PATCH 07/11] feat: Add new compact MetricInResponse type (#1593) # What does this PR do? This change adds a compact type to include metrics in response as opposed to the full MetricEvent which is relevant for internal logging purposes. ## Test Plan ``` LLAMA_STACK_CONFIG=~/.llama/distributions/fireworks/fireworks-run.yaml pytest -s -v agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct llama stack run ~/.llama/distributions/fireworks/fireworks-run.yaml curl --request POST \ --url http://localhost:8321/v1/inference/chat-completion \ --header 'content-type: application/json' \ --data '{ "model_id": "meta-llama/Llama-3.1-70B-Instruct", "messages": [ { "role": "user", "content": { "type": "text", "text": "where do humans live" } } ], "stream": false }' { "metrics": [ { "metric": "prompt_tokens", "value": 10, "unit": null }, { "metric": "completion_tokens", "value": 522, "unit": null }, { "metric": "total_tokens", "value": 532, "unit": null } ], "completion_message": { "role": "assistant", "content": "Humans live in various parts of the world...............", "stop_reason": "out_of_tokens", "tool_calls": [] }, "logprobs": null } ``` --- docs/_static/llama-stack-spec.html | 133 +++++++++++++------- docs/_static/llama-stack-spec.yaml | 82 +++++++----- llama_stack/apis/telemetry/telemetry.py | 9 +- llama_stack/distribution/routers/routers.py | 6 +- 4 files changed, 150 insertions(+), 80 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 709360ede..dbd530aa3 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4549,7 +4549,7 @@ "metrics": { "type": "array", "items": { - "$ref": "#/components/schemas/MetricEvent" + "$ref": "#/components/schemas/MetricInResponse" } }, "completion_message": { @@ -4571,46 +4571,9 @@ "title": "ChatCompletionResponse", "description": "Response from a chat completion request." }, - "MetricEvent": { + "MetricInResponse": { "type": "object", "properties": { - "trace_id": { - "type": "string" - }, - "span_id": { - "type": "string" - }, - "timestamp": { - "type": "string", - "format": "date-time" - }, - "attributes": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - } - ] - } - }, - "type": { - "type": "string", - "const": "metric", - "default": "metric" - }, "metric": { "type": "string" }, @@ -4630,15 +4593,10 @@ }, "additionalProperties": false, "required": [ - "trace_id", - "span_id", - "timestamp", - "type", "metric", - "value", - "unit" + "value" ], - "title": "MetricEvent" + "title": "MetricInResponse" }, "TokenLogProbs": { "type": "object", @@ -4715,6 +4673,12 @@ "CompletionResponse": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + } + }, "content": { "type": "string", "description": "The generated completion text" @@ -4924,7 +4888,7 @@ "metrics": { "type": "array", "items": { - "$ref": "#/components/schemas/MetricEvent" + "$ref": "#/components/schemas/MetricInResponse" } }, "event": { @@ -5082,6 +5046,12 @@ "CompletionResponseStreamChunk": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + } + }, "delta": { "type": "string", "description": "New content generated since last chunk. This can be one or more tokens." @@ -8363,6 +8333,75 @@ ], "title": "LogSeverity" }, + "MetricEvent": { + "type": "object", + "properties": { + "trace_id": { + "type": "string" + }, + "span_id": { + "type": "string" + }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] + } + }, + "type": { + "type": "string", + "const": "metric", + "default": "metric" + }, + "metric": { + "type": "string" + }, + "value": { + "oneOf": [ + { + "type": "integer" + }, + { + "type": "number" + } + ] + }, + "unit": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "trace_id", + "span_id", + "timestamp", + "type", + "metric", + "value", + "unit" + ], + "title": "MetricEvent" + }, "SpanEndPayload": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 4c00fbe63..cca1872a4 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3101,7 +3101,7 @@ components: metrics: type: array items: - $ref: '#/components/schemas/MetricEvent' + $ref: '#/components/schemas/MetricInResponse' completion_message: $ref: '#/components/schemas/CompletionMessage' description: The complete response message @@ -3116,29 +3116,9 @@ components: - completion_message title: ChatCompletionResponse description: Response from a chat completion request. - MetricEvent: + MetricInResponse: type: object properties: - trace_id: - type: string - span_id: - type: string - timestamp: - type: string - format: date-time - attributes: - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - type: - type: string - const: metric - default: metric metric: type: string value: @@ -3149,14 +3129,9 @@ components: type: string additionalProperties: false required: - - trace_id - - span_id - - timestamp - - type - metric - value - - unit - title: MetricEvent + title: MetricInResponse TokenLogProbs: type: object properties: @@ -3213,6 +3188,10 @@ components: CompletionResponse: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' content: type: string description: The generated completion text @@ -3412,7 +3391,7 @@ components: metrics: type: array items: - $ref: '#/components/schemas/MetricEvent' + $ref: '#/components/schemas/MetricInResponse' event: $ref: '#/components/schemas/ChatCompletionResponseEvent' description: The event containing the new content @@ -3531,6 +3510,10 @@ components: CompletionResponseStreamChunk: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' delta: type: string description: >- @@ -5703,6 +5686,47 @@ components: - error - critical title: LogSeverity + MetricEvent: + type: object + properties: + trace_id: + type: string + span_id: + type: string + timestamp: + type: string + format: date-time + attributes: + type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + type: + type: string + const: metric + default: metric + metric: + type: string + value: + oneOf: + - type: integer + - type: number + unit: + type: string + additionalProperties: false + required: + - trace_id + - span_id + - timestamp + - type + - metric + - value + - unit + title: MetricEvent SpanEndPayload: type: object properties: diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index fe75677e7..cbea57e79 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -96,6 +96,13 @@ class MetricEvent(EventCommon): unit: str +@json_schema_type +class MetricInResponse(BaseModel): + metric: str + value: Union[int, float] + unit: Optional[str] = None + + # This is a short term solution to allow inference API to return metrics # The ideal way to do this is to have a way for all response types to include metrics # and all metric events logged to the telemetry API to be inlcuded with the response @@ -117,7 +124,7 @@ class MetricEvent(EventCommon): class MetricResponseMixin(BaseModel): - metrics: Optional[List[MetricEvent]] = None + metrics: Optional[List[MetricInResponse]] = None @json_schema_type diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 34102d04b..22a1e46f9 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -48,7 +48,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.telemetry import MetricEvent, Telemetry +from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -206,12 +206,12 @@ class InferenceRouter(Inference): completion_tokens: int, total_tokens: int, model: Model, - ) -> List[MetricEvent]: + ) -> List[MetricInResponse]: metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return metrics + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, From 18de4cd08ae8a2eaed8136934cceb709b2e8d95d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 18:38:07 -0700 Subject: [PATCH 08/11] comments --- llama_stack/apis/datasets/datasets.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 4b3ce3e6f..b18dd204b 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -13,10 +13,10 @@ from llama_stack.apis.resource import Resource, ResourceType from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -class Schema(Enum): +class DatasetPurpose(Enum): """ - Schema of the dataset. Each type has a different column format. - :cvar messages: The dataset contains messages used for post-training. Examples: + Purpose of the dataset. Each type has a different column format. + :cvar tuning/messages: The dataset contains messages used for post-training. Examples: { "messages": [ {"role": "user", "content": "Hello, world!"}, @@ -25,7 +25,8 @@ class Schema(Enum): } """ - messages = "messages" + tuning_messages = "tuning/messages" + # TODO: add more schemas here @@ -99,8 +100,8 @@ class Datasets(Protocol): @webmethod(route="/datasets", method="POST") async def register_dataset( self, - schema: Schema, - data_source: DataSource, + purpose: DatasetPurpose, + source: DataSource, metadata: Optional[Dict[str, Any]] = None, dataset_id: Optional[str] = None, ) -> Dataset: From a3173e8284501d7bd9c275733472deb571d4c30c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 18:46:40 -0700 Subject: [PATCH 09/11] update --- docs/_static/llama-stack-spec.html | 49 ++++++++++++++++----------- docs/_static/llama-stack-spec.yaml | 45 ++++++++++++++++-------- llama_stack/apis/datasets/datasets.py | 42 +++++++++++++++++++---- 3 files changed, 95 insertions(+), 41 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 821e5ed53..856c6e715 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -6846,13 +6846,14 @@ "const": "dataset", "default": "dataset" }, - "schema": { + "purpose": { "type": "string", "enum": [ - "messages" + "post-training/messages", + "eval/question-answer" ], - "title": "Schema", - "description": "Schema of the dataset. Each type has a different column format." + "title": "DatasetPurpose", + "description": "Purpose of the dataset. Each type has a different column format." }, "data_source": { "$ref": "#/components/schemas/DataSource" @@ -6889,7 +6890,7 @@ "provider_resource_id", "provider_id", "type", - "schema", + "purpose", "data_source", "metadata" ], @@ -6903,8 +6904,9 @@ "const": "huggingface", "default": "huggingface" }, - "dataset_path": { - "type": "string" + "path": { + "type": "string", + "description": "The path to the dataset in Huggingface. E.g. - \"llamastack/simpleqa\"" }, "params": { "type": "object", @@ -6929,16 +6931,18 @@ "type": "object" } ] - } + }, + "description": "The parameters for the dataset." } }, "additionalProperties": false, "required": [ "type", - "dataset_path", + "path", "params" ], - "title": "HuggingfaceDataSource" + "title": "HuggingfaceDataSource", + "description": "A dataset stored in Huggingface." }, "RowsDataSource": { "type": "object", @@ -6974,7 +6978,8 @@ } ] } - } + }, + "description": "The dataset is stored in rows. E.g. - [ {\"messages\": [{\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}]} ]" } }, "additionalProperties": false, @@ -6982,7 +6987,8 @@ "type", "rows" ], - "title": "RowsDataSource" + "title": "RowsDataSource", + "description": "A dataset stored in rows." }, "URIDataSource": { "type": "object", @@ -6993,7 +6999,8 @@ "default": "uri" }, "uri": { - "type": "string" + "type": "string", + "description": "The dataset can be obtained from a URI. E.g. - \"https://mywebsite.com/mydata.jsonl\" - \"lsfs://mydata.jsonl\" - \"data:csv;base64,{base64_content}\"" } }, "additionalProperties": false, @@ -7001,7 +7008,8 @@ "type", "uri" ], - "title": "URIDataSource" + "title": "URIDataSource", + "description": "A dataset that can be obtained from a URI." }, "Model": { "type": "object", @@ -9419,14 +9427,15 @@ "RegisterDatasetRequest": { "type": "object", "properties": { - "schema": { + "purpose": { "type": "string", "enum": [ - "messages" + "post-training/messages", + "eval/question-answer" ], - "description": "The schema format of the dataset. One of - messages: The dataset contains a messages column with list of messages for post-training." + "description": "The purpose of the dataset. One of - \"post-training/messages\": The dataset contains a messages column with list of messages for post-training. - \"eval/question-answer\": The dataset contains a question and answer column." }, - "data_source": { + "source": { "$ref": "#/components/schemas/DataSource", "description": "The data source of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"huggingface\", \"dataset_path\": \"tatsu-lab/alpaca\", \"params\": { \"split\": \"train\" } } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }" }, @@ -9463,8 +9472,8 @@ }, "additionalProperties": false, "required": [ - "schema", - "data_source" + "purpose", + "source" ], "title": "RegisterDatasetRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 21625827a..93ba4ba30 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -4738,13 +4738,14 @@ components: type: string const: dataset default: dataset - schema: + purpose: type: string enum: - - messages - title: Schema + - post-training/messages + - eval/question-answer + title: DatasetPurpose description: >- - Schema of the dataset. Each type has a different column format. + Purpose of the dataset. Each type has a different column format. data_source: $ref: '#/components/schemas/DataSource' metadata: @@ -4763,7 +4764,7 @@ components: - provider_resource_id - provider_id - type - - schema + - purpose - data_source - metadata title: Dataset @@ -4774,8 +4775,10 @@ components: type: string const: huggingface default: huggingface - dataset_path: + path: type: string + description: >- + The path to the dataset in Huggingface. E.g. - "llamastack/simpleqa" params: type: object additionalProperties: @@ -4786,12 +4789,14 @@ components: - type: string - type: array - type: object + description: The parameters for the dataset. additionalProperties: false required: - type - - dataset_path + - path - params title: HuggingfaceDataSource + description: A dataset stored in Huggingface. RowsDataSource: type: object properties: @@ -4811,11 +4816,16 @@ components: - type: string - type: array - type: object + description: >- + The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user", + "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, + world!"}]} ] additionalProperties: false required: - type - rows title: RowsDataSource + description: A dataset stored in rows. URIDataSource: type: object properties: @@ -4825,11 +4835,16 @@ components: default: uri uri: type: string + description: >- + The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl" + - "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}" additionalProperties: false required: - type - uri title: URIDataSource + description: >- + A dataset that can be obtained from a URI. Model: type: object properties: @@ -6367,14 +6382,16 @@ components: RegisterDatasetRequest: type: object properties: - schema: + purpose: type: string enum: - - messages + - post-training/messages + - eval/question-answer description: >- - The schema format of the dataset. One of - messages: The dataset contains - a messages column with list of messages for post-training. - data_source: + The purpose of the dataset. One of - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. - + "eval/question-answer": The dataset contains a question and answer column. + source: $ref: '#/components/schemas/DataSource' description: >- The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" @@ -6401,8 +6418,8 @@ components: The ID of the dataset. If not provided, a random ID will be generated. additionalProperties: false required: - - schema - - data_source + - purpose + - source title: RegisterDatasetRequest RegisterModelRequest: type: object diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index b18dd204b..26ad85422 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho class DatasetPurpose(Enum): """ Purpose of the dataset. Each type has a different column format. - :cvar tuning/messages: The dataset contains messages used for post-training. Examples: + :cvar post-training/messages: The dataset contains messages used for post-training. Examples: { "messages": [ {"role": "user", "content": "Hello, world!"}, @@ -25,12 +25,19 @@ class DatasetPurpose(Enum): } """ - tuning_messages = "tuning/messages" + post_training_messages = "post-training/messages" + eval_question_answer = "eval/question-answer" # TODO: add more schemas here class DatasetType(Enum): + """ + Type of the dataset source. + :cvar huggingface: The dataset is stored in Huggingface. + :cvar uri: The dataset can be obtained from a URI. + :cvar rows: The dataset is stored in rows. + """ huggingface = "huggingface" uri = "uri" rows = "rows" @@ -38,19 +45,36 @@ class DatasetType(Enum): @json_schema_type class URIDataSource(BaseModel): + """A dataset that can be obtained from a URI. + :param uri: The dataset can be obtained from a URI. E.g. + - "https://mywebsite.com/mydata.jsonl" + - "lsfs://mydata.jsonl" + - "data:csv;base64,{base64_content}" + """ type: Literal["uri"] = "uri" uri: str @json_schema_type class HuggingfaceDataSource(BaseModel): + """A dataset stored in Huggingface. + :param path: The path to the dataset in Huggingface. E.g. + - "llamastack/simpleqa" + :param params: The parameters for the dataset. + """ type: Literal["huggingface"] = "huggingface" - dataset_path: str + path: str params: Dict[str, Any] @json_schema_type class RowsDataSource(BaseModel): + """A dataset stored in rows. + :param rows: The dataset is stored in rows. E.g. + - [ + {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]} + ] + """ type: Literal["rows"] = "rows" rows: List[Dict[str, Any]] @@ -65,7 +89,10 @@ DataSource = register_schema( class CommonDatasetFields(BaseModel): - schema: Schema + """ + Common fields for a dataset. + """ + purpose: DatasetPurpose data_source: DataSource metadata: Dict[str, Any] = Field( default_factory=dict, @@ -108,9 +135,10 @@ class Datasets(Protocol): """ Register a new dataset. - :param schema: The schema format of the dataset. One of - - messages: The dataset contains a messages column with list of messages for post-training. - :param data_source: The data source of the dataset. Examples: + :param purpose: The purpose of the dataset. One of + - "post-training/messages": The dataset contains a messages column with list of messages for post-training. + - "eval/question-answer": The dataset contains a question and answer column. + :param source: The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" From 790b2d5cc0e7f90ac38f86c9521fba157dbae6a9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 18:51:46 -0700 Subject: [PATCH 10/11] source --- llama_stack/apis/datasets/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 26ad85422..36f75d7b3 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -93,7 +93,7 @@ class CommonDatasetFields(BaseModel): Common fields for a dataset. """ purpose: DatasetPurpose - data_source: DataSource + source: DataSource metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this dataset", From 09039eca5740af1f8d1e9ff04cd50a7ceb7f94af Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 18:52:05 -0700 Subject: [PATCH 11/11] source --- docs/_static/llama-stack-spec.html | 4 ++-- docs/_static/llama-stack-spec.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 856c6e715..47f48df3c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -6855,7 +6855,7 @@ "title": "DatasetPurpose", "description": "Purpose of the dataset. Each type has a different column format." }, - "data_source": { + "source": { "$ref": "#/components/schemas/DataSource" }, "metadata": { @@ -6891,7 +6891,7 @@ "provider_id", "type", "purpose", - "data_source", + "source", "metadata" ], "title": "Dataset" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 93ba4ba30..16ef6fed4 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -4746,7 +4746,7 @@ components: title: DatasetPurpose description: >- Purpose of the dataset. Each type has a different column format. - data_source: + source: $ref: '#/components/schemas/DataSource' metadata: type: object @@ -4765,7 +4765,7 @@ components: - provider_id - type - purpose - - data_source + - source - metadata title: Dataset HuggingfaceDataSource: