diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index 7a6aec436..210b74de3 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -11,9 +11,17 @@ from llama_stack.core.server.routes import find_matching_route, initialize_route from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger from llama_stack_api.datatypes import Api +from llama_stack_api.version import ( + LLAMA_STACK_API_V1, + LLAMA_STACK_API_V1ALPHA, + LLAMA_STACK_API_V1BETA, +) logger = get_logger(name=__name__, category="core::server") +# Valid API version levels - all routes must start with one of these +VALID_API_LEVELS = {LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA, LLAMA_STACK_API_V1BETA} + class TracingMiddleware: def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): @@ -30,9 +38,9 @@ class TracingMiddleware: We need to check if the path matches any router-based API prefix. """ # Extract API name from path (e.g., /v1/batches -> batches) - # Paths are typically /v1/{api_name} or /v1/{api_name}/... + # Paths must start with a valid API level: /v1/{api_name} or /v1alpha/{api_name} or /v1beta/{api_name} parts = path.strip("/").split("/") - if len(parts) >= 2 and parts[0].startswith("v"): + if len(parts) >= 2 and parts[0] in VALID_API_LEVELS: api_name = parts[1] try: api = Api(api_name)