diff --git a/litellm/litellm_core_utils/dd_tracing.py b/litellm/litellm_core_utils/dd_tracing.py new file mode 100644 index 0000000000..4df3026144 --- /dev/null +++ b/litellm/litellm_core_utils/dd_tracing.py @@ -0,0 +1,53 @@ +""" +Handles Tracing on DataDog Traces. + +If the ddtrace package is not installed, the tracer will be a no-op. +""" + +from contextlib import contextmanager + +try: + from ddtrace import tracer as dd_tracer + + has_ddtrace = True +except ImportError: + has_ddtrace = False + + @contextmanager + def null_tracer(name, **kwargs): + class NullSpan: + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def finish(self): + pass + + yield NullSpan() + + class NullTracer: + def trace(self, name, **kwargs): + class NullSpan: + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def finish(self): + pass + + return NullSpan() + + def wrap(self, name=None, **kwargs): + def decorator(f): + return f + + return decorator + + dd_tracer = NullTracer() + +# Export the tracer instance +tracer = dd_tracer diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index 7b04b2c02a..c46a5f8a0e 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from litellm._logging import verbose_logger from litellm.caching.caching import DualCache +from litellm.litellm_core_utils.dd_tracing import tracer from litellm.secret_managers.main import get_secret, get_secret_str if TYPE_CHECKING: @@ -63,6 +64,7 @@ class BaseAWSLLM: credential_str = json.dumps(credential_args, sort_keys=True) return hashlib.sha256(credential_str.encode()).hexdigest() + @tracer.wrap() def get_credentials( self, aws_access_key_id: Optional[str] = None, @@ -200,6 +202,7 @@ class BaseAWSLLM: self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) return credentials + @tracer.wrap() def _auth_with_web_identity_token( self, aws_web_identity_token: str, @@ -230,11 +233,12 @@ class BaseAWSLLM: status_code=401, ) - sts_client = boto3.client( - "sts", - region_name=aws_region_name, - endpoint_url=sts_endpoint, - ) + with tracer.trace("boto3.client(sts)"): + sts_client = boto3.client( + "sts", + region_name=aws_region_name, + endpoint_url=sts_endpoint, + ) # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html @@ -258,11 +262,13 @@ class BaseAWSLLM: f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}" ) - session = boto3.Session(**iam_creds_dict) + with tracer.trace("boto3.Session(**iam_creds_dict)"): + session = boto3.Session(**iam_creds_dict) iam_creds = session.get_credentials() return iam_creds, self._get_default_ttl_for_boto3_credentials() + @tracer.wrap() def _auth_with_aws_role( self, aws_access_key_id: Optional[str], @@ -276,11 +282,12 @@ class BaseAWSLLM: import boto3 from botocore.credentials import Credentials - sts_client = boto3.client( - "sts", - aws_access_key_id=aws_access_key_id, # [OPTIONAL] - aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] - ) + with tracer.trace("boto3.client(sts)"): + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) sts_response = sts_client.assume_role( RoleArn=aws_role_name, RoleSessionName=aws_session_name @@ -288,7 +295,6 @@ class BaseAWSLLM: # Extract the credentials from the response and convert to Session Credentials sts_credentials = sts_response["Credentials"] - credentials = Credentials( access_key=sts_credentials["AccessKeyId"], secret_key=sts_credentials["SecretAccessKey"], @@ -301,6 +307,7 @@ class BaseAWSLLM: sts_ttl = (sts_expiry - current_time).total_seconds() - 60 return credentials, sts_ttl + @tracer.wrap() def _auth_with_aws_profile( self, aws_profile_name: str ) -> Tuple[Credentials, Optional[int]]: @@ -310,9 +317,11 @@ class BaseAWSLLM: import boto3 # uses auth values from AWS profile usually stored in ~/.aws/credentials - client = boto3.Session(profile_name=aws_profile_name) - return client.get_credentials(), None + with tracer.trace("boto3.Session(profile_name=aws_profile_name)"): + client = boto3.Session(profile_name=aws_profile_name) + return client.get_credentials(), None + @tracer.wrap() def _auth_with_aws_session_token( self, aws_access_key_id: str, @@ -333,6 +342,7 @@ class BaseAWSLLM: return credentials, None + @tracer.wrap() def _auth_with_access_key_and_secret_key( self, aws_access_key_id: str, @@ -345,26 +355,31 @@ class BaseAWSLLM: import boto3 # Check if credentials are already in cache. These credentials have no expiry time. - - session = boto3.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, - ) + with tracer.trace( + "boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)" + ): + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) credentials = session.get_credentials() return credentials, self._get_default_ttl_for_boto3_credentials() + @tracer.wrap() def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]: """ Authenticate with AWS Environment Variables """ import boto3 - session = boto3.Session() - credentials = session.get_credentials() - return credentials, None + with tracer.trace("boto3.Session()"): + session = boto3.Session() + credentials = session.get_credentials() + return credentials, None + @tracer.wrap() def _get_default_ttl_for_boto3_credentials(self) -> int: """ Get the default TTL for boto3 credentials @@ -475,6 +490,7 @@ class BaseAWSLLM: aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, ) + @tracer.wrap() def get_request_headers( self, credentials: Credentials, diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index 57cccad7e0..b70c15b3e1 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -1,6 +1,6 @@ import json import urllib -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import httpx @@ -60,7 +60,6 @@ def make_sync_call( api_key="", data=data, messages=messages, - print_verbose=litellm.print_verbose, encoding=litellm.encoding, ) # type: ignore completion_stream: Any = MockResponseIterator( @@ -102,7 +101,6 @@ class BedrockConverseLLM(BaseAWSLLM): messages: list, api_base: str, model_response: ModelResponse, - print_verbose: Callable, timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj, @@ -170,7 +168,6 @@ class BedrockConverseLLM(BaseAWSLLM): messages: list, api_base: str, model_response: ModelResponse, - print_verbose: Callable, timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj: LiteLLMLoggingObject, @@ -247,7 +244,6 @@ class BedrockConverseLLM(BaseAWSLLM): api_key="", data=data, messages=messages, - print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, ) @@ -259,7 +255,6 @@ class BedrockConverseLLM(BaseAWSLLM): api_base: Optional[str], custom_prompt_dict: dict, model_response: ModelResponse, - print_verbose: Callable, encoding, logging_obj: LiteLLMLoggingObject, optional_params: dict, @@ -271,11 +266,6 @@ class BedrockConverseLLM(BaseAWSLLM): client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ): - try: - from botocore.credentials import Credentials - except ImportError: - raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") - ## SETUP ## stream = optional_params.pop("stream", None) modelId = optional_params.pop("model_id", None) @@ -367,7 +357,6 @@ class BedrockConverseLLM(BaseAWSLLM): messages=messages, api_base=proxy_endpoint_url, model_response=model_response, - print_verbose=print_verbose, encoding=encoding, logging_obj=logging_obj, optional_params=optional_params, @@ -387,7 +376,6 @@ class BedrockConverseLLM(BaseAWSLLM): messages=messages, api_base=proxy_endpoint_url, model_response=model_response, - print_verbose=print_verbose, encoding=encoding, logging_obj=logging_obj, optional_params=optional_params, @@ -489,7 +477,6 @@ class BedrockConverseLLM(BaseAWSLLM): api_key="", data=data, messages=messages, - print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, ) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index ae79bcb0af..68ae3af478 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse` import copy import time import types -from typing import Callable, List, Literal, Optional, Tuple, Union, cast, overload +from typing import List, Literal, Optional, Tuple, Union, cast, overload import httpx @@ -542,7 +542,6 @@ class AmazonConverseConfig(BaseConfig): api_key=api_key, data=request_data, messages=messages, - print_verbose=None, encoding=encoding, ) @@ -557,7 +556,6 @@ class AmazonConverseConfig(BaseConfig): api_key: Optional[str], data: Union[dict, str], messages: List, - print_verbose: Optional[Callable], encoding, ) -> ModelResponse: ## LOGGING diff --git a/litellm/main.py b/litellm/main.py index 152c61ac37..cc74080245 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2638,7 +2638,6 @@ def completion( # type: ignore # noqa: PLR0915 messages=messages, custom_prompt_dict=custom_prompt_dict, model_response=model_response, - print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 8b763c043c..5b5cb038e0 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -20,6 +20,7 @@ import litellm from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._service_logger import ServiceLogging from litellm.caching import DualCache +from litellm.litellm_core_utils.dd_tracing import tracer from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( _cache_key_object, @@ -897,7 +898,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # Check 3. If token is expired if valid_token.expires is not None: current_time = datetime.now(timezone.utc) - expiry_time = datetime.fromisoformat(valid_token.expires) + if isinstance(valid_token.expires, datetime): + expiry_time = valid_token.expires + else: + expiry_time = datetime.fromisoformat(valid_token.expires) if ( expiry_time.tzinfo is None or expiry_time.tzinfo.utcoffset(expiry_time) is None @@ -1127,6 +1131,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) +@tracer.wrap() async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header), diff --git a/litellm/router.py b/litellm/router.py index bd7d101c7a..84946c6b43 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -48,6 +48,7 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.litellm_core_utils.dd_tracing import tracer from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.router_strategy.budget_limiter import RouterBudgetLimiting from litellm.router_strategy.least_busy import LeastBusyLoggingHandler @@ -2857,6 +2858,7 @@ class Router: #### [END] ASSISTANTS API #### + @tracer.wrap() async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 """ Try calling the function_with_retries @@ -3127,6 +3129,7 @@ class Router: Context_Policy_Fallbacks={content_policy_fallbacks}", ) + @tracer.wrap() async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 verbose_router_logger.debug("Inside async function with retries.") original_function = kwargs.pop("original_function") diff --git a/tests/litellm/litellm_core_utils/test_dd_tracing.py b/tests/litellm/litellm_core_utils/test_dd_tracing.py new file mode 100644 index 0000000000..0492295071 --- /dev/null +++ b/tests/litellm/litellm_core_utils/test_dd_tracing.py @@ -0,0 +1,53 @@ +import json +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.litellm_core_utils.dd_tracing import dd_tracer + + +def test_dd_tracer_when_package_exists(): + with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", True): + # Test the trace context manager + with dd_tracer.trace("test_operation") as span: + assert span is not None + + # Test the wrapper decorator + @dd_tracer.wrap(name="test_function") + def sample_function(): + return "test" + + result = sample_function() + assert result == "test" + + +def test_dd_tracer_when_package_not_exists(): + with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", False): + # Test the trace context manager with null tracer + with dd_tracer.trace("test_operation") as span: + assert span is not None + # Verify null span methods don't raise exceptions + span.finish() + + # Test the wrapper decorator with null tracer + @dd_tracer.wrap(name="test_function") + def sample_function(): + return "test" + + result = sample_function() + assert result == "test" + + +def test_null_tracer_context_manager(): + with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", False): + # Test that the context manager works without raising exceptions + with dd_tracer.trace("test_operation") as span: + # Test that we can call methods on the null span + span.finish() + assert True # If we get here without exceptions, the test passes diff --git a/tests/litellm/llms/bedrock/test_base_aws_llm.py b/tests/litellm/llms/bedrock/test_base_aws_llm.py new file mode 100644 index 0000000000..3a2f691c1d --- /dev/null +++ b/tests/litellm/llms/bedrock/test_base_aws_llm.py @@ -0,0 +1,100 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + + +from datetime import datetime, timezone +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +from botocore.credentials import Credentials + +import litellm +from litellm.llms.bedrock.base_aws_llm import ( + AwsAuthError, + BaseAWSLLM, + Boto3CredentialsInfo, +) + +# Global variable for the base_aws_llm.py file path + +BASE_AWS_LLM_PATH = os.path.join( + os.path.dirname(__file__), "../../../../litellm/llms/bedrock/base_aws_llm.py" +) + + +def test_boto3_init_tracer_wrapping(): + """ + Test that all boto3 initializations are wrapped in tracer.trace or @tracer.wrap + + Ensures observability of boto3 calls in litellm. + """ + # Get the source code of base_aws_llm.py + with open(BASE_AWS_LLM_PATH, "r") as f: + content = f.read() + + # List all boto3 initialization patterns we want to check + boto3_init_patterns = ["boto3.client", "boto3.Session"] + + lines = content.split("\n") + # Check each boto3 initialization is wrapped in tracer.trace + for line_number, line in enumerate(lines, 1): + for pattern in boto3_init_patterns: + if pattern in line: + # Look back up to 5 lines for decorator or trace block + start_line = max(0, line_number - 5) + context_lines = lines[start_line:line_number] + + has_trace = ( + "tracer.trace" in line + or any("tracer.trace" in prev_line for prev_line in context_lines) + or any("@tracer.wrap" in prev_line for prev_line in context_lines) + ) + + if not has_trace: + print(f"\nContext for line {line_number}:") + for i, ctx_line in enumerate(context_lines, start=start_line + 1): + print(f"{i}: {ctx_line}") + + assert ( + has_trace + ), f"boto3 initialization '{pattern}' on line {line_number} is not wrapped with tracer.trace or @tracer.wrap" + + +def test_auth_functions_tracer_wrapping(): + """ + Test that all _auth functions in base_aws_llm.py are wrapped with @tracer.wrap + + Ensures observability of AWS authentication calls in litellm. + """ + # Get the source code of base_aws_llm.py + with open(BASE_AWS_LLM_PATH, "r") as f: + content = f.read() + + lines = content.split("\n") + # Check each line for _auth function definitions + for line_number, line in enumerate(lines, 1): + if line.strip().startswith("def _auth_"): + # Look back up to 2 lines for the @tracer.wrap decorator + start_line = max(0, line_number - 2) + context_lines = lines[start_line:line_number] + + has_tracer_wrap = any( + "@tracer.wrap" in prev_line for prev_line in context_lines + ) + + if not has_tracer_wrap: + print(f"\nContext for line {line_number}:") + for i, ctx_line in enumerate(context_lines, start=start_line + 1): + print(f"{i}: {ctx_line}") + + assert ( + has_tracer_wrap + ), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}"