import asyncio import os import sys from unittest.mock import Mock import pytest from fastapi import Request import litellm sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.types.utils import SupportedCacheControls @pytest.fixture def mock_request(monkeypatch): mock_request = Mock(spec=Request) mock_request.query_params = {} # Set mock query_params to an empty dictionary mock_request.headers = {"traceparent": "test_traceparent"} monkeypatch.setattr( "litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", mock_request ) return mock_request @pytest.mark.parametrize("endpoint", ["/v1/threads", "/v1/thread/123"]) @pytest.mark.asyncio async def test_add_litellm_data_to_request_thread_endpoint(endpoint, mock_request): mock_request.url.path = endpoint user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" ) proxy_config = Mock() data = {} await add_litellm_data_to_request( data, mock_request, user_api_key_dict, proxy_config ) print("DATA: ", data) assert "litellm_metadata" in data assert "metadata" not in data @pytest.mark.parametrize( "endpoint", ["/chat/completions", "/v1/completions", "/completions"] ) @pytest.mark.asyncio async def test_add_litellm_data_to_request_non_thread_endpoint(endpoint, mock_request): mock_request.url.path = endpoint user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" ) proxy_config = Mock() data = {} await add_litellm_data_to_request( data, mock_request, user_api_key_dict, proxy_config ) print("DATA: ", data) assert "metadata" in data assert "litellm_metadata" not in data # test adding traceparent @pytest.mark.parametrize( "endpoint", ["/chat/completions", "/v1/completions", "/completions"] ) @pytest.mark.asyncio async def test_traceparent_not_added_by_default(endpoint, mock_request): """ This tests that traceparent is not forwarded in the extra_headers We had an incident where bedrock calls were failing because traceparent was forwarded """ from litellm.integrations.opentelemetry import OpenTelemetry otel_logger = OpenTelemetry() setattr(litellm.proxy.proxy_server, "open_telemetry_logger", otel_logger) mock_request.url.path = endpoint user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" ) proxy_config = Mock() data = {} await add_litellm_data_to_request( data, mock_request, user_api_key_dict, proxy_config ) print("DATA: ", data) _extra_headers = data.get("extra_headers") or {} assert "traceparent" not in _extra_headers setattr(litellm.proxy.proxy_server, "open_telemetry_logger", None)