From d7b294dd0a7dba3f172ca5d930997bf3d55639a3 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 29 Mar 2025 11:02:13 -0700 Subject: [PATCH] build(pyproject.toml): add new dev dependencies - for type checking (#9631) * build(pyproject.toml): add new dev dependencies - for type checking * build: reformat files to fit black * ci: reformat to fit black * ci(test-litellm.yml): make tests run clear * build(pyproject.toml): add ruff * fix: fix ruff checks * build(mypy/): fix mypy linting errors * fix(hashicorp_secret_manager.py): fix passing cert for tls auth * build(mypy/): resolve all mypy errors * test: update test * fix: fix black formatting * build(pre-commit-config.yaml): use poetry run black * fix(proxy_server.py): fix linting error * fix: fix ruff safe representation error --- .github/workflows/test-linting.yml | 53 +++ .github/workflows/test-litellm.yml | 2 +- .pre-commit-config.yaml | 10 +- .../enterprise_hooks/secret_detection.py | 4 - litellm/__init__.py | 98 +++--- litellm/_service_logger.py | 2 +- litellm/batches/main.py | 2 - litellm/caching/base_cache.py | 4 +- litellm/caching/caching_handler.py | 11 +- litellm/caching/disk_cache.py | 4 +- litellm/caching/dual_cache.py | 4 +- litellm/caching/llm_caching_handler.py | 1 - litellm/caching/redis_cache.py | 3 +- litellm/caching/redis_cluster_cache.py | 4 +- litellm/caching/redis_semantic_cache.py | 127 +++---- litellm/caching/s3_cache.py | 2 +- litellm/cost_calculator.py | 1 - litellm/fine_tuning/main.py | 3 - .../SlackAlerting/batching_handler.py | 1 - .../SlackAlerting/slack_alerting.py | 28 +- litellm/integrations/_types/open_inference.py | 2 +- litellm/integrations/argilla.py | 1 - litellm/integrations/arize/_utils.py | 4 +- litellm/integrations/arize/arize.py | 3 +- litellm/integrations/arize/arize_phoenix.py | 27 +- litellm/integrations/athina.py | 5 +- .../azure_storage/azure_storage.py | 13 +- litellm/integrations/braintrust_logging.py | 103 ++++-- litellm/integrations/custom_batch_logger.py | 1 - litellm/integrations/custom_guardrail.py | 1 - litellm/integrations/custom_logger.py | 2 +- litellm/integrations/datadog/datadog.py | 1 - .../gcs_bucket/gcs_bucket_base.py | 6 +- litellm/integrations/gcs_pubsub/pub_sub.py | 13 +- litellm/integrations/humanloop.py | 6 +- litellm/integrations/langfuse/langfuse.py | 6 +- .../integrations/langfuse/langfuse_handler.py | 5 +- .../langfuse/langfuse_prompt_management.py | 6 +- litellm/integrations/langsmith.py | 8 +- litellm/integrations/langtrace.py | 4 +- litellm/integrations/lunary.py | 2 - litellm/integrations/mlflow.py | 9 +- litellm/integrations/opentelemetry.py | 28 +- litellm/integrations/opik/opik.py | 1 - litellm/integrations/prometheus.py | 6 +- .../integrations/prompt_management_base.py | 7 +- litellm/integrations/s3.py | 2 +- litellm/integrations/weights_biases.py | 8 +- litellm/litellm_core_utils/core_helpers.py | 2 +- .../litellm_core_utils/default_encoding.py | 4 +- litellm/litellm_core_utils/litellm_logging.py | 317 +++++++++--------- .../convert_dict_to_response.py | 8 +- .../litellm_core_utils/model_param_helper.py | 1 - .../prompt_templates/common_utils.py | 1 - .../prompt_templates/factory.py | 26 +- .../litellm_core_utils/realtime_streaming.py | 1 - litellm/litellm_core_utils/redact_messages.py | 6 +- .../sensitive_data_masker.py | 6 +- .../streaming_chunk_builder_utils.py | 1 - .../litellm_core_utils/streaming_handler.py | 20 +- litellm/llms/anthropic/chat/handler.py | 16 +- litellm/llms/anthropic/chat/transformation.py | 46 +-- .../anthropic/completion/transformation.py | 6 +- .../messages/handler.py | 28 +- litellm/llms/azure/azure.py | 6 - litellm/llms/azure/batches/handler.py | 72 ++-- litellm/llms/azure/common_utils.py | 1 - litellm/llms/azure/files/handler.py | 91 +++-- litellm/llms/azure/fine_tuning/handler.py | 9 +- litellm/llms/azure_ai/chat/transformation.py | 1 - .../azure_ai/embed/cohere_transformation.py | 1 - litellm/llms/azure_ai/embed/handler.py | 2 - .../llms/azure_ai/rerank/transformation.py | 1 + litellm/llms/base.py | 1 - litellm/llms/base_llm/chat/transformation.py | 1 - .../llms/base_llm/responses/transformation.py | 1 - litellm/llms/bedrock/chat/converse_handler.py | 10 +- .../bedrock/chat/converse_transformation.py | 20 +- litellm/llms/bedrock/chat/invoke_handler.py | 14 +- .../base_invoke_transformation.py | 7 +- litellm/llms/bedrock/common_utils.py | 1 - .../amazon_titan_multimodal_transformation.py | 7 +- .../amazon_nova_canvas_transformation.py | 93 +++-- litellm/llms/bedrock/image/image_handler.py | 13 +- litellm/llms/bedrock/rerank/handler.py | 1 - litellm/llms/bedrock/rerank/transformation.py | 1 - litellm/llms/codestral/completion/handler.py | 3 - .../codestral/completion/transformation.py | 1 - litellm/llms/cohere/chat/transformation.py | 2 - litellm/llms/cohere/embed/handler.py | 1 - litellm/llms/cohere/embed/transformation.py | 2 - litellm/llms/cohere/rerank/transformation.py | 2 +- .../llms/cohere/rerank_v2/transformation.py | 3 +- litellm/llms/custom_httpx/aiohttp_handler.py | 2 - litellm/llms/custom_httpx/http_handler.py | 4 - litellm/llms/custom_httpx/llm_http_handler.py | 5 - litellm/llms/databricks/common_utils.py | 6 +- .../llms/databricks/embed/transformation.py | 6 +- litellm/llms/databricks/streaming_utils.py | 1 - .../audio_transcription/transformation.py | 6 +- litellm/llms/deepseek/chat/transformation.py | 1 - .../llms/deprecated_providers/aleph_alpha.py | 6 +- .../llms/fireworks_ai/chat/transformation.py | 3 - litellm/llms/gemini/chat/transformation.py | 1 - litellm/llms/gemini/common_utils.py | 1 - litellm/llms/groq/chat/transformation.py | 1 - litellm/llms/groq/stt/transformation.py | 1 - .../llms/huggingface/chat/transformation.py | 24 +- litellm/llms/maritalk.py | 1 - .../llms/ollama/completion/transformation.py | 6 +- .../llms/openai/chat/gpt_transformation.py | 1 - litellm/llms/openai/completion/handler.py | 1 - .../llms/openai/completion/transformation.py | 6 +- litellm/llms/openai/fine_tuning/handler.py | 9 +- litellm/llms/openai/openai.py | 25 +- .../transcriptions/whisper_transformation.py | 6 +- .../llms/openrouter/chat/transformation.py | 8 +- .../llms/petals/completion/transformation.py | 6 +- litellm/llms/predibase/chat/handler.py | 1 - litellm/llms/predibase/chat/transformation.py | 12 +- litellm/llms/replicate/chat/handler.py | 1 - litellm/llms/sagemaker/chat/handler.py | 2 - litellm/llms/sagemaker/common_utils.py | 2 - litellm/llms/sagemaker/completion/handler.py | 12 +- .../sagemaker/completion/transformation.py | 6 +- .../llms/together_ai/rerank/transformation.py | 1 - .../topaz/image_variations/transformation.py | 2 - .../llms/triton/completion/transformation.py | 1 - litellm/llms/vertex_ai/batches/handler.py | 7 +- litellm/llms/vertex_ai/common_utils.py | 4 +- litellm/llms/vertex_ai/files/handler.py | 9 +- litellm/llms/vertex_ai/fine_tuning/handler.py | 15 +- .../llms/vertex_ai/gemini/transformation.py | 26 +- .../vertex_and_google_ai_studio_gemini.py | 18 +- .../batch_embed_content_handler.py | 1 - .../batch_embed_content_transformation.py | 1 - .../embedding_handler.py | 1 - .../multimodal_embeddings/transformation.py | 1 - .../llms/vertex_ai/vertex_ai_non_gemini.py | 2 - .../vertex_ai_partner_models/main.py | 1 - .../vertex_embeddings/embedding_handler.py | 12 +- .../vertex_embeddings/transformation.py | 1 - .../vertex_ai/vertex_model_garden/main.py | 1 - litellm/llms/watsonx/chat/transformation.py | 1 - litellm/main.py | 56 ++-- litellm/proxy/_types.py | 68 ++-- litellm/proxy/auth/auth_checks.py | 7 +- .../proxy/auth/auth_checks_organization.py | 14 +- litellm/proxy/auth/auth_exception_handler.py | 5 +- litellm/proxy/auth/auth_utils.py | 2 - litellm/proxy/auth/handle_jwt.py | 23 +- litellm/proxy/auth/litellm_license.py | 1 - litellm/proxy/auth/model_checks.py | 1 - litellm/proxy/auth/route_checks.py | 3 - litellm/proxy/auth/user_api_key_auth.py | 28 +- .../common_utils/encrypt_decrypt_utils.py | 2 - .../proxy/common_utils/http_parsing_utils.py | 8 +- litellm/proxy/db/log_db_metrics.py | 1 - litellm/proxy/db/redis_update_buffer.py | 22 +- litellm/proxy/guardrails/guardrail_helpers.py | 1 - .../guardrail_hooks/bedrock_guardrails.py | 1 - .../guardrails/guardrail_hooks/lakera_ai.py | 6 +- .../guardrails/guardrail_hooks/presidio.py | 10 +- litellm/proxy/hooks/dynamic_rate_limiter.py | 58 ++-- .../proxy/hooks/key_management_event_hooks.py | 1 - .../proxy/hooks/parallel_request_limiter.py | 34 +- .../proxy/hooks/prompt_injection_detection.py | 1 - .../proxy/hooks/proxy_track_cost_callback.py | 8 +- litellm/proxy/litellm_pre_call_utils.py | 16 +- .../budget_management_endpoints.py | 1 - .../management_endpoints/common_utils.py | 1 - .../customer_endpoints.py | 1 - .../internal_user_endpoints.py | 41 ++- .../key_management_endpoints.py | 41 ++- .../model_management_endpoints.py | 3 - .../organization_endpoints.py | 43 +-- .../management_endpoints/team_endpoints.py | 28 +- litellm/proxy/management_endpoints/ui_sso.py | 12 +- litellm/proxy/management_helpers/utils.py | 3 +- .../openai_files_endpoints/files_endpoints.py | 1 - .../anthropic_passthrough_logging_handler.py | 7 +- .../vertex_passthrough_logging_handler.py | 5 +- .../pass_through_endpoints.py | 3 - .../passthrough_endpoint_router.py | 6 +- .../streaming_handler.py | 1 - litellm/proxy/prisma_migration.py | 30 +- litellm/proxy/proxy_server.py | 36 +- .../spend_management_endpoints.py | 6 - litellm/proxy/utils.py | 9 +- .../vertex_ai_endpoints/langfuse_endpoints.py | 8 +- litellm/rerank_api/main.py | 18 +- litellm/responses/main.py | 27 +- litellm/router.py | 94 +++--- .../router_strategy/base_routing_strategy.py | 6 +- litellm/router_strategy/budget_limiter.py | 41 +-- litellm/router_strategy/lowest_latency.py | 2 +- litellm/router_strategy/lowest_tpm_rpm_v2.py | 3 +- litellm/router_utils/cooldown_cache.py | 4 +- litellm/router_utils/cooldown_callbacks.py | 6 +- litellm/router_utils/cooldown_handlers.py | 2 +- litellm/router_utils/handle_error.py | 4 +- .../router_utils/pattern_match_deployments.py | 12 +- litellm/router_utils/prompt_caching_cache.py | 4 +- .../secret_managers/aws_secret_manager_v2.py | 1 - .../hashicorp_secret_manager.py | 6 +- litellm/types/integrations/arize_phoenix.py | 4 +- litellm/types/llms/openai.py | 36 +- litellm/types/rerank.py | 7 +- litellm/types/router.py | 29 +- litellm/types/utils.py | 3 - litellm/utils.py | 53 ++- poetry.lock | 139 +++++++- pyproject.toml | 7 +- tests/litellm_utils_tests/test_hashicorp.py | 36 +- 214 files changed, 1553 insertions(+), 1433 deletions(-) create mode 100644 .github/workflows/test-linting.yml diff --git a/.github/workflows/test-linting.yml b/.github/workflows/test-linting.yml new file mode 100644 index 0000000000..d117d12de4 --- /dev/null +++ b/.github/workflows/test-linting.yml @@ -0,0 +1,53 @@ +name: LiteLLM Linting + +on: + pull_request: + branches: [ main ] + +jobs: + lint: + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install Poetry + uses: snok/install-poetry@v1 + + - name: Install dependencies + run: | + poetry install --with dev + + - name: Run Black formatting check + run: | + cd litellm + poetry run black . --check + cd .. + + - name: Run Ruff linting + run: | + cd litellm + poetry run ruff check . + cd .. + + - name: Run MyPy type checking + run: | + cd litellm + poetry run mypy . --ignore-missing-imports + cd .. + + - name: Check for circular imports + run: | + cd litellm + poetry run python ../tests/documentation_tests/test_circular_imports.py + cd .. + + - name: Check import safety + run: | + poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) \ No newline at end of file diff --git a/.github/workflows/test-litellm.yml b/.github/workflows/test-litellm.yml index 09de0ff110..12d09725ed 100644 --- a/.github/workflows/test-litellm.yml +++ b/.github/workflows/test-litellm.yml @@ -1,4 +1,4 @@ -name: LiteLLM Tests +name: LiteLLM Mock Tests (folder - tests/litellm) on: pull_request: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fb37f32524..bceedb41aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,10 +14,12 @@ repos: types: [python] files: litellm/.*\.py exclude: ^litellm/__init__.py$ -- repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black + - id: black + name: black + entry: poetry run black + language: system + types: [python] + files: litellm/.*\.py - repo: https://github.com/pycqa/flake8 rev: 7.0.0 # The version of flake8 to use hooks: diff --git a/enterprise/enterprise_hooks/secret_detection.py b/enterprise/enterprise_hooks/secret_detection.py index 459fd374d1..158f26efa3 100644 --- a/enterprise/enterprise_hooks/secret_detection.py +++ b/enterprise/enterprise_hooks/secret_detection.py @@ -444,9 +444,7 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail): detected_secrets = [] for file in secrets.files: - for found_secret in secrets[file]: - if found_secret.secret_value is None: continue detected_secrets.append( @@ -471,14 +469,12 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail): data: dict, call_type: str, # "completion", "embeddings", "image_generation", "moderation" ): - if await self.should_run_check(user_api_key_dict) is False: return if "messages" in data and isinstance(data["messages"], list): for message in data["messages"]: if "content" in message and isinstance(message["content"], str): - detected_secrets = self.scan_message_for_secrets(message["content"]) for secret in detected_secrets: diff --git a/litellm/__init__.py b/litellm/__init__.py index a4903f828c..c2e366e2b1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -122,19 +122,19 @@ langsmith_batch_size: Optional[int] = None prometheus_initialize_budget_metrics: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload -gcs_pub_sub_use_v1: Optional[bool] = ( - False # if you want to use v1 gcs pubsub logged payload -) +gcs_pub_sub_use_v1: Optional[ + bool +] = False # if you want to use v1 gcs pubsub logged payload argilla_transformation_object: Optional[Dict[str, Any]] = None -_async_input_callback: List[Union[str, Callable, CustomLogger]] = ( - [] -) # internal variable - async custom callbacks are routed here. -_async_success_callback: List[Union[str, Callable, CustomLogger]] = ( - [] -) # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[Union[str, Callable, CustomLogger]] = ( - [] -) # internal variable - async custom callbacks are routed here. +_async_input_callback: List[ + Union[str, Callable, CustomLogger] +] = [] # internal variable - async custom callbacks are routed here. +_async_success_callback: List[ + Union[str, Callable, CustomLogger] +] = [] # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[ + Union[str, Callable, CustomLogger] +] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False @@ -142,18 +142,18 @@ log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False filter_invalid_headers: Optional[bool] = False -add_user_information_to_llm_headers: Optional[bool] = ( - None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers -) +add_user_information_to_llm_headers: Optional[ + bool +] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers store_audit_logs = False # Enterprise feature, allow users to see audit logs ### end of callbacks ############# -email: Optional[str] = ( - None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) -token: Optional[str] = ( - None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) +email: Optional[ + str +] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +token: Optional[ + str +] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 telemetry = True max_tokens = 256 # OpenAI Defaults drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) @@ -229,24 +229,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) -caching: bool = ( - False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) -caching_with_models: bool = ( - False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -) -cache: Optional[Cache] = ( - None # cache object <- use this - https://docs.litellm.ai/docs/caching -) +caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +cache: Optional[ + Cache +] = None # cache object <- use this - https://docs.litellm.ai/docs/caching default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[str] = ( - None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). -) +budget_duration: Optional[ + str +] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). default_soft_budget: float = ( 50.0 # by default all litellm proxy keys have a soft budget of 50.0 ) @@ -255,15 +251,11 @@ forward_traceparent_to_llm_provider: bool = False _current_cost = 0.0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = ( - False # if function calling not supported by api, append function call details to system prompt -) +add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = ( - "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" -) +model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" suppress_debug_info = False dynamodb_table_name: Optional[str] = None s3_callback_params: Optional[Dict] = None @@ -285,9 +277,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None custom_prometheus_metadata_labels: List[str] = [] #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None -force_ipv4: bool = ( - False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. -) +force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. module_level_aclient = AsyncHTTPHandler( timeout=request_timeout, client_alias="module level aclient" ) @@ -301,13 +291,13 @@ fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None allowed_fails: int = 3 -num_retries_per_request: Optional[int] = ( - None # for the request overall (incl. fallbacks + model retries) -) +num_retries_per_request: Optional[ + int +] = None # for the request overall (incl. fallbacks + model retries) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[Any] = ( - None # list of instantiated key management clients - e.g. azure kv, infisical, etc. -) +secret_manager_client: Optional[ + Any +] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None _key_management_settings: KeyManagementSettings = KeyManagementSettings() @@ -1056,10 +1046,10 @@ from .types.llms.custom_llm import CustomLLMItem from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[str] = ( - [] -) # internal helper util, used to track names of custom providers -disable_hf_tokenizer_download: Optional[bool] = ( - None # disable huggingface tokenizer download. Defaults to openai clk100 -) +_custom_providers: List[ + str +] = [] # internal helper util, used to track names of custom providers +disable_hf_tokenizer_download: Optional[ + bool +] = None # disable huggingface tokenizer download. Defaults to openai clk100 global_disable_no_log_param: bool = False diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 0b4f22e210..8f835bea83 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -15,7 +15,7 @@ from .types.services import ServiceLoggerPayload, ServiceTypes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] OTELClass = OpenTelemetry else: Span = Any diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 1ddcafce4c..f4f74c72fb 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -153,7 +153,6 @@ def create_batch( ) api_base: Optional[str] = None if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base @@ -358,7 +357,6 @@ def retrieve_batch( _is_async = kwargs.pop("aretrieve_batch", False) is True api_base: Optional[str] = None if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base diff --git a/litellm/caching/base_cache.py b/litellm/caching/base_cache.py index 7109951d15..5140b390f7 100644 --- a/litellm/caching/base_cache.py +++ b/litellm/caching/base_cache.py @@ -9,12 +9,12 @@ Has 4 methods: """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 09fabf1c12..14278de9cd 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel): cached_result: Optional[Any] = None final_embedding_cached_response: Optional[EmbeddingResponse] = None - embedding_all_elements_cache_hit: bool = ( - False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call - ) + embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call class LLMCachingHandler: @@ -738,7 +736,6 @@ class LLMCachingHandler: if self._should_store_result_in_cache( original_function=self.original_function, kwargs=new_kwargs ): - litellm.cache.add_cache(result, **new_kwargs) return @@ -865,9 +862,9 @@ class LLMCachingHandler: } if litellm.cache is not None: - litellm_params["preset_cache_key"] = ( - litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) - ) + litellm_params[ + "preset_cache_key" + ] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) else: litellm_params["preset_cache_key"] = None diff --git a/litellm/caching/disk_cache.py b/litellm/caching/disk_cache.py index abf3203f50..413ac2932d 100644 --- a/litellm/caching/disk_cache.py +++ b/litellm/caching/disk_cache.py @@ -1,12 +1,12 @@ import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from .base_cache import BaseCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index 5f598f7d70..8bef333758 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -12,7 +12,7 @@ import asyncio import time import traceback from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import litellm from litellm._logging import print_verbose, verbose_logger @@ -24,7 +24,7 @@ from .redis_cache import RedisCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/caching/llm_caching_handler.py b/litellm/caching/llm_caching_handler.py index 429634b7b1..3bf1f80d08 100644 --- a/litellm/caching/llm_caching_handler.py +++ b/litellm/caching/llm_caching_handler.py @@ -8,7 +8,6 @@ from .in_memory_cache import InMemoryCache class LLMClientCache(InMemoryCache): - def update_cache_key_with_event_loop(self, key): """ Add the event loop to the cache key, to prevent event loop closed errors. diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 7378ed878a..63cd4d0959 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: cluster_pipeline = ClusterPipeline async_redis_client = Redis async_redis_cluster_client = RedisCluster - Span = _Span + Span = Union[_Span, Any] else: pipeline = Any cluster_pipeline = Any @@ -57,7 +57,6 @@ class RedisCache(BaseCache): socket_timeout: Optional[float] = 5.0, # default 5 second timeout **kwargs, ): - from litellm._service_logger import ServiceLogging from .._redis import get_redis_client, get_redis_connection_pool diff --git a/litellm/caching/redis_cluster_cache.py b/litellm/caching/redis_cluster_cache.py index 2e7d1de17f..21c3ab0366 100644 --- a/litellm/caching/redis_cluster_cache.py +++ b/litellm/caching/redis_cluster_cache.py @@ -5,7 +5,7 @@ Key differences: - RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created """ -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union from litellm.caching.redis_cache import RedisCache @@ -16,7 +16,7 @@ if TYPE_CHECKING: pipeline = Pipeline async_redis_client = Redis - Span = _Span + Span = Union[_Span, Any] else: pipeline = Any async_redis_client = Any diff --git a/litellm/caching/redis_semantic_cache.py b/litellm/caching/redis_semantic_cache.py index f46bb661ef..c76f27377d 100644 --- a/litellm/caching/redis_semantic_cache.py +++ b/litellm/caching/redis_semantic_cache.py @@ -13,23 +13,27 @@ import ast import asyncio import json import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast import litellm from litellm._logging import print_verbose -from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_str_from_messages, +) +from litellm.types.utils import EmbeddingResponse + from .base_cache import BaseCache class RedisSemanticCache(BaseCache): """ - Redis-backed semantic cache for LLM responses. - - This cache uses vector similarity to find semantically similar prompts that have been + Redis-backed semantic cache for LLM responses. + + This cache uses vector similarity to find semantically similar prompts that have been previously sent to the LLM, allowing for cache hits even when prompts are not identical but carry similar meaning. """ - + DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index" def __init__( @@ -57,7 +61,7 @@ class RedisSemanticCache(BaseCache): index_name: Name for the Redis index ttl: Default time-to-live for cache entries in seconds **kwargs: Additional arguments passed to the Redis client - + Raises: Exception: If similarity_threshold is not provided or required Redis connection information is missing @@ -69,14 +73,14 @@ class RedisSemanticCache(BaseCache): index_name = self.DEFAULT_REDIS_INDEX_NAME print_verbose(f"Redis semantic-cache initializing index - {index_name}") - + # Validate similarity threshold if similarity_threshold is None: raise ValueError("similarity_threshold must be provided, passed None") - + # Store configuration self.similarity_threshold = similarity_threshold - + # Convert similarity threshold [0,1] to distance threshold [0,2] # For cosine distance: 0 = most similar, 2 = least similar # While similarity: 1 = most similar, 0 = least similar @@ -87,14 +91,16 @@ class RedisSemanticCache(BaseCache): if redis_url is None: try: # Attempt to use provided parameters or fallback to environment variables - host = host or os.environ['REDIS_HOST'] - port = port or os.environ['REDIS_PORT'] - password = password or os.environ['REDIS_PASSWORD'] + host = host or os.environ["REDIS_HOST"] + port = port or os.environ["REDIS_PORT"] + password = password or os.environ["REDIS_PASSWORD"] except KeyError as e: # Raise a more informative exception if any of the required keys are missing missing_var = e.args[0] - raise ValueError(f"Missing required Redis configuration: {missing_var}. " - f"Provide {missing_var} or redis_url.") from e + raise ValueError( + f"Missing required Redis configuration: {missing_var}. " + f"Provide {missing_var} or redis_url." + ) from e redis_url = f"redis://:{password}@{host}:{port}" @@ -114,7 +120,7 @@ class RedisSemanticCache(BaseCache): def _get_ttl(self, **kwargs) -> Optional[int]: """ Get the TTL (time-to-live) value for cache entries. - + Args: **kwargs: Keyword arguments that may contain a custom TTL @@ -125,22 +131,25 @@ class RedisSemanticCache(BaseCache): if ttl is not None: ttl = int(ttl) return ttl - + def _get_embedding(self, prompt: str) -> List[float]: """ Generate an embedding vector for the given prompt using the configured embedding model. - + Args: prompt: The text to generate an embedding for - + Returns: List[float]: The embedding vector """ # Create an embedding from prompt - embedding_response = litellm.embedding( - model=self.embedding_model, - input=prompt, - cache={"no-store": True, "no-cache": True}, + embedding_response = cast( + EmbeddingResponse, + litellm.embedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ), ) embedding = embedding_response["data"][0]["embedding"] return embedding @@ -148,10 +157,10 @@ class RedisSemanticCache(BaseCache): def _get_cache_logic(self, cached_response: Any) -> Any: """ Process the cached response to prepare it for use. - + Args: cached_response: The raw cached response - + Returns: The processed cache response, or None if input was None """ @@ -171,13 +180,13 @@ class RedisSemanticCache(BaseCache): except (ValueError, SyntaxError) as e: print_verbose(f"Error parsing cached response: {str(e)}") return None - + return cached_response def set_cache(self, key: str, value: Any, **kwargs) -> None: """ Store a value in the semantic cache. - + Args: key: The cache key (not directly used in semantic caching) value: The response value to cache @@ -186,13 +195,14 @@ class RedisSemanticCache(BaseCache): """ print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}") + value_str: Optional[str] = None try: # Extract the prompt from messages messages = kwargs.get("messages", []) if not messages: print_verbose("No messages provided for semantic caching") return - + prompt = get_str_from_messages(messages) value_str = str(value) @@ -203,16 +213,18 @@ class RedisSemanticCache(BaseCache): else: self.llmcache.store(prompt, value_str) except Exception as e: - print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}") + print_verbose( + f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}" + ) def get_cache(self, key: str, **kwargs) -> Any: """ Retrieve a semantically similar cached response. - + Args: key: The cache key (not directly used in semantic caching) **kwargs: Additional arguments including 'messages' for the prompt - + Returns: The cached response if a semantically similar prompt is found, else None """ @@ -224,7 +236,7 @@ class RedisSemanticCache(BaseCache): if not messages: print_verbose("No messages provided for semantic cache lookup") return None - + prompt = get_str_from_messages(messages) # Check the cache for semantically similar prompts results = self.llmcache.check(prompt=prompt) @@ -236,12 +248,12 @@ class RedisSemanticCache(BaseCache): # Process the best matching result cache_hit = results[0] vector_distance = float(cache_hit["vector_distance"]) - + # Convert vector distance back to similarity score # For cosine distance: 0 = most similar, 2 = least similar # While similarity: 1 = most similar, 0 = least similar similarity = 1 - vector_distance - + cached_prompt = cache_hit["prompt"] cached_response = cache_hit["response"] @@ -251,19 +263,19 @@ class RedisSemanticCache(BaseCache): f"current prompt: {prompt}, " f"cached prompt: {cached_prompt}" ) - + return self._get_cache_logic(cached_response=cached_response) except Exception as e: print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}") - + async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: """ Asynchronously generate an embedding for the given prompt. - + Args: prompt: The text to generate an embedding for **kwargs: Additional arguments that may contain metadata - + Returns: List[float]: The embedding vector """ @@ -275,7 +287,7 @@ class RedisSemanticCache(BaseCache): if llm_model_list is not None else [] ) - + try: if llm_router is not None and self.embedding_model in router_model_names: # Use the router for embedding generation @@ -307,7 +319,7 @@ class RedisSemanticCache(BaseCache): async def async_set_cache(self, key: str, value: Any, **kwargs) -> None: """ Asynchronously store a value in the semantic cache. - + Args: key: The cache key (not directly used in semantic caching) value: The response value to cache @@ -322,13 +334,13 @@ class RedisSemanticCache(BaseCache): if not messages: print_verbose("No messages provided for semantic caching") return - + prompt = get_str_from_messages(messages) value_str = str(value) # Generate embedding for the value (response) to cache prompt_embedding = await self._get_async_embedding(prompt, **kwargs) - + # Get TTL and store in Redis semantic cache ttl = self._get_ttl(**kwargs) if ttl is not None: @@ -336,13 +348,13 @@ class RedisSemanticCache(BaseCache): prompt, value_str, vector=prompt_embedding, # Pass through custom embedding - ttl=ttl + ttl=ttl, ) else: await self.llmcache.astore( prompt, value_str, - vector=prompt_embedding # Pass through custom embedding + vector=prompt_embedding, # Pass through custom embedding ) except Exception as e: print_verbose(f"Error in async_set_cache: {str(e)}") @@ -350,11 +362,11 @@ class RedisSemanticCache(BaseCache): async def async_get_cache(self, key: str, **kwargs) -> Any: """ Asynchronously retrieve a semantically similar cached response. - + Args: key: The cache key (not directly used in semantic caching) **kwargs: Additional arguments including 'messages' for the prompt - + Returns: The cached response if a semantically similar prompt is found, else None """ @@ -367,21 +379,20 @@ class RedisSemanticCache(BaseCache): print_verbose("No messages provided for semantic cache lookup") kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None - + prompt = get_str_from_messages(messages) - + # Generate embedding for the prompt prompt_embedding = await self._get_async_embedding(prompt, **kwargs) # Check the cache for semantically similar prompts - results = await self.llmcache.acheck( - prompt=prompt, - vector=prompt_embedding - ) + results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding) # handle results / cache hit if not results: - kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above?? + kwargs.setdefault("metadata", {})[ + "semantic-similarity" + ] = 0.0 # TODO why here but not above?? return None cache_hit = results[0] @@ -404,7 +415,7 @@ class RedisSemanticCache(BaseCache): f"current prompt: {prompt}, " f"cached prompt: {cached_prompt}" ) - + return self._get_cache_logic(cached_response=cached_response) except Exception as e: print_verbose(f"Error in async_get_cache: {str(e)}") @@ -413,17 +424,19 @@ class RedisSemanticCache(BaseCache): async def _index_info(self) -> Dict[str, Any]: """ Get information about the Redis index. - + Returns: Dict[str, Any]: Information about the Redis index """ aindex = await self.llmcache._get_async_index() return await aindex.info() - async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None: + async def async_set_cache_pipeline( + self, cache_list: List[Tuple[str, Any]], **kwargs + ) -> None: """ Asynchronously store multiple values in the semantic cache. - + Args: cache_list: List of (key, value) tuples to cache **kwargs: Additional arguments diff --git a/litellm/caching/s3_cache.py b/litellm/caching/s3_cache.py index 301591c64f..c02e109136 100644 --- a/litellm/caching/s3_cache.py +++ b/litellm/caching/s3_cache.py @@ -123,7 +123,7 @@ class S3Cache(BaseCache): ) # Convert string to dictionary except Exception: cached_response = ast.literal_eval(cached_response) - if type(cached_response) is not dict: + if not isinstance(cached_response, dict): cached_response = dict(cached_response) verbose_logger.debug( f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}" diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index f343591443..a41fc364ab 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915 - For un-mapped Replicate models, the cost is calculated based on the total time used for the request. """ try: - call_type = _infer_call_type(call_type, completion_response) or "completion" if ( diff --git a/litellm/fine_tuning/main.py b/litellm/fine_tuning/main.py index b726a394c2..09c070fffb 100644 --- a/litellm/fine_tuning/main.py +++ b/litellm/fine_tuning/main.py @@ -138,7 +138,6 @@ def create_fine_tuning_job( # OpenAI if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base @@ -360,7 +359,6 @@ def cancel_fine_tuning_job( # OpenAI if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base @@ -522,7 +520,6 @@ def list_fine_tuning_jobs( # OpenAI if custom_llm_provider == "openai": - # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( optional_params.api_base diff --git a/litellm/integrations/SlackAlerting/batching_handler.py b/litellm/integrations/SlackAlerting/batching_handler.py index e35cf61d63..fdce2e0479 100644 --- a/litellm/integrations/SlackAlerting/batching_handler.py +++ b/litellm/integrations/SlackAlerting/batching_handler.py @@ -19,7 +19,6 @@ else: def squash_payloads(queue): - squashed = {} if len(queue) == 0: return squashed diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index a2e6264760..50f0538cfd 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -195,12 +195,15 @@ class SlackAlerting(CustomBatchLogger): if self.alerting is None or self.alert_types is None: return - time_difference_float, model, api_base, messages = ( - self._response_taking_too_long_callback_helper( - kwargs=kwargs, - start_time=start_time, - end_time=end_time, - ) + ( + time_difference_float, + model, + api_base, + messages, + ) = self._response_taking_too_long_callback_helper( + kwargs=kwargs, + start_time=start_time, + end_time=end_time, ) if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: messages = "Message not logged. litellm.redact_messages_in_exceptions=True" @@ -819,9 +822,9 @@ class SlackAlerting(CustomBatchLogger): ### UNIQUE CACHE KEY ### cache_key = provider + region_name - outage_value: Optional[ProviderRegionOutageModel] = ( - await self.internal_usage_cache.async_get_cache(key=cache_key) - ) + outage_value: Optional[ + ProviderRegionOutageModel + ] = await self.internal_usage_cache.async_get_cache(key=cache_key) if ( getattr(exception, "status_code", None) is None @@ -1402,9 +1405,9 @@ Model Info: self.alert_to_webhook_url is not None and alert_type in self.alert_to_webhook_url ): - slack_webhook_url: Optional[Union[str, List[str]]] = ( - self.alert_to_webhook_url[alert_type] - ) + slack_webhook_url: Optional[ + Union[str, List[str]] + ] = self.alert_to_webhook_url[alert_type] elif self.default_webhook_url is not None: slack_webhook_url = self.default_webhook_url else: @@ -1768,7 +1771,6 @@ Model Info: - Team Created, Updated, Deleted """ try: - message = f"`{event_name}`\n" key_event_dict = key_event.model_dump() diff --git a/litellm/integrations/_types/open_inference.py b/litellm/integrations/_types/open_inference.py index b5076c0e42..bcfabe9b7b 100644 --- a/litellm/integrations/_types/open_inference.py +++ b/litellm/integrations/_types/open_inference.py @@ -283,4 +283,4 @@ class OpenInferenceSpanKindValues(Enum): class OpenInferenceMimeTypeValues(Enum): TEXT = "text/plain" - JSON = "application/json" \ No newline at end of file + JSON = "application/json" diff --git a/litellm/integrations/argilla.py b/litellm/integrations/argilla.py index 055ad90259..a362ce7e4d 100644 --- a/litellm/integrations/argilla.py +++ b/litellm/integrations/argilla.py @@ -98,7 +98,6 @@ class ArgillaLogger(CustomBatchLogger): argilla_dataset_name: Optional[str], argilla_base_url: Optional[str], ) -> ArgillaCredentialsObject: - _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") if _credentials_api_key is None: raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") diff --git a/litellm/integrations/arize/_utils.py b/litellm/integrations/arize/_utils.py index 487304cce4..5a090968b4 100644 --- a/litellm/integrations/arize/_utils.py +++ b/litellm/integrations/arize/_utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from litellm._logging import verbose_logger from litellm.litellm_core_utils.safe_json_dumps import safe_dumps @@ -7,7 +7,7 @@ from litellm.types.utils import StandardLoggingPayload if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/integrations/arize/arize.py b/litellm/integrations/arize/arize.py index 7a0fb785a7..03b6966809 100644 --- a/litellm/integrations/arize/arize.py +++ b/litellm/integrations/arize/arize.py @@ -19,14 +19,13 @@ if TYPE_CHECKING: from litellm.types.integrations.arize import Protocol as _Protocol Protocol = _Protocol - Span = _Span + Span = Union[_Span, Any] else: Protocol = Any Span = Any class ArizeLogger(OpenTelemetry): - def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]): ArizeLogger.set_arize_attributes(span, kwargs, response_obj) return diff --git a/litellm/integrations/arize/arize_phoenix.py b/litellm/integrations/arize/arize_phoenix.py index d7b7d5812b..2b4909885a 100644 --- a/litellm/integrations/arize/arize_phoenix.py +++ b/litellm/integrations/arize/arize_phoenix.py @@ -1,17 +1,20 @@ import os -from typing import TYPE_CHECKING, Any -from litellm.integrations.arize import _utils +from typing import TYPE_CHECKING, Any, Union + from litellm._logging import verbose_logger +from litellm.integrations.arize import _utils from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig if TYPE_CHECKING: - from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig - from litellm.types.integrations.arize import Protocol as _Protocol from opentelemetry.trace import Span as _Span + from litellm.types.integrations.arize import Protocol as _Protocol + + from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig + Protocol = _Protocol OpenTelemetryConfig = _OpenTelemetryConfig - Span = _Span + Span = Union[_Span, Any] else: Protocol = Any OpenTelemetryConfig = Any @@ -20,6 +23,7 @@ else: ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces" + class ArizePhoenixLogger: @staticmethod def set_arize_phoenix_attributes(span: Span, kwargs, response_obj): @@ -49,7 +53,7 @@ class ArizePhoenixLogger: protocol = "otlp_grpc" else: endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT - protocol = "otlp_http" + protocol = "otlp_http" verbose_logger.debug( f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}" ) @@ -57,17 +61,16 @@ class ArizePhoenixLogger: otlp_auth_headers = None # If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses # a slightly different auth header format than self hosted phoenix - if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT: + if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT: if api_key is None: - raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.") + raise ValueError( + "PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used." + ) otlp_auth_headers = f"api_key={api_key}" elif api_key is not None: # api_key/auth is optional for self hosted phoenix otlp_auth_headers = f"Authorization=Bearer {api_key}" return ArizePhoenixConfig( - otlp_auth_headers=otlp_auth_headers, - protocol=protocol, - endpoint=endpoint + otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint ) - diff --git a/litellm/integrations/athina.py b/litellm/integrations/athina.py index 705dc11f1d..49b9e9e687 100644 --- a/litellm/integrations/athina.py +++ b/litellm/integrations/athina.py @@ -12,7 +12,10 @@ class AthinaLogger: "athina-api-key": self.athina_api_key, "Content-Type": "application/json", } - self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference" + self.athina_logging_url = ( + os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + + "/api/v1/log/inference" + ) self.additional_keys = [ "environment", "prompt_slug", diff --git a/litellm/integrations/azure_storage/azure_storage.py b/litellm/integrations/azure_storage/azure_storage.py index ddc46b117f..27f5e0e112 100644 --- a/litellm/integrations/azure_storage/azure_storage.py +++ b/litellm/integrations/azure_storage/azure_storage.py @@ -50,12 +50,12 @@ class AzureBlobStorageLogger(CustomBatchLogger): self.azure_storage_file_system: str = _azure_storage_file_system # Internal variables used for Token based authentication - self.azure_auth_token: Optional[str] = ( - None # the Azure AD token to use for Azure Storage API requests - ) - self.token_expiry: Optional[datetime] = ( - None # the expiry time of the currentAzure AD token - ) + self.azure_auth_token: Optional[ + str + ] = None # the Azure AD token to use for Azure Storage API requests + self.token_expiry: Optional[ + datetime + ] = None # the expiry time of the currentAzure AD token asyncio.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() @@ -153,7 +153,6 @@ class AzureBlobStorageLogger(CustomBatchLogger): 3. Flush the data """ try: - if self.azure_storage_account_key: await self.upload_to_azure_data_lake_with_azure_account_key( payload=payload diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index 281fbda01e..0961eab02b 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -4,7 +4,7 @@ import copy import os from datetime import datetime -from typing import Optional, Dict +from typing import Dict, Optional import httpx from pydantic import BaseModel @@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import ( ) from litellm.utils import print_verbose -global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback) +global_braintrust_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback +) global_braintrust_sync_http_handler = HTTPHandler() API_BASE = "https://api.braintrustdata.com/v1" @@ -35,7 +37,9 @@ def get_utc_datetime(): class BraintrustLogger(CustomLogger): - def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None: + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> None: super().__init__() self.validate_environment(api_key=api_key) self.api_base = api_base or API_BASE @@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger): "Authorization": "Bearer " + self.api_key, "Content-Type": "application/json", } - self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs + self._project_id_cache: Dict[ + str, str + ] = {} # Cache mapping project names to IDs def validate_environment(self, api_key: Optional[str]): """ @@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger): try: response = global_braintrust_sync_http_handler.post( - f"{self.api_base}/project", headers=self.headers, json={"name": project_name} + f"{self.api_base}/project", + headers=self.headers, + json={"name": project_name}, ) project_dict = response.json() project_id = project_dict["id"] @@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger): try: response = await global_braintrust_http_handler.post( - f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name} + f"{self.api_base}/project/register", + headers=self.headers, + json={"name": project_name}, ) project_dict = response.json() project_id = project_dict["id"] @@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger): if metadata is None: metadata = {} - proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + proxy_headers = ( + litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + ) for metadata_param_key in proxy_headers: if metadata_param_key.startswith("braintrust"): trace_param_key = metadata_param_key.replace("braintrust", "", 1) if trace_param_key in metadata: - verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header") + verbose_logger.warning( + f"Overwriting Braintrust `{trace_param_key}` from request header" + ) else: - verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header") + verbose_logger.debug( + f"Found Braintrust `{trace_param_key}` in request header" + ) metadata[trace_param_key] = proxy_headers.get(metadata_param_key) return metadata @@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger): output = None choices = [] if response_obj is not None and ( - kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) ): output = None - elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): output = response_obj["choices"][0]["message"].json() choices = response_obj["choices"] - elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): output = response_obj.choices[0].text choices = response_obj.choices - elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): output = response_obj["data"] litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None metadata = self.add_metadata_from_header(litellm_params, metadata) clean_metadata = {} try: - metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata + metadata = copy.deepcopy( + metadata + ) # Avoid modifying the original metadata except Exception: new_metadata = {} for key, value in metadata.items(): @@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger): project_id = metadata.get("project_id") if project_id is None: project_name = metadata.get("project_name") - project_id = self.get_project_id_sync(project_name) if project_name else None + project_id = ( + self.get_project_id_sync(project_name) if project_name else None + ) if project_id is None: if self.default_project_id is None: @@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger): "completion_tokens": usage_obj.completion_tokens, "total_tokens": usage_obj.total_tokens, "total_cost": cost, - "time_to_first_token": end_time.timestamp() - start_time.timestamp(), + "time_to_first_token": end_time.timestamp() + - start_time.timestamp(), "start": start_time.timestamp(), "end": end_time.timestamp(), } @@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger): request_data["metrics"] = metrics try: - print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}") + print_verbose( + f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}" + ) global_braintrust_sync_http_handler.post( url=f"{self.api_base}/project_logs/{project_id}/insert", json={"events": [request_data]}, @@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger): output = None choices = [] if response_obj is not None and ( - kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) ): output = None - elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): output = response_obj["choices"][0]["message"].json() choices = response_obj["choices"] - elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): output = response_obj.choices[0].text choices = response_obj.choices - elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): output = response_obj["data"] litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None metadata = self.add_metadata_from_header(litellm_params, metadata) clean_metadata = {} new_metadata = {} @@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger): project_id = metadata.get("project_id") if project_id is None: project_name = metadata.get("project_name") - project_id = await self.get_project_id_async(project_name) if project_name else None + project_id = ( + await self.get_project_id_async(project_name) + if project_name + else None + ) if project_id is None: if self.default_project_id is None: @@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger): api_call_start_time = kwargs.get("api_call_start_time") completion_start_time = kwargs.get("completion_start_time") - if api_call_start_time is not None and completion_start_time is not None: - metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp() + if ( + api_call_start_time is not None + and completion_start_time is not None + ): + metrics["time_to_first_token"] = ( + completion_start_time.timestamp() + - api_call_start_time.timestamp() + ) request_data = { "id": litellm_call_id, diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index 3cfdf82cab..f9d4496c21 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -14,7 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger class CustomBatchLogger(CustomLogger): - def __init__( self, flush_lock: Optional[asyncio.Lock] = None, diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 4421664bfc..41a3800116 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -7,7 +7,6 @@ from litellm.types.utils import StandardLoggingGuardrailInformation class CustomGuardrail(CustomLogger): - def __init__( self, guardrail_name: Optional[str] = None, diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 6f1ec88d01..ddb8094285 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -31,7 +31,7 @@ from litellm.types.utils import ( if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/integrations/datadog/datadog.py b/litellm/integrations/datadog/datadog.py index 4f4b05c84e..e9b6b6b164 100644 --- a/litellm/integrations/datadog/datadog.py +++ b/litellm/integrations/datadog/datadog.py @@ -233,7 +233,6 @@ class DataDogLogger( pass async def _log_async_event(self, kwargs, response_obj, start_time, end_time): - dd_payload = self.create_datadog_logging_payload( kwargs=kwargs, response_obj=response_obj, diff --git a/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/litellm/integrations/gcs_bucket/gcs_bucket_base.py index 66995d8482..0ce845ecb2 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket_base.py @@ -125,9 +125,9 @@ class GCSBucketBase(CustomBatchLogger): if kwargs is None: kwargs = {} - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params", None) - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params", None) bucket_name: str path_service_account: Optional[str] diff --git a/litellm/integrations/gcs_pubsub/pub_sub.py b/litellm/integrations/gcs_pubsub/pub_sub.py index 1b078df7bc..bdaedcd908 100644 --- a/litellm/integrations/gcs_pubsub/pub_sub.py +++ b/litellm/integrations/gcs_pubsub/pub_sub.py @@ -70,12 +70,13 @@ class GcsPubSubLogger(CustomBatchLogger): """Construct authorization headers using Vertex AI auth""" from litellm import vertex_chat_completion - _auth_header, vertex_project = ( - await vertex_chat_completion._ensure_access_token_async( - credentials=self.path_service_account_json, - project_id=None, - custom_llm_provider="vertex_ai", - ) + ( + _auth_header, + vertex_project, + ) = await vertex_chat_completion._ensure_access_token_async( + credentials=self.path_service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", ) auth_header, _ = vertex_chat_completion._get_token_and_url( diff --git a/litellm/integrations/humanloop.py b/litellm/integrations/humanloop.py index fd3463f9e3..4651238af4 100644 --- a/litellm/integrations/humanloop.py +++ b/litellm/integrations/humanloop.py @@ -155,11 +155,7 @@ class HumanloopLogger(CustomLogger): prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[ - str, - List[AllMessageValues], - dict, - ]: + ) -> Tuple[str, List[AllMessageValues], dict,]: humanloop_api_key = dynamic_callback_params.get( "humanloop_api_key" ) or get_secret_str("HUMANLOOP_API_KEY") diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index f990a316c4..d0472ee638 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -471,9 +471,9 @@ class LangFuseLogger: # we clean out all extra litellm metadata params before logging clean_metadata: Dict[str, Any] = {} if prompt_management_metadata is not None: - clean_metadata["prompt_management_metadata"] = ( - prompt_management_metadata - ) + clean_metadata[ + "prompt_management_metadata" + ] = prompt_management_metadata if isinstance(metadata, dict): for key, value in metadata.items(): # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy diff --git a/litellm/integrations/langfuse/langfuse_handler.py b/litellm/integrations/langfuse/langfuse_handler.py index aebe1461b0..f9d27f6cf0 100644 --- a/litellm/integrations/langfuse/langfuse_handler.py +++ b/litellm/integrations/langfuse/langfuse_handler.py @@ -19,7 +19,6 @@ else: class LangFuseHandler: - @staticmethod def get_langfuse_logger_for_request( standard_callback_dynamic_params: StandardCallbackDynamicParams, @@ -87,7 +86,9 @@ class LangFuseHandler: if globalLangfuseLogger is not None: return globalLangfuseLogger - credentials_dict: Dict[str, Any] = ( + credentials_dict: Dict[ + str, Any + ] = ( {} ) # the global langfuse logger uses Environment Variables, there are no dynamic credentials globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache( diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index 1f4ca84db3..30f991ebd6 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -172,11 +172,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[ - str, - List[AllMessageValues], - dict, - ]: + ) -> Tuple[str, List[AllMessageValues], dict,]: return self.get_chat_completion_prompt( model, messages, diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 1ef90c1822..0914150db9 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -75,7 +75,6 @@ class LangsmithLogger(CustomBatchLogger): langsmith_project: Optional[str] = None, langsmith_base_url: Optional[str] = None, ) -> LangsmithCredentialsObject: - _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY") if _credentials_api_key is None: raise Exception( @@ -443,9 +442,9 @@ class LangsmithLogger(CustomBatchLogger): Otherwise, use the default credentials. """ - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params", None) - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params", None) if standard_callback_dynamic_params is not None: credentials = self.get_credentials_from_env( langsmith_api_key=standard_callback_dynamic_params.get( @@ -481,7 +480,6 @@ class LangsmithLogger(CustomBatchLogger): asyncio.run(self.async_send_batch()) def get_run_by_id(self, run_id): - langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] diff --git a/litellm/integrations/langtrace.py b/litellm/integrations/langtrace.py index 51cd272ff1..ac1069f440 100644 --- a/litellm/integrations/langtrace.py +++ b/litellm/integrations/langtrace.py @@ -1,12 +1,12 @@ import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from litellm.proxy._types import SpanAttributes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/integrations/lunary.py b/litellm/integrations/lunary.py index fcd781e44e..b24a24e088 100644 --- a/litellm/integrations/lunary.py +++ b/litellm/integrations/lunary.py @@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls): return None def clean_tool_call(tool_call): - serialized = { "type": tool_call.type, "id": tool_call.id, @@ -36,7 +35,6 @@ def parse_tool_calls(tool_calls): def parse_messages(input): - if input is None: return None diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py index 193d1c4ea2..e7a458accf 100644 --- a/litellm/integrations/mlflow.py +++ b/litellm/integrations/mlflow.py @@ -48,14 +48,17 @@ class MlflowLogger(CustomLogger): def _extract_and_set_chat_attributes(self, span, kwargs, response_obj): try: - from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools + from mlflow.tracing.utils import set_span_chat_messages # type: ignore + from mlflow.tracing.utils import set_span_chat_tools # type: ignore except ImportError: return inputs = self._construct_input(kwargs) input_messages = inputs.get("messages", []) - output_messages = [c.message.model_dump(exclude_none=True) - for c in getattr(response_obj, "choices", [])] + output_messages = [ + c.message.model_dump(exclude_none=True) + for c in getattr(response_obj, "choices", []) + ] if messages := [*input_messages, *output_messages]: set_span_chat_messages(span, messages) if tools := inputs.get("tools"): diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 1572eb81f5..f4fe40738b 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast import litellm from litellm._logging import verbose_logger @@ -23,10 +23,10 @@ if TYPE_CHECKING: ) from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth - Span = _Span - SpanExporter = _SpanExporter - UserAPIKeyAuth = _UserAPIKeyAuth - ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload + Span = Union[_Span, Any] + SpanExporter = Union[_SpanExporter, Any] + UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any] + ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any] else: Span = Any SpanExporter = Any @@ -46,7 +46,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request" @dataclass class OpenTelemetryConfig: - exporter: Union[str, SpanExporter] = "console" endpoint: Optional[str] = None headers: Optional[str] = None @@ -154,7 +153,6 @@ class OpenTelemetry(CustomLogger): end_time: Optional[Union[datetime, float]] = None, event_metadata: Optional[dict] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -215,7 +213,6 @@ class OpenTelemetry(CustomLogger): end_time: Optional[Union[float, datetime]] = None, event_metadata: Optional[dict] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -353,9 +350,9 @@ class OpenTelemetry(CustomLogger): """ from opentelemetry import trace - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params") - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params") if not standard_callback_dynamic_params: return @@ -722,7 +719,6 @@ class OpenTelemetry(CustomLogger): span.set_attribute(key, primitive_value) def set_raw_request_attributes(self, span: Span, kwargs, response_obj): - kwargs.get("optional_params", {}) litellm_params = kwargs.get("litellm_params", {}) or {} custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown") @@ -843,12 +839,14 @@ class OpenTelemetry(CustomLogger): headers=dynamic_headers or self.OTEL_HEADERS ) - if isinstance(self.OTEL_EXPORTER, SpanExporter): + if hasattr( + self.OTEL_EXPORTER, "export" + ): # Check if it has the export method that SpanExporter requires verbose_logger.debug( "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s", self.OTEL_EXPORTER, ) - return SimpleSpanProcessor(self.OTEL_EXPORTER) + return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER)) if self.OTEL_EXPORTER == "console": verbose_logger.debug( @@ -907,7 +905,6 @@ class OpenTelemetry(CustomLogger): logging_payload: ManagementEndpointLoggingPayload, parent_otel_span: Optional[Span] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -961,7 +958,6 @@ class OpenTelemetry(CustomLogger): logging_payload: ManagementEndpointLoggingPayload, parent_otel_span: Optional[Span] = None, ): - from opentelemetry import trace from opentelemetry.trace import Status, StatusCode diff --git a/litellm/integrations/opik/opik.py b/litellm/integrations/opik/opik.py index 1f7f18f336..8cbfb9e653 100644 --- a/litellm/integrations/opik/opik.py +++ b/litellm/integrations/opik/opik.py @@ -185,7 +185,6 @@ class OpikLogger(CustomBatchLogger): def _create_opik_payload( # noqa: PLR0915 self, kwargs, response_obj, start_time, end_time ) -> List[Dict]: - # Get metadata _litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params_metadata = _litellm_params.get("metadata", {}) or {} diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index d6e47b87ce..5ac8c80eb3 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -988,9 +988,9 @@ class PrometheusLogger(CustomLogger): ): try: verbose_logger.debug("setting remaining tokens requests metric") - standard_logging_payload: Optional[StandardLoggingPayload] = ( - request_kwargs.get("standard_logging_object") - ) + standard_logging_payload: Optional[ + StandardLoggingPayload + ] = request_kwargs.get("standard_logging_object") if standard_logging_payload is None: return diff --git a/litellm/integrations/prompt_management_base.py b/litellm/integrations/prompt_management_base.py index 3fe3b31ed8..07b6720ffe 100644 --- a/litellm/integrations/prompt_management_base.py +++ b/litellm/integrations/prompt_management_base.py @@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict): class PromptManagementBase(ABC): - @property @abstractmethod def integration_name(self) -> str: @@ -83,11 +82,7 @@ class PromptManagementBase(ABC): prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[ - str, - List[AllMessageValues], - dict, - ]: + ) -> Tuple[str, List[AllMessageValues], dict,]: if not self.should_run_prompt_management( prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params ): diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index 4a0c27354f..01b9248e03 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -38,7 +38,7 @@ class S3Logger: if litellm.s3_callback_params is not None: # read in .env variables - example os.environ/AWS_BUCKET_NAME for key, value in litellm.s3_callback_params.items(): - if type(value) is str and value.startswith("os.environ/"): + if isinstance(value, str) and value.startswith("os.environ/"): litellm.s3_callback_params[key] = litellm.get_secret(value) # now set s3 params from litellm.s3_logger_params s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name") diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py index 5fcbab04b3..63d87c9bd9 100644 --- a/litellm/integrations/weights_biases.py +++ b/litellm/integrations/weights_biases.py @@ -21,11 +21,11 @@ try: # contains a (known) object attribute object: Literal["chat.completion", "edit", "text_completion"] - def __getitem__(self, key: K) -> V: ... # noqa + def __getitem__(self, key: K) -> V: + ... # noqa - def get( # noqa - self, key: K, default: Optional[V] = None - ) -> Optional[V]: ... # pragma: no cover + def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa + ... # pragma: no cover class OpenAIRequestResponseResolver: def __call__( diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 2036b93692..275c53ad30 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -10,7 +10,7 @@ from litellm.types.llms.openai import AllMessageValues if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/litellm_core_utils/default_encoding.py b/litellm/litellm_core_utils/default_encoding.py index 05bf78a6a9..93b3132912 100644 --- a/litellm/litellm_core_utils/default_encoding.py +++ b/litellm/litellm_core_utils/default_encoding.py @@ -11,7 +11,9 @@ except (ImportError, AttributeError): # Old way to access resources, which setuptools deprecated some time ago import pkg_resources # type: ignore - filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers") + filename = pkg_resources.resource_filename( + __name__, "litellm_core_utils/tokenizers" + ) os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv( "CUSTOM_TIKTOKEN_CACHE_DIR", filename diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 3565c4468c..dcd3ae3a64 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -239,9 +239,9 @@ class Logging(LiteLLMLoggingBaseClass): self.litellm_trace_id = litellm_trace_id self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[Any] = ( - [] - ) # for generating complete stream response + self.sync_streaming_chunks: List[ + Any + ] = [] # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -452,18 +452,19 @@ class Logging(LiteLLMLoggingBaseClass): prompt_id: str, prompt_variables: Optional[dict], ) -> Tuple[str, List[AllMessageValues], dict]: - custom_logger = self.get_custom_logger_for_prompt_management(model) if custom_logger: - model, messages, non_default_params = ( - custom_logger.get_chat_completion_prompt( - model=model, - messages=messages, - non_default_params=non_default_params, - prompt_id=prompt_id, - prompt_variables=prompt_variables, - dynamic_callback_params=self.standard_callback_dynamic_params, - ) + ( + model, + messages, + non_default_params, + ) = custom_logger.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + dynamic_callback_params=self.standard_callback_dynamic_params, ) self.messages = messages return model, messages, non_default_params @@ -541,12 +542,11 @@ class Logging(LiteLLMLoggingBaseClass): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"]["api_base"] = ( - self._get_masked_api_base(additional_args.get("api_base", "")) - ) + self.model_call_details["litellm_params"][ + "api_base" + ] = self._get_masked_api_base(additional_args.get("api_base", "")) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 - # Log the exact input to the LLM API litellm.error_logs["PRE_CALL"] = locals() try: @@ -568,19 +568,16 @@ class Logging(LiteLLMLoggingBaseClass): self.log_raw_request_response is True or log_raw_request_response is True ): - _litellm_params = self.model_call_details.get("litellm_params", {}) _metadata = _litellm_params.get("metadata", {}) or {} try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - - _metadata["raw_request"] = ( - "redacted by litellm. \ + _metadata[ + "raw_request" + ] = "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" - ) else: - curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), headers=additional_args.get("headers", {}), @@ -590,33 +587,33 @@ class Logging(LiteLLMLoggingBaseClass): _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ignore_sensitive_headers=True, - ), - error=None, - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ignore_sensitive_headers=True, + ), + error=None, ) except Exception as e: - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - error=str(e), - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + error=str(e), ) traceback.print_exc() - _metadata["raw_request"] = ( - "Unable to Log \ + _metadata[ + "raw_request" + ] = "Unable to Log \ raw request: {}".format( - str(e) - ) + str(e) ) if self.logger_fn and callable(self.logger_fn): try: @@ -941,9 +938,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None try: @@ -968,9 +965,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None @@ -995,7 +992,6 @@ class Logging(LiteLLMLoggingBaseClass): def should_run_callback( self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str ) -> bool: - if litellm.global_disable_no_log_param: return True @@ -1027,9 +1023,9 @@ class Logging(LiteLLMLoggingBaseClass): end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details["completion_start_time"] = ( - self.completion_start_time - ) + self.model_call_details[ + "completion_start_time" + ] = self.completion_start_time self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit @@ -1083,39 +1079,39 @@ class Logging(LiteLLMLoggingBaseClass): "response_cost" ] else: - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=result) - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=result) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif isinstance(result, dict): # pass-through endpoints ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif standard_logging_object is not None: - self.model_call_details["standard_logging_object"] = ( - standard_logging_object - ) + self.model_call_details[ + "standard_logging_object" + ] = standard_logging_object else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -1154,7 +1150,6 @@ class Logging(LiteLLMLoggingBaseClass): standard_logging_object=kwargs.get("standard_logging_object", None), ) try: - ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response: Optional[ Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] @@ -1172,23 +1167,23 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details["complete_streaming_response"] = ( - complete_streaming_response - ) - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=complete_streaming_response) - ) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=complete_streaming_response) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, @@ -1207,7 +1202,6 @@ class Logging(LiteLLMLoggingBaseClass): ## LOGGING HOOK ## for callback in callbacks: if isinstance(callback, CustomLogger): - self.model_call_details, result = callback.logging_hook( kwargs=self.model_call_details, result=result, @@ -1538,10 +1532,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -1581,10 +1575,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] @@ -1659,7 +1653,6 @@ class Logging(LiteLLMLoggingBaseClass): if self.call_type == CallTypes.aretrieve_batch.value and isinstance( result, LiteLLMBatch ): - response_cost, batch_usage, batch_models = await _handle_completed_batch( batch=result, custom_llm_provider=self.custom_llm_provider ) @@ -1692,9 +1685,9 @@ class Logging(LiteLLMLoggingBaseClass): if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["async_complete_streaming_response"] = ( - complete_streaming_response - ) + self.model_call_details[ + "async_complete_streaming_response" + ] = complete_streaming_response try: if self.model_call_details.get("cache_hit", False) is True: self.model_call_details["response_cost"] = 0.0 @@ -1704,10 +1697,10 @@ class Logging(LiteLLMLoggingBaseClass): model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details["response_cost"] = ( - self._response_cost_calculator( - result=complete_streaming_response - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator( + result=complete_streaming_response ) verbose_logger.debug( @@ -1720,16 +1713,16 @@ class Logging(LiteLLMLoggingBaseClass): self.model_call_details["response_cost"] = None ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, @@ -1935,18 +1928,18 @@ class Logging(LiteLLMLoggingBaseClass): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, ) return start_time, end_time @@ -2084,7 +2077,6 @@ class Logging(LiteLLMLoggingBaseClass): ) is not True ): # custom logger class - callback.log_failure_event( start_time=start_time, end_time=end_time, @@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 endpoint=arize_config.endpoint, ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"space_key={arize_config.space_key},api_key={arize_config.api_key}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}" for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - arize_phoenix_config.otlp_auth_headers - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = arize_phoenix_config.otlp_auth_headers for callback in _in_memory_loggers: if ( @@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"api_key={os.getenv('LANGTRACE_API_KEY')}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -3223,7 +3215,6 @@ class StandardLoggingPayloadSetup: custom_llm_provider: Optional[str], init_response_obj: Union[Any, BaseModel, dict], ) -> StandardLoggingModelInformation: - model_cost_name = _select_model_name_for_cost_calc( model=None, completion_response=init_response_obj, # type: ignore @@ -3286,7 +3277,6 @@ class StandardLoggingPayloadSetup: def get_additional_headers( additiona_headers: Optional[dict], ) -> Optional[StandardLoggingAdditionalHeaders]: - if additiona_headers is None: return None @@ -3322,10 +3312,10 @@ class StandardLoggingPayloadSetup: for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params["additional_headers"] = ( - StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] - ) + clean_hidden_params[ + "additional_headers" + ] = StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -3463,12 +3453,14 @@ def get_standard_logging_object_payload( ) # cleanup timestamps - start_time_float, end_time_float, completion_start_time_float = ( - StandardLoggingPayloadSetup.cleanup_timestamps( - start_time=start_time, - end_time=end_time, - completion_start_time=completion_start_time, - ) + ( + start_time_float, + end_time_float, + completion_start_time_float, + ) = StandardLoggingPayloadSetup.cleanup_timestamps( + start_time=start_time, + end_time=end_time, + completion_start_time=completion_start_time, ) response_time = StandardLoggingPayloadSetup.get_response_time( start_time_float=start_time_float, @@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload( saved_cache_cost: float = 0.0 if cache_hit is True: - id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id saved_cache_cost = ( logging_obj._response_cost_calculator( @@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[k] = ( - "scrubbed_by_litellm_for_sensitive_keys" - ) + cleaned_user_api_key_metadata[ + k + ] = "scrubbed_by_litellm_for_sensitive_keys" else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py index d33af2a477..f3f4ce6ef4 100644 --- a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py +++ b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py @@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s class LiteLLMResponseObjectHandler: - @staticmethod def convert_to_image_response( response_object: dict, model_response_object: Optional[ImageResponse] = None, hidden_params: Optional[dict] = None, ) -> ImageResponse: - response_object.update({"hidden_params": hidden_params}) if model_response_object is None: @@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915 provider_specific_fields["thinking_blocks"] = thinking_blocks if reasoning_content: - provider_specific_fields["reasoning_content"] = ( - reasoning_content - ) + provider_specific_fields[ + "reasoning_content" + ] = reasoning_content message = Message( content=content, diff --git a/litellm/litellm_core_utils/model_param_helper.py b/litellm/litellm_core_utils/model_param_helper.py index b7d8fc19d1..c96b4a3f5b 100644 --- a/litellm/litellm_core_utils/model_param_helper.py +++ b/litellm/litellm_core_utils/model_param_helper.py @@ -17,7 +17,6 @@ from litellm.types.rerank import RerankRequest class ModelParamHelper: - @staticmethod def get_standard_logging_model_parameters( model_parameters: dict, diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index c8745f5119..4170d3c1e1 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -257,7 +257,6 @@ def _insert_assistant_continue_message( and message.get("role") == "user" # Current is user and messages[i + 1].get("role") == "user" ): # Next is user - # Insert assistant message continue_message = ( assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 28e09d7ac8..3c89141b5e 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -1042,10 +1042,10 @@ def convert_to_gemini_tool_call_invoke( if tool_calls is not None: for tool in tool_calls: if "function" in tool: - gemini_function_call: Optional[VertexFunctionCall] = ( - _gemini_tool_call_invoke_helper( - function_call_params=tool["function"] - ) + gemini_function_call: Optional[ + VertexFunctionCall + ] = _gemini_tool_call_invoke_helper( + function_call_params=tool["function"] ) if gemini_function_call is not None: _parts_list.append( @@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_element["cache_control"] = ( - _content_element["cache_control"] - ) + _anthropic_content_element[ + "cache_control" + ] = _content_element["cache_control"] user_content.append(_anthropic_content_element) elif m.get("type", "") == "text": m = cast(ChatCompletionTextObject, m) @@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_text_element["cache_control"] = ( - _content_element["cache_control"] - ) + _anthropic_content_text_element[ + "cache_control" + ] = _content_element["cache_control"] user_content.append(_anthropic_content_text_element) @@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915 "content" ] # don't pass empty text blocks. anthropic api raises errors. ): - _anthropic_text_content_element = AnthropicMessagesTextParam( type="text", text=assistant_content_block["content"], @@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915 msg_i += 1 if assistant_content: - new_messages.append({"role": "assistant", "content": assistant_content}) if msg_i == init_msg_i: # prevent infinite loops @@ -2245,7 +2243,6 @@ class BedrockImageProcessor: @staticmethod async def get_image_details_async(image_url) -> Tuple[str, str]: try: - client = get_async_httpx_client( llm_provider=httpxSpecialProvider.PromptFactory, params={"concurrent_limit": 1}, @@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message( for item in modified_content_block: # Check if the list is empty if item["type"] == "text": - if not item["text"].strip(): # Replace empty text with continue message _user_continue_message = ChatCompletionUserMessage( @@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915 assistant_content: List[BedrockContentBlock] = [] ## MERGE CONSECUTIVE ASSISTANT CONTENT ## while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - assistant_message_block = get_assistant_message_block_or_continue_message( message=messages[msg_i], assistant_continue_message=assistant_continue_message, @@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str: {"role": "user", "content": "{}".format(response_schema)} ] if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: - custom_prompt_details = litellm.custom_prompt_dict[ f"{model}/response_schema_prompt" ] # allow user to define custom response schema prompt by model diff --git a/litellm/litellm_core_utils/realtime_streaming.py b/litellm/litellm_core_utils/realtime_streaming.py index aebd049692..e84c720441 100644 --- a/litellm/litellm_core_utils/realtime_streaming.py +++ b/litellm/litellm_core_utils/realtime_streaming.py @@ -122,7 +122,6 @@ class RealTimeStreaming: pass async def bidirectional_forward(self): - forward_task = asyncio.create_task(self.backend_to_client_send_messages()) try: await self.client_ack_messages() diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index 50e0e0b575..a62031a9c9 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params( handles boolean and string values of `turn_off_message_logging` """ - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - model_call_details.get("standard_callback_dynamic_params", None) - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = model_call_details.get("standard_callback_dynamic_params", None) if standard_callback_dynamic_params: _turn_off_message_logging = standard_callback_dynamic_params.get( "turn_off_message_logging" diff --git a/litellm/litellm_core_utils/sensitive_data_masker.py b/litellm/litellm_core_utils/sensitive_data_masker.py index 7800e5304f..23b9ec32fc 100644 --- a/litellm/litellm_core_utils/sensitive_data_masker.py +++ b/litellm/litellm_core_utils/sensitive_data_masker.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional, Set + from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH @@ -40,7 +41,10 @@ class SensitiveDataMasker: return result def mask_dict( - self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH + self, + data: Dict[str, Any], + depth: int = 0, + max_depth: int = DEFAULT_MAX_RECURSE_DEPTH, ) -> Dict[str, Any]: if depth >= max_depth: return data diff --git a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py index 7a5ee3e41e..1ca2bfe45e 100644 --- a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py +++ b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py @@ -104,7 +104,6 @@ class ChunkProcessor: def get_combined_tool_content( self, tool_call_chunks: List[Dict[str, Any]] ) -> List[ChatCompletionMessageToolCall]: - argument_list: List[str] = [] delta = tool_call_chunks[0]["choices"][0]["delta"] id = None diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index a11e5af12b..42106135cc 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -84,9 +84,9 @@ class CustomStreamWrapper: self.system_fingerprint: Optional[str] = None self.received_finish_reason: Optional[str] = None - self.intermittent_finish_reason: Optional[str] = ( - None # finish reasons that show up mid-stream - ) + self.intermittent_finish_reason: Optional[ + str + ] = None # finish reasons that show up mid-stream self.special_tokens = [ "<|assistant|>", "<|system|>", @@ -814,7 +814,6 @@ class CustomStreamWrapper: model_response: ModelResponseStream, response_obj: Dict[str, Any], ): - print_verbose( f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}" ) @@ -1008,7 +1007,6 @@ class CustomStreamWrapper: self.custom_llm_provider and self.custom_llm_provider in litellm._custom_providers ): - if self.received_finish_reason is not None: if "provider_specific_fields" not in chunk: raise StopIteration @@ -1379,9 +1377,9 @@ class CustomStreamWrapper: _json_delta = delta.model_dump() print_verbose(f"_json_delta: {_json_delta}") if "role" not in _json_delta or _json_delta["role"] is None: - _json_delta["role"] = ( - "assistant" # mistral's api returns role as None - ) + _json_delta[ + "role" + ] = "assistant" # mistral's api returns role as None if "tool_calls" in _json_delta and isinstance( _json_delta["tool_calls"], list ): @@ -1758,9 +1756,9 @@ class CustomStreamWrapper: chunk = next(self.completion_stream) if chunk is not None and chunk != b"": print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk: Optional[ModelResponseStream] = ( - self.chunk_creator(chunk=chunk) - ) + processed_chunk: Optional[ + ModelResponseStream + ] = self.chunk_creator(chunk=chunk) print_verbose( f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" ) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index f2c5f390d7..7625292e6e 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -290,7 +290,6 @@ class AnthropicChatCompletion(BaseLLM): headers={}, client=None, ): - optional_params = copy.deepcopy(optional_params) stream = optional_params.pop("stream", None) json_mode: bool = optional_params.pop("json_mode", False) @@ -491,7 +490,6 @@ class ModelResponseIterator: def _handle_usage( self, anthropic_usage_chunk: Union[dict, UsageDelta] ) -> AnthropicChatCompletionUsageBlock: - usage_block = AnthropicChatCompletionUsageBlock( prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), completion_tokens=anthropic_usage_chunk.get("output_tokens", 0), @@ -515,7 +513,9 @@ class ModelResponseIterator: return usage_block - def _content_block_delta_helper(self, chunk: dict) -> Tuple[ + def _content_block_delta_helper( + self, chunk: dict + ) -> Tuple[ str, Optional[ChatCompletionToolCallChunk], List[ChatCompletionThinkingBlock], @@ -592,9 +592,12 @@ class ModelResponseIterator: Anthropic content chunk chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} """ - text, tool_use, thinking_blocks, provider_specific_fields = ( - self._content_block_delta_helper(chunk=chunk) - ) + ( + text, + tool_use, + thinking_blocks, + provider_specific_fields, + ) = self._content_block_delta_helper(chunk=chunk) if thinking_blocks: reasoning_content = self._handle_reasoning_content( thinking_blocks=thinking_blocks @@ -620,7 +623,6 @@ class ModelResponseIterator: "index": self.tool_index, } elif type_chunk == "content_block_stop": - ContentBlockStop(**chunk) # type: ignore # check if tool call content block is_empty = self.check_empty_tool_call_args() diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index dcbc6775dc..a8f36cdcad 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig): to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens: Optional[int] = ( - 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) - ) + max_tokens: Optional[ + int + ] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None @@ -104,7 +104,6 @@ class AnthropicConfig(BaseConfig): def get_json_schema_from_pydantic_object( self, response_format: Union[Any, Dict, None] ) -> Optional[dict]: - return type_to_response_format_param( response_format, ref_template="/$defs/{model}" ) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755 @@ -125,7 +124,6 @@ class AnthropicConfig(BaseConfig): is_vertex_request: bool = False, user_anthropic_beta_headers: Optional[List[str]] = None, ) -> dict: - betas = set() if prompt_caching_set: betas.add("prompt-caching-2024-07-31") @@ -300,7 +298,6 @@ class AnthropicConfig(BaseConfig): model: str, drop_params: bool, ) -> dict: - is_thinking_enabled = self.is_thinking_enabled( non_default_params=non_default_params ) @@ -321,11 +318,11 @@ class AnthropicConfig(BaseConfig): optional_params=optional_params, tools=tool_value ) if param == "tool_choice" or param == "parallel_tool_calls": - _tool_choice: Optional[AnthropicMessagesToolChoice] = ( - self._map_tool_choice( - tool_choice=non_default_params.get("tool_choice"), - parallel_tool_use=non_default_params.get("parallel_tool_calls"), - ) + _tool_choice: Optional[ + AnthropicMessagesToolChoice + ] = self._map_tool_choice( + tool_choice=non_default_params.get("tool_choice"), + parallel_tool_use=non_default_params.get("parallel_tool_calls"), ) if _tool_choice is not None: @@ -341,7 +338,6 @@ class AnthropicConfig(BaseConfig): if param == "top_p": optional_params["top_p"] = value if param == "response_format" and isinstance(value, dict): - ignore_response_format_types = ["text"] if value["type"] in ignore_response_format_types: # value is a no-op continue @@ -470,9 +466,9 @@ class AnthropicConfig(BaseConfig): text=system_message_block["content"], ) if "cache_control" in system_message_block: - anthropic_system_message_content["cache_control"] = ( - system_message_block["cache_control"] - ) + anthropic_system_message_content[ + "cache_control" + ] = system_message_block["cache_control"] anthropic_system_message_list.append( anthropic_system_message_content ) @@ -486,9 +482,9 @@ class AnthropicConfig(BaseConfig): ) ) if "cache_control" in _content: - anthropic_system_message_content["cache_control"] = ( - _content["cache_control"] - ) + anthropic_system_message_content[ + "cache_control" + ] = _content["cache_control"] anthropic_system_message_list.append( anthropic_system_message_content @@ -597,7 +593,9 @@ class AnthropicConfig(BaseConfig): ) return _message - def extract_response_content(self, completion_response: dict) -> Tuple[ + def extract_response_content( + self, completion_response: dict + ) -> Tuple[ str, Optional[List[Any]], Optional[List[ChatCompletionThinkingBlock]], @@ -693,9 +691,13 @@ class AnthropicConfig(BaseConfig): reasoning_content: Optional[str] = None tool_calls: List[ChatCompletionToolCallChunk] = [] - text_content, citations, thinking_blocks, reasoning_content, tool_calls = ( - self.extract_response_content(completion_response=completion_response) - ) + ( + text_content, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + ) = self.extract_response_content(completion_response=completion_response) _message = litellm.Message( tool_calls=tool_calls, diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index 7a260b6f94..5cbc0b5fd8 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig): to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens_to_sample: Optional[int] = ( - litellm.max_tokens - ) # anthropic requires a default + max_tokens_to_sample: Optional[ + int + ] = litellm.max_tokens # anthropic requires a default stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py index a7dfff74d9..099a2acdae 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py @@ -25,7 +25,6 @@ from litellm.utils import ProviderConfigManager, client class AnthropicMessagesHandler: - @staticmethod async def _handle_anthropic_streaming( response: httpx.Response, @@ -74,19 +73,22 @@ async def anthropic_messages( """ # Use provided client or create a new one optional_params = GenericLiteLLMParams(**kwargs) - model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( - litellm.get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=optional_params.api_base, - api_key=optional_params.api_key, - ) + ( + model, + _custom_llm_provider, + dynamic_api_key, + dynamic_api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, ) - anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = ( - ProviderConfigManager.get_provider_anthropic_messages_config( - model=model, - provider=litellm.LlmProviders(_custom_llm_provider), - ) + anthropic_messages_provider_config: Optional[ + BaseAnthropicMessagesConfig + ] = ProviderConfigManager.get_provider_anthropic_messages_config( + model=model, + provider=litellm.LlmProviders(_custom_llm_provider), ) if anthropic_messages_provider_config is None: raise ValueError( diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 03c5cc09eb..aed813fdab 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -654,7 +654,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): ) -> EmbeddingResponse: response = None try: - openai_aclient = self.get_azure_openai_client( api_version=api_version, api_base=api_base, @@ -835,7 +834,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): "2023-10-01-preview", ] ): # CREATE + POLL for azure dall-e-2 calls - api_base = modify_url( original_url=api_base, new_path="/openai/images/generations:submit" ) @@ -867,7 +865,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): ) while response.json()["status"] not in ["succeeded", "failed"]: if time.time() - start_time > timeout_secs: - raise AzureOpenAIError( status_code=408, message="Operation polling timed out." ) @@ -935,7 +932,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): "2023-10-01-preview", ] ): # CREATE + POLL for azure dall-e-2 calls - api_base = modify_url( original_url=api_base, new_path="/openai/images/generations:submit" ) @@ -1199,7 +1195,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): client=None, litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: - max_retries = optional_params.pop("max_retries", 2) if aspeech is not None and aspeech is True: @@ -1253,7 +1248,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): client=None, litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: - azure_client: AsyncAzureOpenAI = self.get_azure_openai_client( api_base=api_base, api_version=api_version, diff --git a/litellm/llms/azure/batches/handler.py b/litellm/llms/azure/batches/handler.py index 1b93c526d5..7fc6388ba8 100644 --- a/litellm/llms/azure/batches/handler.py +++ b/litellm/llms/azure/batches/handler.py @@ -50,15 +50,15 @@ class AzureBatchesAPI(BaseAzureLLM): client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( @@ -96,15 +96,15 @@ class AzureBatchesAPI(BaseAzureLLM): client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( @@ -144,15 +144,15 @@ class AzureBatchesAPI(BaseAzureLLM): client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( @@ -183,15 +183,15 @@ class AzureBatchesAPI(BaseAzureLLM): client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - litellm_params=litellm_params or {}, - ) + azure_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 71092c8b99..5d61557c21 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -306,7 +306,6 @@ class BaseAzureLLM(BaseOpenAILLM): api_version: Optional[str], is_async: bool, ) -> dict: - azure_ad_token_provider: Optional[Callable[[], str]] = None # If we have api_key, then we have higher priority azure_ad_token = litellm_params.get("azure_ad_token") diff --git a/litellm/llms/azure/files/handler.py b/litellm/llms/azure/files/handler.py index d45ac9a315..98407d05d5 100644 --- a/litellm/llms/azure/files/handler.py +++ b/litellm/llms/azure/files/handler.py @@ -46,16 +46,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM): client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: - - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -95,15 +94,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM): ) -> Union[ HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] ]: - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -145,15 +144,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM): client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -197,15 +196,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM): client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( @@ -251,15 +250,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM): client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + openai_client: Optional[ + Union[AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, ) if openai_client is None: raise ValueError( diff --git a/litellm/llms/azure/fine_tuning/handler.py b/litellm/llms/azure/fine_tuning/handler.py index 3d7cc336fb..429b834989 100644 --- a/litellm/llms/azure/fine_tuning/handler.py +++ b/litellm/llms/azure/fine_tuning/handler.py @@ -25,14 +25,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM): _is_async: bool = False, api_version: Optional[str] = None, litellm_params: Optional[dict] = None, - ) -> Optional[ - Union[ - OpenAI, - AsyncOpenAI, - AzureOpenAI, - AsyncAzureOpenAI, - ] - ]: + ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: # Override to use Azure-specific client initialization if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): client = None diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py index 154f345537..a1fd24efa1 100644 --- a/litellm/llms/azure_ai/chat/transformation.py +++ b/litellm/llms/azure_ai/chat/transformation.py @@ -145,7 +145,6 @@ class AzureAIStudioConfig(OpenAIConfig): 2. If message contains an image or audio, send as is (user-intended) """ for message in messages: - # Do nothing if the message contains an image or audio if _audio_or_image_in_message_content(message): continue diff --git a/litellm/llms/azure_ai/embed/cohere_transformation.py b/litellm/llms/azure_ai/embed/cohere_transformation.py index 38b0dbbe23..64433c21b6 100644 --- a/litellm/llms/azure_ai/embed/cohere_transformation.py +++ b/litellm/llms/azure_ai/embed/cohere_transformation.py @@ -22,7 +22,6 @@ class AzureAICohereConfig: pass def _map_azure_model_group(self, model: str) -> str: - if model == "offer-cohere-embed-multili-paygo": return "Cohere-embed-v3-multilingual" elif model == "offer-cohere-embed-english-paygo": diff --git a/litellm/llms/azure_ai/embed/handler.py b/litellm/llms/azure_ai/embed/handler.py index f33c979ca2..da39c5f3b8 100644 --- a/litellm/llms/azure_ai/embed/handler.py +++ b/litellm/llms/azure_ai/embed/handler.py @@ -17,7 +17,6 @@ from .cohere_transformation import AzureAICohereConfig class AzureAIEmbedding(OpenAIChatCompletion): - def _process_response( self, image_embedding_responses: Optional[List], @@ -145,7 +144,6 @@ class AzureAIEmbedding(OpenAIChatCompletion): api_base: Optional[str] = None, client=None, ) -> EmbeddingResponse: - ( image_embeddings_request, v1_embeddings_request, diff --git a/litellm/llms/azure_ai/rerank/transformation.py b/litellm/llms/azure_ai/rerank/transformation.py index 842511f30d..4465e0d70a 100644 --- a/litellm/llms/azure_ai/rerank/transformation.py +++ b/litellm/llms/azure_ai/rerank/transformation.py @@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig): """ Azure AI Rerank - Follows the same Spec as Cohere Rerank """ + def get_complete_url(self, api_base: Optional[str], model: str) -> str: if api_base is None: raise ValueError( diff --git a/litellm/llms/base.py b/litellm/llms/base.py index deced222ca..abc314bba0 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -9,7 +9,6 @@ from litellm.types.utils import ModelResponse, TextCompletionResponse class BaseLLM: - _client_session: Optional[httpx.Client] = None def process_response( diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 45ea06b9e4..b4b120776c 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -218,7 +218,6 @@ class BaseConfig(ABC): json_schema = value["json_schema"]["schema"] if json_schema and not is_response_format_supported: - _tool_choice = ChatCompletionToolChoiceObjectParam( type="function", function=ChatCompletionToolChoiceFunctionParam( diff --git a/litellm/llms/base_llm/responses/transformation.py b/litellm/llms/base_llm/responses/transformation.py index 29555c55da..e98a579845 100644 --- a/litellm/llms/base_llm/responses/transformation.py +++ b/litellm/llms/base_llm/responses/transformation.py @@ -58,7 +58,6 @@ class BaseResponsesAPIConfig(ABC): model: str, drop_params: bool, ) -> Dict: - pass @abstractmethod diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index a4230177b5..7f529c637a 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -81,7 +81,6 @@ def make_sync_call( class BedrockConverseLLM(BaseAWSLLM): - def __init__(self) -> None: super().__init__() @@ -114,7 +113,6 @@ class BedrockConverseLLM(BaseAWSLLM): fake_stream: bool = False, json_mode: Optional[bool] = False, ) -> CustomStreamWrapper: - request_data = await litellm.AmazonConverseConfig()._async_transform_request( model=model, messages=messages, @@ -179,7 +177,6 @@ class BedrockConverseLLM(BaseAWSLLM): headers: dict = {}, client: Optional[AsyncHTTPHandler] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: - request_data = await litellm.AmazonConverseConfig()._async_transform_request( model=model, messages=messages, @@ -265,7 +262,6 @@ class BedrockConverseLLM(BaseAWSLLM): extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ): - ## SETUP ## stream = optional_params.pop("stream", None) unencoded_model_id = optional_params.pop("model_id", None) @@ -301,9 +297,9 @@ class BedrockConverseLLM(BaseAWSLLM): aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) optional_params.pop("aws_region_name", None) - litellm_params["aws_region_name"] = ( - aws_region_name # [DO NOT DELETE] important for async calls - ) + litellm_params[ + "aws_region_name" + ] = aws_region_name # [DO NOT DELETE] important for async calls credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index ced9c469b3..05386c62b5 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -223,7 +223,6 @@ class AmazonConverseConfig(BaseConfig): ) for param, value in non_default_params.items(): if param == "response_format" and isinstance(value, dict): - ignore_response_format_types = ["text"] if value["type"] in ignore_response_format_types: # value is a no-op continue @@ -715,9 +714,9 @@ class AmazonConverseConfig(BaseConfig): chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} content_str = "" tools: List[ChatCompletionToolCallChunk] = [] - reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = ( - None - ) + reasoningContentBlocks: Optional[ + List[BedrockConverseReasoningContentBlock] + ] = None if message is not None: for idx, content in enumerate(message["content"]): @@ -727,7 +726,6 @@ class AmazonConverseConfig(BaseConfig): if "text" in content: content_str += content["text"] if "toolUse" in content: - ## check tool name was formatted by litellm _response_tool_name = content["toolUse"]["name"] response_tool_name = get_bedrock_tool_name( @@ -754,12 +752,12 @@ class AmazonConverseConfig(BaseConfig): chat_completion_message["provider_specific_fields"] = { "reasoningContentBlocks": reasoningContentBlocks, } - chat_completion_message["reasoning_content"] = ( - self._transform_reasoning_content(reasoningContentBlocks) - ) - chat_completion_message["thinking_blocks"] = ( - self._transform_thinking_blocks(reasoningContentBlocks) - ) + chat_completion_message[ + "reasoning_content" + ] = self._transform_reasoning_content(reasoningContentBlocks) + chat_completion_message[ + "thinking_blocks" + ] = self._transform_thinking_blocks(reasoningContentBlocks) chat_completion_message["content"] = content_str if json_mode is True and tools is not None and len(tools) == 1: # to support 'json_schema' logic on bedrock models diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 5b02fd3158..09bdd63572 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -496,9 +496,9 @@ class BedrockLLM(BaseAWSLLM): content=None, ) model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = ( - outputText # allow user to access raw anthropic tool calling response - ) + model_response._hidden_params[ + "original_response" + ] = outputText # allow user to access raw anthropic tool calling response if ( _is_function_call is True and stream is not None @@ -806,9 +806,9 @@ class BedrockLLM(BaseAWSLLM): ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: - inference_params["stream"] = ( - True # cohere requires stream = True in inference params - ) + inference_params[ + "stream" + ] = True # cohere requires stream = True in inference params data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "anthropic": if model.startswith("anthropic.claude-3"): @@ -1205,7 +1205,6 @@ class BedrockLLM(BaseAWSLLM): def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: - from botocore.loaders import Loader from botocore.model import ServiceModel @@ -1539,7 +1538,6 @@ class AmazonDeepSeekR1StreamDecoder(AWSEventStreamDecoder): model: str, sync_stream: bool, ) -> None: - super().__init__(model=model) from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import ( AmazonDeepseekR1ResponseIterator, diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index 133eb659df..d1212705d8 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -225,9 +225,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: - inference_params["stream"] = ( - True # cohere requires stream = True in inference params - ) + inference_params[ + "stream" + ] = True # cohere requires stream = True in inference params request_data = {"prompt": prompt, **inference_params} elif provider == "anthropic": return litellm.AmazonAnthropicClaude3Config().transform_request( @@ -311,7 +311,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - try: completion_response = raw_response.json() except Exception: diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 4677a579ed..f4a1170660 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str: class BedrockModelInfo(BaseLLMModelInfo): - global_config = AmazonBedrockGlobalConfig() all_global_regions = global_config.get_all_regions() diff --git a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py index 6c1147f24a..338029adc3 100644 --- a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py +++ b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py @@ -33,9 +33,9 @@ class AmazonTitanMultimodalEmbeddingG1Config: ) -> dict: for k, v in non_default_params.items(): if k == "dimensions": - optional_params["embeddingConfig"] = ( - AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) - ) + optional_params[ + "embeddingConfig" + ] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) return optional_params def _transform_request( @@ -58,7 +58,6 @@ class AmazonTitanMultimodalEmbeddingG1Config: def _transform_response( self, response_list: List[dict], model: str ) -> EmbeddingResponse: - total_prompt_tokens = 0 transformed_responses: List[Embedding] = [] for index, response in enumerate(response_list): diff --git a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py index 33249f9af8..b331dd1b1d 100644 --- a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py +++ b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py @@ -1,12 +1,16 @@ import types -from typing import List, Optional +from typing import Any, Dict, List, Optional from openai.types.image import Image from litellm.types.llms.bedrock import ( - AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse, - AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams, + AmazonNovaCanvasColorGuidedGenerationParams, AmazonNovaCanvasColorGuidedRequest, + AmazonNovaCanvasImageGenerationConfig, + AmazonNovaCanvasRequestBase, + AmazonNovaCanvasTextToImageParams, + AmazonNovaCanvasTextToImageRequest, + AmazonNovaCanvasTextToImageResponse, ) from litellm.types.utils import ImageResponse @@ -23,7 +27,7 @@ class AmazonNovaCanvasConfig: k: v for k, v in cls.__dict__.items() if not k.startswith("__") - and not isinstance( + and not isinstance( v, ( types.FunctionType, @@ -32,13 +36,12 @@ class AmazonNovaCanvasConfig: staticmethod, ), ) - and v is not None + and v is not None } @classmethod def get_supported_openai_params(cls, model: Optional[str] = None) -> List: - """ - """ + """ """ return ["n", "size", "quality"] @classmethod @@ -56,7 +59,7 @@ class AmazonNovaCanvasConfig: @classmethod def transform_request_body( - cls, text: str, optional_params: dict + cls, text: str, optional_params: dict ) -> AmazonNovaCanvasRequestBase: """ Transform the request body for Amazon Nova Canvas model @@ -65,18 +68,64 @@ class AmazonNovaCanvasConfig: image_generation_config = optional_params.pop("imageGenerationConfig", {}) image_generation_config = {**image_generation_config, **optional_params} if task_type == "TEXT_IMAGE": - text_to_image_params = image_generation_config.pop("textToImageParams", {}) - text_to_image_params = {"text" :text, **text_to_image_params} - text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params) - return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type, - imageGenerationConfig=image_generation_config) + text_to_image_params: Dict[str, Any] = image_generation_config.pop( + "textToImageParams", {} + ) + text_to_image_params = {"text": text, **text_to_image_params} + try: + text_to_image_params_typed = AmazonNovaCanvasTextToImageParams( + **text_to_image_params # type: ignore + ) + except Exception as e: + raise ValueError( + f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}" + ) + + try: + image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig( + **image_generation_config + ) + except Exception as e: + raise ValueError( + f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}" + ) + + return AmazonNovaCanvasTextToImageRequest( + textToImageParams=text_to_image_params_typed, + taskType=task_type, + imageGenerationConfig=image_generation_config_typed, + ) if task_type == "COLOR_GUIDED_GENERATION": - color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {}) - color_guided_generation_params = {"text": text, **color_guided_generation_params} - color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params) - return AmazonNovaCanvasColorGuidedRequest(taskType=task_type, - colorGuidedGenerationParams=color_guided_generation_params, - imageGenerationConfig=image_generation_config) + color_guided_generation_params: Dict[ + str, Any + ] = image_generation_config.pop("colorGuidedGenerationParams", {}) + color_guided_generation_params = { + "text": text, + **color_guided_generation_params, + } + try: + color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams( + **color_guided_generation_params # type: ignore + ) + except Exception as e: + raise ValueError( + f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}" + ) + + try: + image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig( + **image_generation_config + ) + except Exception as e: + raise ValueError( + f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}" + ) + + return AmazonNovaCanvasColorGuidedRequest( + taskType=task_type, + colorGuidedGenerationParams=color_guided_generation_params_typed, + imageGenerationConfig=image_generation_config_typed, + ) raise NotImplementedError(f"Task type {task_type} is not supported") @classmethod @@ -87,7 +136,9 @@ class AmazonNovaCanvasConfig: _size = non_default_params.get("size") if _size is not None: width, height = _size.split("x") - optional_params["width"], optional_params["height"] = int(width), int(height) + optional_params["width"], optional_params["height"] = int(width), int( + height + ) if non_default_params.get("n") is not None: optional_params["numberOfImages"] = non_default_params.get("n") if non_default_params.get("quality") is not None: @@ -99,7 +150,7 @@ class AmazonNovaCanvasConfig: @classmethod def transform_response_dict_to_openai_response( - cls, model_response: ImageResponse, response_dict: dict + cls, model_response: ImageResponse, response_dict: dict ) -> ImageResponse: """ Transform the response dict to the OpenAI response diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py index 8f7762e547..27258aa20f 100644 --- a/litellm/llms/bedrock/image/image_handler.py +++ b/litellm/llms/bedrock/image/image_handler.py @@ -267,7 +267,11 @@ class BedrockImageGeneration(BaseAWSLLM): **inference_params, } elif provider == "amazon": - return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params)) + return dict( + litellm.AmazonNovaCanvasConfig.transform_request_body( + text=prompt, optional_params=optional_params + ) + ) else: raise BedrockError( status_code=422, message=f"Unsupported model={model}, passed in" @@ -303,8 +307,11 @@ class BedrockImageGeneration(BaseAWSLLM): config_class = ( litellm.AmazonStability3Config if litellm.AmazonStability3Config._is_stability_3_model(model=model) - else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) - else litellm.AmazonStabilityConfig + else ( + litellm.AmazonNovaCanvasConfig + if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) + else litellm.AmazonStabilityConfig + ) ) config_class.transform_response_dict_to_openai_response( model_response=model_response, diff --git a/litellm/llms/bedrock/rerank/handler.py b/litellm/llms/bedrock/rerank/handler.py index cd8be6912c..f5a532bec1 100644 --- a/litellm/llms/bedrock/rerank/handler.py +++ b/litellm/llms/bedrock/rerank/handler.py @@ -60,7 +60,6 @@ class BedrockRerankHandler(BaseAWSLLM): extra_headers: Optional[dict] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - request_data = RerankRequest( model=model, query=query, diff --git a/litellm/llms/bedrock/rerank/transformation.py b/litellm/llms/bedrock/rerank/transformation.py index a5380febe9..be8250a967 100644 --- a/litellm/llms/bedrock/rerank/transformation.py +++ b/litellm/llms/bedrock/rerank/transformation.py @@ -29,7 +29,6 @@ from litellm.types.rerank import ( class BedrockRerankConfig: - def _transform_sources( self, documents: List[Union[str, dict]] ) -> List[BedrockRerankSource]: diff --git a/litellm/llms/codestral/completion/handler.py b/litellm/llms/codestral/completion/handler.py index fc6d2886a9..555f7fccfb 100644 --- a/litellm/llms/codestral/completion/handler.py +++ b/litellm/llms/codestral/completion/handler.py @@ -314,7 +314,6 @@ class CodestralTextCompletion: return _response ### SYNC COMPLETION else: - response = litellm.module_level_client.post( url=completion_url, headers=headers, @@ -352,13 +351,11 @@ class CodestralTextCompletion: logger_fn=None, headers={}, ) -> TextCompletionResponse: - async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL, params={"timeout": timeout}, ) try: - response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) diff --git a/litellm/llms/codestral/completion/transformation.py b/litellm/llms/codestral/completion/transformation.py index 5955e91deb..fc7b6f5dbb 100644 --- a/litellm/llms/codestral/completion/transformation.py +++ b/litellm/llms/codestral/completion/transformation.py @@ -78,7 +78,6 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig): return optional_params def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: - text = "" is_finished = False finish_reason = None diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 3ceec2dbba..fbaedca8f6 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -180,7 +180,6 @@ class CohereChatConfig(BaseConfig): litellm_params: dict, headers: dict, ) -> dict: - ## Load Config for k, v in litellm.CohereChatConfig.get_config().items(): if ( @@ -222,7 +221,6 @@ class CohereChatConfig(BaseConfig): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - try: raw_response_json = raw_response.json() model_response.choices[0].message.content = raw_response_json["text"] # type: ignore diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index e7f22ea72a..7a25bf7e54 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -56,7 +56,6 @@ async def async_embedding( encoding: Callable, client: Optional[AsyncHTTPHandler] = None, ): - ## LOGGING logging_obj.pre_call( input=input, diff --git a/litellm/llms/cohere/embed/transformation.py b/litellm/llms/cohere/embed/transformation.py index 22e157a0fd..837dd5e006 100644 --- a/litellm/llms/cohere/embed/transformation.py +++ b/litellm/llms/cohere/embed/transformation.py @@ -72,7 +72,6 @@ class CohereEmbeddingConfig: return transformed_request def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage: - input_tokens = 0 text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens") @@ -111,7 +110,6 @@ class CohereEmbeddingConfig: encoding: Any, input: list, ) -> EmbeddingResponse: - response_json = response.json() ## LOGGING logging_obj.post_call( diff --git a/litellm/llms/cohere/rerank/transformation.py b/litellm/llms/cohere/rerank/transformation.py index f3624d9216..22782c1300 100644 --- a/litellm/llms/cohere/rerank/transformation.py +++ b/litellm/llms/cohere/rerank/transformation.py @@ -148,4 +148,4 @@ class CohereRerankConfig(BaseRerankConfig): def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: - return CohereError(message=error_message, status_code=status_code) \ No newline at end of file + return CohereError(message=error_message, status_code=status_code) diff --git a/litellm/llms/cohere/rerank_v2/transformation.py b/litellm/llms/cohere/rerank_v2/transformation.py index a93cb982a7..74e760460d 100644 --- a/litellm/llms/cohere/rerank_v2/transformation.py +++ b/litellm/llms/cohere/rerank_v2/transformation.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union from litellm.llms.cohere.rerank.transformation import CohereRerankConfig from litellm.types.rerank import OptionalRerankParams, RerankRequest + class CohereRerankV2Config(CohereRerankConfig): """ Reference: https://docs.cohere.com/v2/reference/rerank @@ -77,4 +78,4 @@ class CohereRerankV2Config(CohereRerankConfig): return_documents=optional_rerank_params.get("return_documents", None), max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None), ) - return rerank_request.model_dump(exclude_none=True) \ No newline at end of file + return rerank_request.model_dump(exclude_none=True) diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index c865fee17e..9568ce7185 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -32,7 +32,6 @@ DEFAULT_TIMEOUT = 600 class BaseLLMAIOHTTPHandler: - def __init__(self): self.client_session: Optional[aiohttp.ClientSession] = None @@ -110,7 +109,6 @@ class BaseLLMAIOHTTPHandler: content: Any = None, params: Optional[dict] = None, ) -> httpx.Response: - max_retry_on_unprocessable_entity_error = ( provider_config.max_retry_on_unprocessable_entity_error ) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 34d70434d5..23d7fe4b4d 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -114,7 +114,6 @@ class AsyncHTTPHandler: event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]], ssl_verify: Optional[VerifyTypes] = None, ) -> httpx.AsyncClient: - # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # /path/to/certificate.pem if ssl_verify is None: @@ -590,7 +589,6 @@ class HTTPHandler: timeout: Optional[Union[float, httpx.Timeout]] = None, ): try: - if timeout is not None: req = self.client.build_request( "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore @@ -609,7 +607,6 @@ class HTTPHandler: llm_provider="litellm-httpx-handler", ) except httpx.HTTPStatusError as e: - if stream is True: setattr(e, "message", mask_sensitive_info(e.response.read())) setattr(e, "text", mask_sensitive_info(e.response.read())) @@ -635,7 +632,6 @@ class HTTPHandler: timeout: Optional[Union[float, httpx.Timeout]] = None, ): try: - if timeout is not None: req = self.client.build_request( "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 872626c747..12736640f1 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -41,7 +41,6 @@ else: class BaseLLMHTTPHandler: - async def _make_common_async_call( self, async_httpx_client: AsyncHTTPHandler, @@ -109,7 +108,6 @@ class BaseLLMHTTPHandler: logging_obj: LiteLLMLoggingObj, stream: bool = False, ) -> httpx.Response: - max_retry_on_unprocessable_entity_error = ( provider_config.max_retry_on_unprocessable_entity_error ) @@ -599,7 +597,6 @@ class BaseLLMHTTPHandler: aembedding: bool = False, headers={}, ) -> EmbeddingResponse: - provider_config = ProviderConfigManager.get_provider_embedding_config( model=model, provider=litellm.LlmProviders(custom_llm_provider) ) @@ -742,7 +739,6 @@ class BaseLLMHTTPHandler: api_base: Optional[str] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - # get config from model, custom llm provider headers = provider_config.validate_environment( api_key=api_key, @@ -828,7 +824,6 @@ class BaseLLMHTTPHandler: timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - if client is None or not isinstance(client, AsyncHTTPHandler): async_httpx_client = get_async_httpx_client( llm_provider=litellm.LlmProviders(custom_llm_provider) diff --git a/litellm/llms/databricks/common_utils.py b/litellm/llms/databricks/common_utils.py index e8481e25b2..76bd281d4d 100644 --- a/litellm/llms/databricks/common_utils.py +++ b/litellm/llms/databricks/common_utils.py @@ -16,9 +16,9 @@ class DatabricksBase: api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" if api_key is None: - databricks_auth_headers: dict[str, str] = ( - databricks_client.config.authenticate() - ) + databricks_auth_headers: dict[ + str, str + ] = databricks_client.config.authenticate() headers = {**databricks_auth_headers, **headers} return api_base, headers diff --git a/litellm/llms/databricks/embed/transformation.py b/litellm/llms/databricks/embed/transformation.py index 53e3b30dd2..a113a349cc 100644 --- a/litellm/llms/databricks/embed/transformation.py +++ b/litellm/llms/databricks/embed/transformation.py @@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig: Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task """ - instruction: Optional[str] = ( - None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries - ) + instruction: Optional[ + str + ] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries def __init__(self, instruction: Optional[str] = None) -> None: locals_ = locals().copy() diff --git a/litellm/llms/databricks/streaming_utils.py b/litellm/llms/databricks/streaming_utils.py index 2db53df908..eebe318288 100644 --- a/litellm/llms/databricks/streaming_utils.py +++ b/litellm/llms/databricks/streaming_utils.py @@ -55,7 +55,6 @@ class ModelResponseIterator: usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None) if usage_chunk is not None: - usage = ChatCompletionUsageBlock( prompt_tokens=usage_chunk.prompt_tokens, completion_tokens=usage_chunk.completion_tokens, diff --git a/litellm/llms/deepgram/audio_transcription/transformation.py b/litellm/llms/deepgram/audio_transcription/transformation.py index 90720a77f7..20599e3994 100644 --- a/litellm/llms/deepgram/audio_transcription/transformation.py +++ b/litellm/llms/deepgram/audio_transcription/transformation.py @@ -126,9 +126,9 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig): # Add additional metadata matching OpenAI format response["task"] = "transcribe" - response["language"] = ( - "english" # Deepgram auto-detects but doesn't return language - ) + response[ + "language" + ] = "english" # Deepgram auto-detects but doesn't return language response["duration"] = response_json["metadata"]["duration"] # Transform words to match OpenAI format diff --git a/litellm/llms/deepseek/chat/transformation.py b/litellm/llms/deepseek/chat/transformation.py index 180cf7dc69..fe70ebe77e 100644 --- a/litellm/llms/deepseek/chat/transformation.py +++ b/litellm/llms/deepseek/chat/transformation.py @@ -14,7 +14,6 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig class DeepSeekChatConfig(OpenAIGPTConfig): - def _transform_messages( self, messages: List[AllMessageValues], model: str ) -> List[AllMessageValues]: diff --git a/litellm/llms/deprecated_providers/aleph_alpha.py b/litellm/llms/deprecated_providers/aleph_alpha.py index 81ad134641..4cfede2a1b 100644 --- a/litellm/llms/deprecated_providers/aleph_alpha.py +++ b/litellm/llms/deprecated_providers/aleph_alpha.py @@ -77,9 +77,9 @@ class AlephAlphaConfig: - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. """ - maximum_tokens: Optional[int] = ( - litellm.max_tokens - ) # aleph alpha requires max tokens + maximum_tokens: Optional[ + int + ] = litellm.max_tokens # aleph alpha requires max tokens minimum_tokens: Optional[int] = None echo: Optional[bool] = None temperature: Optional[int] = None diff --git a/litellm/llms/fireworks_ai/chat/transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py index 1c82f24ac0..4def12adb7 100644 --- a/litellm/llms/fireworks_ai/chat/transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -88,7 +88,6 @@ class FireworksAIConfig(OpenAIGPTConfig): model: str, drop_params: bool, ) -> dict: - supported_openai_params = self.get_supported_openai_params(model=model) is_tools_set = any( param == "tools" and value is not None @@ -104,7 +103,6 @@ class FireworksAIConfig(OpenAIGPTConfig): # pass through the value of tool choice optional_params["tool_choice"] = value elif param == "response_format": - if ( is_tools_set ): # fireworks ai doesn't support tools and response_format together @@ -223,7 +221,6 @@ class FireworksAIConfig(OpenAIGPTConfig): return api_base, dynamic_api_key def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None): - api_base, api_key = self._get_openai_compatible_provider_info( api_base=api_base, api_key=api_key ) diff --git a/litellm/llms/gemini/chat/transformation.py b/litellm/llms/gemini/chat/transformation.py index fbc1916dcc..0d5956122e 100644 --- a/litellm/llms/gemini/chat/transformation.py +++ b/litellm/llms/gemini/chat/transformation.py @@ -90,7 +90,6 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig): model: str, drop_params: bool, ) -> Dict: - if litellm.vertex_ai_safety_settings is not None: optional_params["safety_settings"] = litellm.vertex_ai_safety_settings return super().map_openai_params( diff --git a/litellm/llms/gemini/common_utils.py b/litellm/llms/gemini/common_utils.py index 7f266c0536..4c3357a500 100644 --- a/litellm/llms/gemini/common_utils.py +++ b/litellm/llms/gemini/common_utils.py @@ -25,7 +25,6 @@ class GeminiModelInfo(BaseLLMModelInfo): def get_models( self, api_key: Optional[str] = None, api_base: Optional[str] = None ) -> List[str]: - api_base = GeminiModelInfo.get_api_base(api_base) api_key = GeminiModelInfo.get_api_key(api_key) if api_base is None or api_key is None: diff --git a/litellm/llms/groq/chat/transformation.py b/litellm/llms/groq/chat/transformation.py index 5b24f7d112..b0ee69bed2 100644 --- a/litellm/llms/groq/chat/transformation.py +++ b/litellm/llms/groq/chat/transformation.py @@ -18,7 +18,6 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig class GroqChatConfig(OpenAIGPTConfig): - frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None functions: Optional[list] = None diff --git a/litellm/llms/groq/stt/transformation.py b/litellm/llms/groq/stt/transformation.py index c4dbd8d0ca..b467fab14f 100644 --- a/litellm/llms/groq/stt/transformation.py +++ b/litellm/llms/groq/stt/transformation.py @@ -9,7 +9,6 @@ import litellm class GroqSTTConfig: - frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None functions: Optional[list] = None diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py index 858fda473e..082960b2c2 100644 --- a/litellm/llms/huggingface/chat/transformation.py +++ b/litellm/llms/huggingface/chat/transformation.py @@ -40,17 +40,17 @@ class HuggingfaceChatConfig(BaseConfig): Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate """ - hf_task: Optional[hf_tasks] = ( - None # litellm-specific param, used to know the api spec to use when calling huggingface api - ) + hf_task: Optional[ + hf_tasks + ] = None # litellm-specific param, used to know the api spec to use when calling huggingface api best_of: Optional[int] = None decoder_input_details: Optional[bool] = None details: Optional[bool] = True # enables returning logprobs + best of max_new_tokens: Optional[int] = None repetition_penalty: Optional[float] = None - return_full_text: Optional[bool] = ( - False # by default don't return the input as part of the output - ) + return_full_text: Optional[ + bool + ] = False # by default don't return the input as part of the output seed: Optional[int] = None temperature: Optional[float] = None top_k: Optional[int] = None @@ -120,9 +120,9 @@ class HuggingfaceChatConfig(BaseConfig): optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if param == "stream": optional_params["stream"] = value if param == "stop": @@ -362,9 +362,9 @@ class HuggingfaceChatConfig(BaseConfig): "content-type": "application/json", } if api_key is not None: - default_headers["Authorization"] = ( - f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens - ) + default_headers[ + "Authorization" + ] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens headers = {**headers, **default_headers} return headers diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py index 5f2b8d71bc..418d13b344 100644 --- a/litellm/llms/maritalk.py +++ b/litellm/llms/maritalk.py @@ -17,7 +17,6 @@ class MaritalkError(BaseLLMException): class MaritalkConfig(OpenAIGPTConfig): - def __init__( self, frequency_penalty: Optional[float] = None, diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index b4db95cfa1..b007bbb2bc 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -89,9 +89,9 @@ class OllamaConfig(BaseConfig): repeat_penalty: Optional[float] = None temperature: Optional[float] = None seed: Optional[int] = None - stop: Optional[list] = ( - None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 - ) + stop: Optional[ + list + ] = None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 tfs_z: Optional[float] = None num_predict: Optional[int] = None top_k: Optional[int] = None diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index c765f97979..1fde23c9c2 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -391,7 +391,6 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig): class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator): - def chunk_parser(self, chunk: dict) -> ModelResponseStream: try: return ModelResponseStream( diff --git a/litellm/llms/openai/completion/handler.py b/litellm/llms/openai/completion/handler.py index 2e60f55b57..fa31c487cd 100644 --- a/litellm/llms/openai/completion/handler.py +++ b/litellm/llms/openai/completion/handler.py @@ -220,7 +220,6 @@ class OpenAITextCompletion(BaseLLM): client=None, organization=None, ): - if client is None: openai_client = OpenAI( api_key=api_key, diff --git a/litellm/llms/openai/completion/transformation.py b/litellm/llms/openai/completion/transformation.py index 1aef72d3fa..43fbc1f219 100644 --- a/litellm/llms/openai/completion/transformation.py +++ b/litellm/llms/openai/completion/transformation.py @@ -111,9 +111,9 @@ class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig): if "model" in response_object: model_response_object.model = response_object["model"] - model_response_object._hidden_params["original_response"] = ( - response_object # track original response, if users make a litellm.text_completion() request, we can return the original response - ) + model_response_object._hidden_params[ + "original_response" + ] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response return model_response_object except Exception as e: raise e diff --git a/litellm/llms/openai/fine_tuning/handler.py b/litellm/llms/openai/fine_tuning/handler.py index 97b237c757..2b697f85d2 100644 --- a/litellm/llms/openai/fine_tuning/handler.py +++ b/litellm/llms/openai/fine_tuning/handler.py @@ -28,14 +28,7 @@ class OpenAIFineTuningAPI: _is_async: bool = False, api_version: Optional[str] = None, litellm_params: Optional[dict] = None, - ) -> Optional[ - Union[ - OpenAI, - AsyncOpenAI, - AzureOpenAI, - AsyncAzureOpenAI, - ] - ]: + ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: received_args = locals() openai_client: Optional[ Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index deb70b481e..0545542ead 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -266,7 +266,6 @@ class OpenAIConfig(BaseConfig): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - logging_obj.post_call(original_response=raw_response.text) logging_obj.model_call_details["response_headers"] = raw_response.headers final_response_obj = cast( @@ -320,7 +319,6 @@ class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator): class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): - def __init__(self) -> None: super().__init__() @@ -513,7 +511,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): custom_llm_provider: Optional[str] = None, drop_params: Optional[bool] = None, ): - super().completion() try: fake_stream: bool = False @@ -553,7 +550,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message - if provider_config is not None: data = provider_config.transform_request( model=model, @@ -649,13 +645,14 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): }, ) - headers, response = ( - self.make_sync_openai_chat_completion_request( - openai_client=openai_client, - data=data, - timeout=timeout, - logging_obj=logging_obj, - ) + ( + headers, + response, + ) = self.make_sync_openai_chat_completion_request( + openai_client=openai_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, ) logging_obj.model_call_details["response_headers"] = headers @@ -763,7 +760,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message - try: openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore is_async=True, @@ -973,7 +969,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): except ( Exception ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. - if isinstance(e, OpenAIError): raise e @@ -1246,7 +1241,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): ): response = None try: - openai_aclient = self._get_openai_client( is_async=True, api_key=api_key, @@ -1333,7 +1327,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): ) return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore except OpenAIError as e: - ## LOGGING logging_obj.post_call( input=prompt, @@ -1372,7 +1365,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): aspeech: Optional[bool] = None, client=None, ) -> HttpxBinaryResponseContent: - if aspeech is not None and aspeech is True: return self.async_audio_speech( model=model, @@ -1419,7 +1411,6 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): timeout: Union[float, httpx.Timeout], client=None, ) -> HttpxBinaryResponseContent: - openai_client = cast( AsyncOpenAI, self._get_openai_client( diff --git a/litellm/llms/openai/transcriptions/whisper_transformation.py b/litellm/llms/openai/transcriptions/whisper_transformation.py index 5a7d6481a8..2d3d611dac 100644 --- a/litellm/llms/openai/transcriptions/whisper_transformation.py +++ b/litellm/llms/openai/transcriptions/whisper_transformation.py @@ -81,9 +81,9 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig): if "response_format" not in data or ( data["response_format"] == "text" or data["response_format"] == "json" ): - data["response_format"] = ( - "verbose_json" # ensures 'duration' is received - used for cost calculation - ) + data[ + "response_format" + ] = "verbose_json" # ensures 'duration' is received - used for cost calculation return data diff --git a/litellm/llms/openrouter/chat/transformation.py b/litellm/llms/openrouter/chat/transformation.py index ab4d3c52b9..0b47167524 100644 --- a/litellm/llms/openrouter/chat/transformation.py +++ b/litellm/llms/openrouter/chat/transformation.py @@ -19,7 +19,6 @@ from ..common_utils import OpenRouterException class OpenrouterConfig(OpenAIGPTConfig): - def map_openai_params( self, non_default_params: dict, @@ -42,9 +41,9 @@ class OpenrouterConfig(OpenAIGPTConfig): extra_body["models"] = models if route is not None: extra_body["route"] = route - mapped_openai_params["extra_body"] = ( - extra_body # openai client supports `extra_body` param - ) + mapped_openai_params[ + "extra_body" + ] = extra_body # openai client supports `extra_body` param return mapped_openai_params def get_error_class( @@ -70,7 +69,6 @@ class OpenrouterConfig(OpenAIGPTConfig): class OpenRouterChatCompletionStreamingHandler(BaseModelResponseIterator): - def chunk_parser(self, chunk: dict) -> ModelResponseStream: try: new_choices = [] diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index 08ec15de33..a9e37d27fc 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -37,9 +37,9 @@ class PetalsConfig(BaseConfig): """ max_length: Optional[int] = None - max_new_tokens: Optional[int] = ( - litellm.max_tokens - ) # petals requires max tokens to be set + max_new_tokens: Optional[ + int + ] = litellm.max_tokens # petals requires max tokens to be set do_sample: Optional[bool] = None temperature: Optional[float] = None top_k: Optional[int] = None diff --git a/litellm/llms/predibase/chat/handler.py b/litellm/llms/predibase/chat/handler.py index 43f4b06745..cd80fa53e4 100644 --- a/litellm/llms/predibase/chat/handler.py +++ b/litellm/llms/predibase/chat/handler.py @@ -394,7 +394,6 @@ class PredibaseChatCompletion: logger_fn=None, headers={}, ) -> ModelResponse: - async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.PREDIBASE, params={"timeout": timeout}, diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index f574238696..f1a2163d24 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -30,9 +30,9 @@ class PredibaseConfig(BaseConfig): 256 # openai default - requests hang if max_new_tokens not given ) repetition_penalty: Optional[float] = None - return_full_text: Optional[bool] = ( - False # by default don't return the input as part of the output - ) + return_full_text: Optional[ + bool + ] = False # by default don't return the input as part of the output seed: Optional[int] = None stop: Optional[List[str]] = None temperature: Optional[float] = None @@ -99,9 +99,9 @@ class PredibaseConfig(BaseConfig): optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if param == "stream": optional_params["stream"] = value if param == "stop": diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py index f52eb2ee05..526f376b89 100644 --- a/litellm/llms/replicate/chat/handler.py +++ b/litellm/llms/replicate/chat/handler.py @@ -244,7 +244,6 @@ async def async_completion( print_verbose, headers: dict, ) -> Union[ModelResponse, CustomStreamWrapper]: - prediction_url = replicate_config.get_complete_url( api_base=api_base, model=model, diff --git a/litellm/llms/sagemaker/chat/handler.py b/litellm/llms/sagemaker/chat/handler.py index c827a8a5f7..b86cda7aea 100644 --- a/litellm/llms/sagemaker/chat/handler.py +++ b/litellm/llms/sagemaker/chat/handler.py @@ -13,7 +13,6 @@ from .transformation import SagemakerChatConfig class SagemakerChatHandler(BaseAWSLLM): - def _load_credentials( self, optional_params: dict, @@ -128,7 +127,6 @@ class SagemakerChatHandler(BaseAWSLLM): headers: dict = {}, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): - # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) diff --git a/litellm/llms/sagemaker/common_utils.py b/litellm/llms/sagemaker/common_utils.py index 9884f420c3..ce0c6c9506 100644 --- a/litellm/llms/sagemaker/common_utils.py +++ b/litellm/llms/sagemaker/common_utils.py @@ -34,7 +34,6 @@ class AWSEventStreamDecoder: def _chunk_parser_messages_api( self, chunk_data: dict ) -> StreamingChatCompletionChunk: - openai_chunk = StreamingChatCompletionChunk(**chunk_data) return openai_chunk @@ -192,7 +191,6 @@ class AWSEventStreamDecoder: def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: - from botocore.loaders import Loader from botocore.model import ServiceModel diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py index 909caf73c3..296689c31c 100644 --- a/litellm/llms/sagemaker/completion/handler.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, cast import httpx @@ -35,7 +35,6 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = "" # set os.environ['AWS_REGION_NAME'] = class SagemakerLLM(BaseAWSLLM): - def _load_credentials( self, optional_params: dict, @@ -154,7 +153,6 @@ class SagemakerLLM(BaseAWSLLM): acompletion: bool = False, headers: dict = {}, ): - # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) @@ -437,10 +435,14 @@ class SagemakerLLM(BaseAWSLLM): prepared_request.headers.update( {"X-Amzn-SageMaker-Inference-Component": model_id} ) + + if not prepared_request.body: + raise ValueError("Prepared request body is empty") + completion_stream = await self.make_async_call( api_base=prepared_request.url, headers=prepared_request.headers, # type: ignore - data=prepared_request.body, + data=cast(str, prepared_request.body), logging_obj=logging_obj, ) streaming_response = CustomStreamWrapper( @@ -625,7 +627,7 @@ class SagemakerLLM(BaseAWSLLM): response = client.invoke_endpoint( EndpointName={model}, ContentType="application/json", - Body={data}, # type: ignore + Body=f"{data!r}", # Use !r for safe representation CustomAttributes="accept_eula=true", )""" # type: ignore logging_obj.pre_call( diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index d0ab5d0697..9923c0e45d 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -88,9 +88,9 @@ class SagemakerConfig(BaseConfig): optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if param == "stream": optional_params["stream"] = value if param == "stop": diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py index 4714376979..1fdb772add 100644 --- a/litellm/llms/together_ai/rerank/transformation.py +++ b/litellm/llms/together_ai/rerank/transformation.py @@ -19,7 +19,6 @@ from litellm.types.rerank import ( class TogetherAIRerankConfig: def _transform_response(self, response: dict) -> RerankResponse: - _billed_units = RerankBilledUnits(**response.get("usage", {})) _tokens = RerankTokens(**response.get("usage", {})) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) diff --git a/litellm/llms/topaz/image_variations/transformation.py b/litellm/llms/topaz/image_variations/transformation.py index 8b95deed04..6188101015 100644 --- a/litellm/llms/topaz/image_variations/transformation.py +++ b/litellm/llms/topaz/image_variations/transformation.py @@ -121,7 +121,6 @@ class TopazImageVariationConfig(BaseImageVariationConfig): optional_params: dict, headers: dict, ) -> HttpHandlerRequestFields: - request_params = HttpHandlerRequestFields( files={"image": self.prepare_file_tuple(image)}, data=optional_params, @@ -134,7 +133,6 @@ class TopazImageVariationConfig(BaseImageVariationConfig): image_content: bytes, response_ms: float, ) -> ImageResponse: - # Convert to base64 base64_image = base64.b64encode(image_content).decode("utf-8") diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 56151f89ef..46b607d455 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -244,7 +244,6 @@ class TritonInferConfig(TritonConfig): litellm_params: dict, headers: dict, ) -> dict: - text_input = messages[0].get("content", "") data_for_triton = { "inputs": [ diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index b82268bef6..dc3f93857a 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -35,7 +35,6 @@ class VertexAIBatchPrediction(VertexLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: - sync_handler = _get_httpx_client() access_token, project_id = self._ensure_access_token( @@ -69,10 +68,8 @@ class VertexAIBatchPrediction(VertexLLM): "Authorization": f"Bearer {access_token}", } - vertex_batch_request: VertexAIBatchPredictionJob = ( - VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( - request=create_batch_data - ) + vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( + request=create_batch_data ) if _is_async is True: diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 337445777a..f2cd1ef557 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -243,7 +243,7 @@ def convert_anyof_null_to_nullable(schema, depth=0): # remove null type anyof.remove(atype) contains_null = True - + if len(anyof) == 0: # Edge case: response schema with only null type present is invalid in Vertex AI raise ValueError( @@ -251,12 +251,10 @@ def convert_anyof_null_to_nullable(schema, depth=0): "Please provide a non-null type." ) - if contains_null: # set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python for atype in anyof: atype["nullable"] = True - properties = schema.get("properties", None) if properties is not None: diff --git a/litellm/llms/vertex_ai/files/handler.py b/litellm/llms/vertex_ai/files/handler.py index 266169cdfb..7000cf151d 100644 --- a/litellm/llms/vertex_ai/files/handler.py +++ b/litellm/llms/vertex_ai/files/handler.py @@ -49,10 +49,11 @@ class VertexAIFilesHandler(GCSBucketBase): service_account_json=gcs_logging_config["path_service_account"], ) bucket_name = gcs_logging_config["bucket_name"] - logging_payload, object_name = ( - vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content( - openai_file_content=create_file_data.get("file") - ) + ( + logging_payload, + object_name, + ) = vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content( + openai_file_content=create_file_data.get("file") ) gcs_upload_response = await self._log_json_data_on_gcs( headers=headers, diff --git a/litellm/llms/vertex_ai/fine_tuning/handler.py b/litellm/llms/vertex_ai/fine_tuning/handler.py index 3cf409c78e..7ea8527fd4 100644 --- a/litellm/llms/vertex_ai/fine_tuning/handler.py +++ b/litellm/llms/vertex_ai/fine_tuning/handler.py @@ -36,7 +36,6 @@ class VertexFineTuningAPI(VertexLLM): def convert_response_created_at(self, response: ResponseTuningJob): try: - create_time_str = response.get("createTime", "") or "" create_time_datetime = datetime.fromisoformat( create_time_str.replace("Z", "+00:00") @@ -65,9 +64,9 @@ class VertexFineTuningAPI(VertexLLM): ) if create_fine_tuning_job_data.validation_file: - supervised_tuning_spec["validation_dataset"] = ( - create_fine_tuning_job_data.validation_file - ) + supervised_tuning_spec[ + "validation_dataset" + ] = create_fine_tuning_job_data.validation_file _vertex_hyperparameters = ( self._transform_openai_hyperparameters_to_vertex_hyperparameters( @@ -175,7 +174,6 @@ class VertexFineTuningAPI(VertexLLM): headers: dict, request_data: FineTuneJobCreate, ): - try: verbose_logger.debug( "about to create fine tuning job: %s, request_data: %s", @@ -229,7 +227,6 @@ class VertexFineTuningAPI(VertexLLM): kwargs: Optional[dict] = None, original_hyperparameters: Optional[dict] = {}, ): - verbose_logger.debug( "creating fine tuning job, args= %s", create_fine_tuning_job_data ) @@ -346,9 +343,9 @@ class VertexFineTuningAPI(VertexLLM): elif "cachedContents" in request_route: _model = request_data.get("model") if _model is not None and "/publishers/google/models/" not in _model: - request_data["model"] = ( - f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" - ) + request_data[ + "model" + ] = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}" else: diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index d6bafc7c60..96b33ee187 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -85,7 +85,6 @@ def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartT and (image_type := format or _get_image_mime_type_from_url(image_url)) is not None ): - file_data = FileDataType(file_uri=image_url, mime_type=image_type) return PartType(file_data=file_data) elif "http://" in image_url or "https://" in image_url or "base64" in image_url: @@ -414,18 +413,19 @@ async def async_transform_request_body( context_caching_endpoints = ContextCachingEndpoints() if gemini_api_key is not None: - messages, cached_content = ( - await context_caching_endpoints.async_check_and_create_cache( - messages=messages, - api_key=gemini_api_key, - api_base=api_base, - model=model, - client=client, - timeout=timeout, - extra_headers=extra_headers, - cached_content=optional_params.pop("cached_content", None), - logging_obj=logging_obj, - ) + ( + messages, + cached_content, + ) = await context_caching_endpoints.async_check_and_create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, ) else: # [TODO] implement context caching for gemini as well cached_content = optional_params.pop("cached_content", None) diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 90c66f69a3..860dec9eb2 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -246,9 +246,9 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): value = _remove_strict_from_schema(value) for tool in value: - openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( - None - ) + openai_function_object: Optional[ + ChatCompletionToolParamFunctionChunk + ] = None if "function" in tool: # tools list _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore **tool["function"] @@ -813,15 +813,15 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): ## ADD SAFETY RATINGS ## setattr(model_response, "vertex_ai_safety_results", safety_ratings) - model_response._hidden_params["vertex_ai_safety_results"] = ( - safety_ratings # older approach - maintaining to prevent regressions - ) + model_response._hidden_params[ + "vertex_ai_safety_results" + ] = safety_ratings # older approach - maintaining to prevent regressions ## ADD CITATION METADATA ## setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) - model_response._hidden_params["vertex_ai_citation_metadata"] = ( - citation_metadata # older approach - maintaining to prevent regressions - ) + model_response._hidden_params[ + "vertex_ai_citation_metadata" + ] = citation_metadata # older approach - maintaining to prevent regressions except Exception as e: raise VertexAIError( diff --git a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py index 0fe5145a14..ecfe2ee8b4 100644 --- a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py +++ b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py @@ -47,7 +47,6 @@ class GoogleBatchEmbeddings(VertexLLM): timeout=300, client=None, ) -> EmbeddingResponse: - _auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, diff --git a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py index 592dac5846..2c0f5dad22 100644 --- a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py +++ b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py @@ -52,7 +52,6 @@ def process_response( model: str, _predictions: VertexAIBatchEmbeddingsResponseObject, ) -> EmbeddingResponse: - openai_embeddings: List[Embedding] = [] for embedding in _predictions["embeddings"]: openai_embedding = Embedding( diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py index 34879ae9ac..88d7339449 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py @@ -50,7 +50,6 @@ class VertexMultimodalEmbedding(VertexLLM): timeout=300, client=None, ) -> EmbeddingResponse: - _auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py index 8f96ca2bb0..afa58c7e5c 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py @@ -260,7 +260,6 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig): def transform_embedding_response_to_openai( self, predictions: MultimodalPredictions ) -> List[Embedding]: - openai_embeddings: List[Embedding] = [] if "predictions" in predictions: for idx, _prediction in enumerate(predictions["predictions"]): diff --git a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py index 744e1eb317..df267d9623 100644 --- a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py @@ -323,7 +323,6 @@ def completion( # noqa: PLR0915 ) completion_response = chat.send_message(prompt, **optional_params).text elif mode == "text": - if fake_stream is not True and stream is True: request_str += ( f"llm_model.predict_streaming({prompt}, **{optional_params})\n" @@ -506,7 +505,6 @@ async def async_completion( # noqa: PLR0915 Add support for acompletion calls for gemini-pro """ try: - response_obj = None completion_response = None if mode == "chat": diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py index fb2393631b..b8d2658f80 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py @@ -110,7 +110,6 @@ class VertexAIPartnerModels(VertexBase): message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", ) try: - vertex_httpx_logic = VertexLLM() access_token, project_id = vertex_httpx_logic._ensure_access_token( diff --git a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py index 2e8051d4d2..1167ca285f 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py @@ -86,10 +86,8 @@ class VertexEmbedding(VertexBase): mode="embedding", ) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) - vertex_request: VertexEmbeddingRequest = ( - litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params, model=model - ) + vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params, model=model ) _client_params = {} @@ -178,10 +176,8 @@ class VertexEmbedding(VertexBase): mode="embedding", ) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) - vertex_request: VertexEmbeddingRequest = ( - litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params, model=model - ) + vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params, model=model ) _async_client_params = {} diff --git a/litellm/llms/vertex_ai/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai/vertex_embeddings/transformation.py index d9e84fca03..97af558041 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/transformation.py @@ -212,7 +212,6 @@ class VertexAITextEmbeddingConfig(BaseModel): embedding_response = [] input_tokens: int = 0 for idx, element in enumerate(_predictions): - embedding = element["embeddings"] embedding_response.append( { diff --git a/litellm/llms/vertex_ai/vertex_model_garden/main.py b/litellm/llms/vertex_ai/vertex_model_garden/main.py index 7b54d4e34b..1c57096734 100644 --- a/litellm/llms/vertex_ai/vertex_model_garden/main.py +++ b/litellm/llms/vertex_ai/vertex_model_garden/main.py @@ -76,7 +76,6 @@ class VertexAIModelGardenModels(VertexBase): VertexLLM, ) except Exception as e: - raise VertexAIError( status_code=400, message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""", diff --git a/litellm/llms/watsonx/chat/transformation.py b/litellm/llms/watsonx/chat/transformation.py index f253da6f5b..2ff1dd6a68 100644 --- a/litellm/llms/watsonx/chat/transformation.py +++ b/litellm/llms/watsonx/chat/transformation.py @@ -15,7 +15,6 @@ from ..common_utils import IBMWatsonXMixin class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): - def get_supported_openai_params(self, model: str) -> List: return [ "temperature", # equivalent to temperature diff --git a/litellm/main.py b/litellm/main.py index 1e6c36aa6c..f69454aaad 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -946,14 +946,16 @@ def completion( # type: ignore # noqa: PLR0915 ## PROMPT MANAGEMENT HOOKS ## if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: - model, messages, optional_params = ( - litellm_logging_obj.get_chat_completion_prompt( - model=model, - messages=messages, - non_default_params=non_default_params, - prompt_id=prompt_id, - prompt_variables=prompt_variables, - ) + ( + model, + messages, + optional_params, + ) = litellm_logging_obj.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_variables=prompt_variables, ) try: @@ -1246,7 +1248,6 @@ def completion( # type: ignore # noqa: PLR0915 optional_params["max_retries"] = max_retries if litellm.AzureOpenAIO1Config().is_o_series_model(model=model): - ## LOAD CONFIG - if set config = litellm.AzureOpenAIO1Config.get_config() for k, v in config.items(): @@ -2654,9 +2655,9 @@ def completion( # type: ignore # noqa: PLR0915 "aws_region_name" not in optional_params or optional_params["aws_region_name"] is None ): - optional_params["aws_region_name"] = ( - aws_bedrock_client.meta.region_name - ) + optional_params[ + "aws_region_name" + ] = aws_bedrock_client.meta.region_name bedrock_route = BedrockModelInfo.get_bedrock_route(model) if bedrock_route == "converse": @@ -4362,9 +4363,9 @@ def adapter_completion( new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore - translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = ( - None - ) + translated_response: Optional[ + Union[BaseModel, AdapterCompletionStreamWrapper] + ] = None if isinstance(response, ModelResponse): translated_response = translation_obj.translate_completion_output_params( response=response @@ -4436,13 +4437,16 @@ async def amoderation( optional_params = GenericLiteLLMParams(**kwargs) try: - model, _custom_llm_provider, _dynamic_api_key, _dynamic_api_base = ( - litellm.get_llm_provider( - model=model or "", - custom_llm_provider=custom_llm_provider, - api_base=optional_params.api_base, - api_key=optional_params.api_key, - ) + ( + model, + _custom_llm_provider, + _dynamic_api_key, + _dynamic_api_base, + ) = litellm.get_llm_provider( + model=model or "", + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, ) except litellm.BadRequestError: # `model` is optional field for moderation - get_llm_provider will throw BadRequestError if model is not set / not recognized @@ -5405,7 +5409,6 @@ def speech( # noqa: PLR0915 litellm_params=litellm_params_dict, ) elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": - generic_optional_params = GenericLiteLLMParams(**kwargs) api_base = generic_optional_params.api_base or "" @@ -5460,7 +5463,6 @@ def speech( # noqa: PLR0915 async def ahealth_check_wildcard_models( model: str, custom_llm_provider: str, model_params: dict ) -> dict: - # this is a wildcard model, we need to pick a random model from the provider cheapest_models = pick_cheapest_chat_models_from_llm_provider( custom_llm_provider=custom_llm_provider, n=3 @@ -5783,9 +5785,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(content_chunks) > 0: - response["choices"][0]["message"]["content"] = ( - processor.get_combined_content(content_chunks) - ) + response["choices"][0]["message"][ + "content" + ] = processor.get_combined_content(content_chunks) audio_chunks = [ chunk diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e7d6bec004..16b45f3837 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -38,7 +38,7 @@ from .types_utils.utils import get_instance_fn, validate_custom_validate_return_ if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any @@ -615,9 +615,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[dict] = ( - {} - ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[ + dict + ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -873,12 +873,12 @@ class NewCustomerRequest(BudgetNewRequest): alias: Optional[str] = None # human-friendly alias blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model @model_validator(mode="before") @classmethod @@ -900,12 +900,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1040,9 +1040,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( - "success_and_failure" - ) + callback_type: Optional[ + Literal["success", "failure", "success_and_failure"] + ] = "success_and_failure" callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1299,9 +1299,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[List[FieldDetail]] = ( - None # For nested dictionary or Pydantic fields - ) + nested_fields: Optional[ + List[FieldDetail] + ] = None # For nested dictionary or Pydantic fields class ConfigGeneralSettings(LiteLLMPydanticObjectBase): @@ -1567,9 +1567,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) + user: Optional[ + Any + ] = None # You might want to replace 'Any' with a more specific type if available litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -2306,9 +2306,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[float] = ( - None # Users max budget within the organization - ) + max_budget_in_organization: Optional[ + float + ] = None # Users max budget within the organization class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -2497,9 +2497,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[str, ProviderBudgetResponseObject] = ( - {} - ) # Dictionary mapping provider names to their budget configurations + providers: Dict[ + str, ProviderBudgetResponseObject + ] = {} # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -2627,9 +2627,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[str] = ( - None # can be either user / team, inferred from the role mapping - ) + object_id_jwt_field: Optional[ + str + ] = None # can be either user / team, inferred from the role mapping scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index efbfe8d90c..ddd1008bd0 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -11,7 +11,7 @@ Run checks for: import asyncio import re import time -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast from fastapi import Request, status from pydantic import BaseModel @@ -49,7 +49,7 @@ from .auth_checks_organization import organization_role_based_access_check if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any @@ -551,7 +551,6 @@ def _get_role_based_permissions( return None for role_based_permission in role_based_permissions: - if role_based_permission.role == rbac_role: return getattr(role_based_permission, key) @@ -867,7 +866,6 @@ async def _get_team_object_from_cache( proxy_logging_obj is not None and proxy_logging_obj.internal_usage_cache.dual_cache ): - cached_team_obj = ( await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( key=key, parent_otel_span=parent_otel_span @@ -1202,7 +1200,6 @@ async def can_user_call_model( llm_router: Optional[Router], user_object: Optional[LiteLLM_UserTable], ) -> Literal[True]: - if user_object is None: return True diff --git a/litellm/proxy/auth/auth_checks_organization.py b/litellm/proxy/auth/auth_checks_organization.py index 3da3d8ddd1..e96a5c61fc 100644 --- a/litellm/proxy/auth/auth_checks_organization.py +++ b/litellm/proxy/auth/auth_checks_organization.py @@ -44,9 +44,10 @@ def organization_role_based_access_check( # Checks if route is an Org Admin Only Route if route in LiteLLMRoutes.org_admin_only_routes.value: - _user_organizations, _user_organization_role_mapping = ( - get_user_organization_info(user_object) - ) + ( + _user_organizations, + _user_organization_role_mapping, + ) = get_user_organization_info(user_object) if user_object.organization_memberships is None: raise ProxyException( @@ -84,9 +85,10 @@ def organization_role_based_access_check( ) elif route == "/team/new": # if user is part of multiple teams, then they need to specify the organization_id - _user_organizations, _user_organization_role_mapping = ( - get_user_organization_info(user_object) - ) + ( + _user_organizations, + _user_organization_role_mapping, + ) = get_user_organization_info(user_object) if ( user_object.organization_memberships is not None and len(user_object.organization_memberships) > 0 diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index 05797381c6..7c97655141 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -3,7 +3,7 @@ Handles Authentication Errors """ import asyncio -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from fastapi import HTTPException, Request, status @@ -17,13 +17,12 @@ from litellm.types.services import ServiceTypes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any class UserAPIKeyAuthExceptionHandler: - @staticmethod async def _handle_authentication_error( e: Exception, diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 2c4b122d3a..0200457ef9 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -14,7 +14,6 @@ from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS def _get_request_ip_address( request: Request, use_x_forwarded_for: Optional[bool] = False ) -> Optional[str]: - client_ip = None if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: client_ip = request.headers["x-forwarded-for"] @@ -469,7 +468,6 @@ def should_run_auth_on_pass_through_provider_route(route: str) -> bool: from litellm.proxy.proxy_server import general_settings, premium_user if premium_user is not True: - return False # premium use has opted into using client credentials diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index cc41050198..783c2f1553 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -166,7 +166,6 @@ class JWTHandler: self, token: dict, default_value: Optional[str] ) -> Optional[str]: try: - if self.litellm_jwtauth.end_user_id_jwt_field is not None: user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] else: @@ -339,7 +338,6 @@ class JWTHandler: return scopes async def get_public_key(self, kid: Optional[str]) -> dict: - keys_url = os.getenv("JWT_PUBLIC_KEY_URL") if keys_url is None: @@ -348,7 +346,6 @@ class JWTHandler: keys_url_list = [url.strip() for url in keys_url.split(",")] for key_url in keys_url_list: - cache_key = f"litellm_jwt_auth_keys_{key_url}" cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) @@ -923,7 +920,6 @@ class JWTAuthManager: object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) if rbac_role and object_id: - if rbac_role == LitellmUserRoles.TEAM: team_id = object_id elif rbac_role == LitellmUserRoles.INTERNAL_USER: @@ -940,15 +936,16 @@ class JWTAuthManager: ## SPECIFIC TEAM ID if not team_id: - team_id, team_object = ( - await JWTAuthManager.find_and_validate_specific_team_id( - jwt_handler, - jwt_valid_token, - prisma_client, - user_api_key_cache, - parent_otel_span, - proxy_logging_obj, - ) + ( + team_id, + team_object, + ) = await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler, + jwt_valid_token, + prisma_client, + user_api_key_cache, + parent_otel_span, + proxy_logging_obj, ) if not team_object and not team_id: diff --git a/litellm/proxy/auth/litellm_license.py b/litellm/proxy/auth/litellm_license.py index 67ec91f51a..d962aad2c0 100644 --- a/litellm/proxy/auth/litellm_license.py +++ b/litellm/proxy/auth/litellm_license.py @@ -45,7 +45,6 @@ class LicenseCheck: verbose_proxy_logger.error(f"Error reading public key: {str(e)}") def _verify(self, license_str: str) -> bool: - verbose_proxy_logger.debug( "litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format( self.base_url, license_str diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index a48ef6ae87..f0f730138f 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -178,7 +178,6 @@ def _get_wildcard_models( all_wildcard_models = [] for model in unique_models: if _check_wildcard_routing(model=model): - if ( return_wildcard_routes ): # will add the wildcard route to the list eg: anthropic/*. diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 8f956abb72..41529512b6 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -16,7 +16,6 @@ from .auth_checks_organization import _user_is_org_admin class RouteChecks: - @staticmethod def non_proxy_admin_allowed_routes_check( user_obj: Optional[LiteLLM_UserTable], @@ -67,7 +66,6 @@ class RouteChecks: and getattr(valid_token, "permissions", None) is not None and "get_spend_routes" in getattr(valid_token, "permissions", []) ): - pass elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value: if RouteChecks.is_llm_api_route(route=route): @@ -80,7 +78,6 @@ class RouteChecks: ): # the Admin Viewer is only allowed to call /user/update for their own user_id and can only update if route == "/user/update": - # Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY if request_data is not None and isinstance(request_data, dict): _params_updated = request_data.keys() diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b58353bf05..eddbf4e0d9 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -206,7 +206,6 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: def get_model_from_request(request_data: dict, route: str) -> Optional[str]: - # First try to get model from request_data model = request_data.get("model") @@ -229,7 +228,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_apim_header: Optional[str], request_data: dict, ) -> UserAPIKeyAuth: - from litellm.proxy.proxy_server import ( general_settings, jwt_handler, @@ -251,7 +249,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 valid_token: Optional[UserAPIKeyAuth] = None try: - # get the request body await pre_db_read_auth_checks( @@ -514,23 +511,23 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 proxy_logging_obj=proxy_logging_obj, ) if _end_user_object is not None: - end_user_params["allowed_model_region"] = ( - _end_user_object.allowed_model_region - ) + end_user_params[ + "allowed_model_region" + ] = _end_user_object.allowed_model_region if _end_user_object.litellm_budget_table is not None: budget_info = _end_user_object.litellm_budget_table if budget_info.tpm_limit is not None: - end_user_params["end_user_tpm_limit"] = ( - budget_info.tpm_limit - ) + end_user_params[ + "end_user_tpm_limit" + ] = budget_info.tpm_limit if budget_info.rpm_limit is not None: - end_user_params["end_user_rpm_limit"] = ( - budget_info.rpm_limit - ) + end_user_params[ + "end_user_rpm_limit" + ] = budget_info.rpm_limit if budget_info.max_budget is not None: - end_user_params["end_user_max_budget"] = ( - budget_info.max_budget - ) + end_user_params[ + "end_user_max_budget" + ] = budget_info.max_budget except Exception as e: if isinstance(e, litellm.BudgetExceededError): raise e @@ -801,7 +798,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # Check 3. Check if user is in their team budget if valid_token.team_member_spend is not None: if prisma_client is not None: - _cache_key = f"{valid_token.team_id}_{valid_token.user_id}" team_member_info = await user_api_key_cache.async_get_cache( diff --git a/litellm/proxy/common_utils/encrypt_decrypt_utils.py b/litellm/proxy/common_utils/encrypt_decrypt_utils.py index ec9279a089..3452734867 100644 --- a/litellm/proxy/common_utils/encrypt_decrypt_utils.py +++ b/litellm/proxy/common_utils/encrypt_decrypt_utils.py @@ -21,7 +21,6 @@ def _get_salt_key(): def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): - signing_key = new_encryption_key or _get_salt_key() try: @@ -41,7 +40,6 @@ def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): def decrypt_value_helper(value: str): - signing_key = _get_salt_key() try: diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 5736ee2152..7220ccaa65 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -142,10 +142,10 @@ def check_file_size_under_limit( if llm_router is not None and request_data["model"] in router_model_names: try: - deployment: Optional[Deployment] = ( - llm_router.get_deployment_by_model_group_name( - model_group_name=request_data["model"] - ) + deployment: Optional[ + Deployment + ] = llm_router.get_deployment_by_model_group_name( + model_group_name=request_data["model"] ) if ( deployment diff --git a/litellm/proxy/db/log_db_metrics.py b/litellm/proxy/db/log_db_metrics.py index 9bd3350793..5c79515532 100644 --- a/litellm/proxy/db/log_db_metrics.py +++ b/litellm/proxy/db/log_db_metrics.py @@ -36,7 +36,6 @@ def log_db_metrics(func): @wraps(func) async def wrapper(*args, **kwargs): - start_time: datetime = datetime.now() try: diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index f77c839aaf..f98fc9300f 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -43,9 +43,9 @@ class RedisUpdateBuffer: """ from litellm.proxy.proxy_server import general_settings - _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( - general_settings.get("use_redis_transaction_buffer", False) - ) + _use_redis_transaction_buffer: Optional[ + Union[bool, str] + ] = general_settings.get("use_redis_transaction_buffer", False) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) if _use_redis_transaction_buffer is None: @@ -78,15 +78,13 @@ class RedisUpdateBuffer: "redis_cache is None, skipping store_in_memory_spend_updates_in_redis" ) return - db_spend_update_transactions: DBSpendUpdateTransactions = ( - DBSpendUpdateTransactions( - user_list_transactions=prisma_client.user_list_transactions, - end_user_list_transactions=prisma_client.end_user_list_transactions, - key_list_transactions=prisma_client.key_list_transactions, - team_list_transactions=prisma_client.team_list_transactions, - team_member_list_transactions=prisma_client.team_member_list_transactions, - org_list_transactions=prisma_client.org_list_transactions, - ) + db_spend_update_transactions: DBSpendUpdateTransactions = DBSpendUpdateTransactions( + user_list_transactions=prisma_client.user_list_transactions, + end_user_list_transactions=prisma_client.end_user_list_transactions, + key_list_transactions=prisma_client.key_list_transactions, + team_list_transactions=prisma_client.team_list_transactions, + team_member_list_transactions=prisma_client.team_member_list_transactions, + org_list_transactions=prisma_client.org_list_transactions, ) # only store in redis if there are any updates to commit diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index c351f9f762..e970311460 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -45,7 +45,6 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b # v1 implementation of this if isinstance(request_guardrails, dict): - # get guardrail configs from `init_guardrails.py` # for all requested guardrails -> get their associated callbacks for _guardrail_name, should_run in request_guardrails.items(): diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 7686fba7cf..5c6b53be25 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -192,7 +192,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): async def make_bedrock_api_request( self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None ): - credentials, aws_region_name = self._load_credentials() bedrock_request_data: dict = dict( self.convert_to_bedrock_format( diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 5d3b8be334..2dd8a3154a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -148,9 +148,9 @@ class lakeraAI_Moderation(CustomGuardrail): text = "" _json_data: str = "" if "messages" in data and isinstance(data["messages"], list): - prompt_injection_obj: Optional[GuardrailItem] = ( - litellm.guardrail_name_config_map.get("prompt_injection") - ) + prompt_injection_obj: Optional[ + GuardrailItem + ] = litellm.guardrail_name_config_map.get("prompt_injection") if prompt_injection_obj is not None: enabled_roles = prompt_injection_obj.enabled_roles else: diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 86d2c8b25a..0c7d2a1fe6 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -95,9 +95,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): presidio_analyzer_api_base: Optional[str] = None, presidio_anonymizer_api_base: Optional[str] = None, ): - self.presidio_analyzer_api_base: Optional[str] = ( - presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore - ) + self.presidio_analyzer_api_base: Optional[ + str + ] = presidio_analyzer_api_base or get_secret( + "PRESIDIO_ANALYZER_API_BASE", None + ) # type: ignore self.presidio_anonymizer_api_base: Optional[ str ] = presidio_anonymizer_api_base or litellm.get_secret( @@ -168,7 +170,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): async with session.post( analyze_url, json=analyze_payload ) as response: - analyze_results = await response.json() # Make the second request to /anonymize @@ -228,7 +229,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): """ try: - content_safety = data.get("content_safety", None) verbose_proxy_logger.debug("content_safety: %s", content_safety) presidio_config = self.get_presidio_settings_from_request_data(data) diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 15a9bc1ba8..e06366d02b 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -71,7 +71,6 @@ class DynamicRateLimiterCache: class _PROXY_DynamicRateLimitHandler(CustomLogger): - # Class variables or attributes def __init__(self, internal_usage_cache: DualCache): self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache) @@ -121,12 +120,13 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): active_projects = await self.internal_usage_cache.async_get_cache( model=model ) - current_model_tpm, current_model_rpm = ( - await self.llm_router.get_model_group_usage(model_group=model) - ) - model_group_info: Optional[ModelGroupInfo] = ( - self.llm_router.get_model_group_info(model_group=model) - ) + ( + current_model_tpm, + current_model_rpm, + ) = await self.llm_router.get_model_group_usage(model_group=model) + model_group_info: Optional[ + ModelGroupInfo + ] = self.llm_router.get_model_group_info(model_group=model) total_model_tpm: Optional[int] = None total_model_rpm: Optional[int] = None if model_group_info is not None: @@ -210,10 +210,14 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): key_priority: Optional[str] = user_api_key_dict.metadata.get( "priority", None ) - available_tpm, available_rpm, model_tpm, model_rpm, active_projects = ( - await self.check_available_usage( - model=data["model"], priority=key_priority - ) + ( + available_tpm, + available_rpm, + model_tpm, + model_rpm, + active_projects, + ) = await self.check_available_usage( + model=data["model"], priority=key_priority ) ### CHECK TPM ### if available_tpm is not None and available_tpm == 0: @@ -267,21 +271,25 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): key_priority: Optional[str] = user_api_key_dict.metadata.get( "priority", None ) - available_tpm, available_rpm, model_tpm, model_rpm, active_projects = ( - await self.check_available_usage( - model=model_info["model_name"], priority=key_priority - ) - ) - response._hidden_params["additional_headers"] = ( - { # Add additional response headers - easier debugging - "x-litellm-model_group": model_info["model_name"], - "x-ratelimit-remaining-litellm-project-tokens": available_tpm, - "x-ratelimit-remaining-litellm-project-requests": available_rpm, - "x-ratelimit-remaining-model-tokens": model_tpm, - "x-ratelimit-remaining-model-requests": model_rpm, - "x-ratelimit-current-active-projects": active_projects, - } + ( + available_tpm, + available_rpm, + model_tpm, + model_rpm, + active_projects, + ) = await self.check_available_usage( + model=model_info["model_name"], priority=key_priority ) + response._hidden_params[ + "additional_headers" + ] = { # Add additional response headers - easier debugging + "x-litellm-model_group": model_info["model_name"], + "x-ratelimit-remaining-litellm-project-tokens": available_tpm, + "x-ratelimit-remaining-litellm-project-requests": available_rpm, + "x-ratelimit-remaining-model-tokens": model_tpm, + "x-ratelimit-remaining-model-requests": model_rpm, + "x-ratelimit-current-active-projects": active_projects, + } return response return await super().async_post_call_success_hook( diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py index 2030cb2a45..c2c4f0669f 100644 --- a/litellm/proxy/hooks/key_management_event_hooks.py +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -28,7 +28,6 @@ LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/" class KeyManagementEventHooks: - @staticmethod async def async_key_generated_hook( data: GenerateKeyRequest, diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 06f3b6afe5..242c013d67 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache - Span = _Span + Span = Union[_Span, Any] InternalUsageCache = _InternalUsageCache else: Span = Any @@ -201,7 +201,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if rpm_limit is None: rpm_limit = sys.maxsize - values_to_update_in_cache: List[Tuple[Any, Any]] = ( + values_to_update_in_cache: List[ + Tuple[Any, Any] + ] = ( [] ) # values that need to get updated in cache, will run a batch_set_cache after this function @@ -678,9 +680,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose("Inside Max Parallel Request Failure Hook") - litellm_parent_otel_span: Union[Span, None] = ( - _get_parent_otel_span_from_kwargs(kwargs=kwargs) - ) + litellm_parent_otel_span: Union[ + Span, None + ] = _get_parent_otel_span_from_kwargs(kwargs=kwargs) _metadata = kwargs["litellm_params"].get("metadata", {}) or {} global_max_parallel_requests = _metadata.get( "global_max_parallel_requests", None @@ -807,11 +809,11 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{api_key}::{precise_minute}::request_count" - current: Optional[CurrentItemRateLimit] = ( - await self.internal_usage_cache.async_get_cache( - key=request_count_api_key, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + current: Optional[ + CurrentItemRateLimit + ] = await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, ) key_remaining_rpm_limit: Optional[int] = None @@ -843,15 +845,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): _additional_headers = _hidden_params.get("additional_headers", {}) or {} if key_remaining_rpm_limit is not None: - _additional_headers["x-ratelimit-remaining-requests"] = ( - key_remaining_rpm_limit - ) + _additional_headers[ + "x-ratelimit-remaining-requests" + ] = key_remaining_rpm_limit if key_rpm_limit is not None: _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit if key_remaining_tpm_limit is not None: - _additional_headers["x-ratelimit-remaining-tokens"] = ( - key_remaining_tpm_limit - ) + _additional_headers[ + "x-ratelimit-remaining-tokens" + ] = key_remaining_tpm_limit if key_tpm_limit is not None: _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index b1b2bbee5c..b8fa8466a3 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -196,7 +196,6 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): return data except HTTPException as e: - if ( e.status_code == 400 and isinstance(e.detail, dict) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index f205b0146f..39c1eeace9 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -51,10 +51,10 @@ class _ProxyDBLogger(CustomLogger): ) _metadata["user_api_key"] = user_api_key_dict.api_key _metadata["status"] = "failure" - _metadata["error_information"] = ( - StandardLoggingPayloadSetup.get_error_information( - original_exception=original_exception, - ) + _metadata[ + "error_information" + ] = StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, ) existing_metadata: dict = request_data.get("metadata", None) or {} diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index ece5ecf4b7..6427be5a6e 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -346,11 +346,11 @@ class LiteLLMProxyRequestSetup: ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - data[_metadata_variable_name]["tags"] = ( - LiteLLMProxyRequestSetup._merge_tags( - request_tags=data[_metadata_variable_name].get("tags"), - tags_to_add=key_metadata["tags"], - ) + data[_metadata_variable_name][ + "tags" + ] = LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], ) if "spend_logs_metadata" in key_metadata and isinstance( key_metadata["spend_logs_metadata"], dict @@ -556,9 +556,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data[_metadata_variable_name]["global_max_parallel_requests"] = ( - general_settings.get("global_max_parallel_requests", None) - ) + data[_metadata_variable_name][ + "global_max_parallel_requests" + ] = general_settings.get("global_max_parallel_requests", None) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20aa1c6bbf..65b0156afe 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -197,7 +197,6 @@ async def budget_settings( for field_name, field_info in BudgetNewRequest.model_fields.items(): if field_name in allowed_args: - _stored_in_db = True _response_obj = ConfigList( diff --git a/litellm/proxy/management_endpoints/common_utils.py b/litellm/proxy/management_endpoints/common_utils.py index d80a06c597..550ff44616 100644 --- a/litellm/proxy/management_endpoints/common_utils.py +++ b/litellm/proxy/management_endpoints/common_utils.py @@ -16,7 +16,6 @@ def _is_user_team_admin( if ( member.user_id is not None and member.user_id == user_api_key_dict.user_id ) and member.role == "admin": - return True return False diff --git a/litellm/proxy/management_endpoints/customer_endpoints.py b/litellm/proxy/management_endpoints/customer_endpoints.py index 976ff8581f..1f6f846bc7 100644 --- a/litellm/proxy/management_endpoints/customer_endpoints.py +++ b/litellm/proxy/management_endpoints/customer_endpoints.py @@ -230,7 +230,6 @@ async def new_end_user( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) try: - ## VALIDATION ## if data.default_model is not None: if llm_router is None: diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 79de6da1fd..90444013a8 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -82,9 +82,9 @@ def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> d data_json["user_id"] = str(uuid.uuid4()) auto_create_key = data_json.pop("auto_create_key", True) if auto_create_key is False: - data_json["table_name"] = ( - "user" # only create a user, don't create key if 'auto_create_key' set to False - ) + data_json[ + "table_name" + ] = "user" # only create a user, don't create key if 'auto_create_key' set to False is_internal_user = False if data.user_role and data.user_role.is_internal_user_role: @@ -370,7 +370,6 @@ async def ui_get_available_role( _data_to_return = {} for role in LitellmUserRoles: - # We only show a subset of roles on UI if role in [ LitellmUserRoles.PROXY_ADMIN, @@ -652,9 +651,9 @@ def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> di "budget_duration" not in non_default_values ): # applies internal user limits, if user role updated if is_internal_user and litellm.internal_user_budget_duration is not None: - non_default_values["budget_duration"] = ( - litellm.internal_user_budget_duration - ) + non_default_values[ + "budget_duration" + ] = litellm.internal_user_budget_duration duration_s = duration_in_seconds( duration=non_default_values["budget_duration"] ) @@ -965,13 +964,13 @@ async def get_users( "in": user_id_list, # Now passing a list of strings as required by Prisma } - users: Optional[List[LiteLLM_UserTable]] = ( - await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=skip, - take=page_size, - order={"created_at": "desc"}, - ) + users: Optional[ + List[LiteLLM_UserTable] + ] = await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, ) # Get total count of user rows @@ -1226,13 +1225,13 @@ async def ui_view_users( } # Query users with pagination and filters - users: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=skip, - take=page_size, - order={"created_at": "desc"}, - ) + users: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, ) if not users: diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9141d9d14a..b0bf1fb619 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -227,7 +227,6 @@ def _personal_key_membership_check( def _personal_key_generation_check( user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest ): - if ( litellm.key_generation_settings is None or litellm.key_generation_settings.get("personal_key_generation") is None @@ -568,9 +567,9 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -1448,10 +1447,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) # Assuming 'db' is your Prisma Client instance @@ -1553,9 +1552,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -1677,7 +1676,6 @@ async def regenerate_key_fn( Note: This is an Enterprise feature. It requires a premium license to use. """ try: - from litellm.proxy.proxy_server import ( hash_token, master_key, @@ -1824,7 +1822,6 @@ async def validate_key_list_check( key_alias: Optional[str], prisma_client: PrismaClient, ) -> Optional[LiteLLM_UserTable]: - if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: return None @@ -1835,11 +1832,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -1900,10 +1897,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 88245e36d1..0e8a9e4cc8 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -399,7 +399,6 @@ class ModelManagementAuthChecks: prisma_client: PrismaClient, premium_user: bool, ) -> Literal[True]: - ## Check team model auth if ( model_params.model_info is not None @@ -579,7 +578,6 @@ async def add_new_model( ) try: - if prisma_client is None: raise HTTPException( status_code=500, @@ -717,7 +715,6 @@ async def update_model( ) try: - if prisma_client is None: raise HTTPException( status_code=500, diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index c202043fbe..37de12a9d2 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -358,11 +358,11 @@ async def info_organization(organization_id: str): if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) - response: Optional[LiteLLM_OrganizationTableWithMembers] = ( - await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": organization_id}, - include={"litellm_budget_table": True, "members": True, "teams": True}, - ) + response: Optional[ + LiteLLM_OrganizationTableWithMembers + ] = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": organization_id}, + include={"litellm_budget_table": True, "members": True, "teams": True}, ) if response is None: @@ -486,12 +486,13 @@ async def organization_member_add( updated_organization_memberships: List[LiteLLM_OrganizationMembershipTable] = [] for member in members: - updated_user, updated_organization_membership = ( - await add_member_to_organization( - member=member, - organization_id=data.organization_id, - prisma_client=prisma_client, - ) + ( + updated_user, + updated_organization_membership, + ) = await add_member_to_organization( + member=member, + organization_id=data.organization_id, + prisma_client=prisma_client, ) updated_users.append(updated_user) @@ -657,16 +658,16 @@ async def organization_member_update( }, data={"budget_id": budget_id}, ) - final_organization_membership: Optional[BaseModel] = ( - await prisma_client.db.litellm_organizationmembership.find_unique( - where={ - "user_id_organization_id": { - "user_id": data.user_id, - "organization_id": data.organization_id, - } - }, - include={"litellm_budget_table": True}, - ) + final_organization_membership: Optional[ + BaseModel + ] = await prisma_client.db.litellm_organizationmembership.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + include={"litellm_budget_table": True}, ) if final_organization_membership is None: diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index f5bcc6ba11..842b5c8e75 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -506,12 +506,12 @@ async def update_team( updated_kv["model_id"] = _model_id updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) - team_row: Optional[LiteLLM_TeamTable] = ( - await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, - data=updated_kv, - include={"litellm_model_table": True}, # type: ignore - ) + team_row: Optional[ + LiteLLM_TeamTable + ] = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data=updated_kv, + include={"litellm_model_table": True}, # type: ignore ) if team_row is None or team_row.team_id is None: @@ -1137,10 +1137,10 @@ async def delete_team( team_rows: List[LiteLLM_TeamTable] = [] for team_id in data.team_ids: try: - team_row_base: Optional[BaseModel] = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} - ) + team_row_base: Optional[ + BaseModel + ] = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} ) if team_row_base is None: raise Exception @@ -1298,10 +1298,10 @@ async def team_info( ) try: - team_info: Optional[BaseModel] = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} - ) + team_info: Optional[ + BaseModel + ] = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} ) if team_info is None: raise Exception diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 86dec9fcaf..d38ff6b536 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -213,9 +213,9 @@ async def google_login(request: Request): # noqa: PLR0915 if state: redirect_params["state"] = state elif "okta" in generic_authorization_endpoint: - redirect_params["state"] = ( - uuid.uuid4().hex - ) # set state param for okta - required + redirect_params[ + "state" + ] = uuid.uuid4().hex # set state param for okta - required return await generic_sso.get_login_redirect(**redirect_params) # type: ignore elif ui_username is not None: # No Google, Microsoft SSO @@ -725,9 +725,9 @@ async def insert_sso_user( if user_defined_values.get("max_budget") is None: user_defined_values["max_budget"] = litellm.max_internal_user_budget if user_defined_values.get("budget_duration") is None: - user_defined_values["budget_duration"] = ( - litellm.internal_user_budget_duration - ) + user_defined_values[ + "budget_duration" + ] = litellm.internal_user_budget_duration if user_defined_values["user_role"] is None: user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 69a5cf9141..cb8e079b76 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -179,7 +179,7 @@ def _delete_api_key_from_cache(kwargs): user_api_key_cache.delete_cache(key=update_request.key) # delete key request - if isinstance(update_request, KeyRequest): + if isinstance(update_request, KeyRequest) and update_request.keys: for key in update_request.keys: user_api_key_cache.delete_cache(key=key) pass @@ -251,7 +251,6 @@ async def send_management_endpoint_alert( proxy_logging_obj is not None and proxy_logging_obj.slack_alerting_instance is not None ): - # Virtual Key Events if function_name in management_function_to_event_name: _event_name: AlertType = management_function_to_event_name[function_name] diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index ffbca91c69..e810ba026e 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -316,7 +316,6 @@ async def get_file_content( data: Dict = {} try: - # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 51845956fc..d6f2a01712 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -25,7 +25,6 @@ else: class AnthropicPassthroughLoggingHandler: - @staticmethod def anthropic_passthrough_handler( httpx_response: httpx.Response, @@ -123,9 +122,9 @@ class AnthropicPassthroughLoggingHandler: litellm_model_response.id = logging_obj.litellm_call_id litellm_model_response.model = model logging_obj.model_call_details["model"] = model - logging_obj.model_call_details["custom_llm_provider"] = ( - litellm.LlmProviders.ANTHROPIC.value - ) + logging_obj.model_call_details[ + "custom_llm_provider" + ] = litellm.LlmProviders.ANTHROPIC.value return kwargs except Exception as e: verbose_proxy_logger.exception( diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 9443563738..a20f39e65c 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -3,6 +3,7 @@ import re from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from urllib.parse import urlparse + import httpx import litellm @@ -222,7 +223,9 @@ class VertexPassthroughLoggingHandler: @staticmethod def _get_custom_llm_provider_from_url(url: str) -> str: parsed_url = urlparse(url) - if parsed_url.hostname and parsed_url.hostname.endswith("generativelanguage.googleapis.com"): + if parsed_url.hostname and parsed_url.hostname.endswith( + "generativelanguage.googleapis.com" + ): return litellm.LlmProviders.GEMINI.value return litellm.LlmProviders.VERTEX_AI.value diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index a13b0dc216..a6b1b3e614 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -373,7 +373,6 @@ async def pass_through_request( # noqa: PLR0915 litellm_call_id = str(uuid.uuid4()) url: Optional[httpx.URL] = None try: - from litellm.litellm_core_utils.litellm_logging import Logging from litellm.proxy.proxy_server import proxy_logging_obj @@ -384,7 +383,6 @@ async def pass_through_request( # noqa: PLR0915 ) if merge_query_params: - # Create a new URL with the merged query params url = url.copy_with( query=urlencode( @@ -771,7 +769,6 @@ def _is_streaming_response(response: httpx.Response) -> bool: async def initialize_pass_through_endpoints(pass_through_endpoints: list): - verbose_proxy_logger.debug("initializing pass through endpoints") from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes from litellm.proxy.proxy_server import app, premium_user diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index 89cccfc071..a02cacc3cc 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -130,9 +130,9 @@ class PassthroughEndpointRouter: vertex_location=location, vertex_credentials=vertex_credentials, ) - self.deployment_key_to_vertex_credentials[deployment_key] = ( - vertex_pass_through_credentials - ) + self.deployment_key_to_vertex_credentials[ + deployment_key + ] = vertex_pass_through_credentials def _get_deployment_key( self, project_id: Optional[str], location: Optional[str] diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index b022bf1d25..2c11e4a2dd 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -21,7 +21,6 @@ from .types import EndpointType class PassThroughStreamingHandler: - @staticmethod async def chunk_processor( response: httpx.Response, diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py index 22fa4da9de..c1e2220c15 100644 --- a/litellm/proxy/prisma_migration.py +++ b/litellm/proxy/prisma_migration.py @@ -9,8 +9,8 @@ import time sys.path.insert( 0, os.path.abspath("./") ) # Adds the parent directory to the system path -from litellm.secret_managers.aws_secret_manager import decrypt_env_var from litellm._logging import verbose_proxy_logger +from litellm.secret_managers.aws_secret_manager import decrypt_env_var if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV @@ -39,7 +39,9 @@ if not database_url: ) exit(1) else: - verbose_proxy_logger.info("Using existing DATABASE_URL environment variable") # Log existing DATABASE_URL + verbose_proxy_logger.info( + "Using existing DATABASE_URL environment variable" + ) # Log existing DATABASE_URL # Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations direct_url = os.getenv("DIRECT_URL") @@ -63,12 +65,18 @@ while retry_count < max_retries and exit_code != 0: # run prisma generate verbose_proxy_logger.info("Running 'prisma generate'...") result = subprocess.run(["prisma", "generate"], capture_output=True, text=True) - verbose_proxy_logger.info(f"'prisma generate' stdout: {result.stdout}") # Log stdout + verbose_proxy_logger.info( + f"'prisma generate' stdout: {result.stdout}" + ) # Log stdout exit_code = result.returncode if exit_code != 0: - verbose_proxy_logger.info(f"'prisma generate' failed with exit code {exit_code}.") - verbose_proxy_logger.error(f"'prisma generate' stderr: {result.stderr}") # Log stderr + verbose_proxy_logger.info( + f"'prisma generate' failed with exit code {exit_code}." + ) + verbose_proxy_logger.error( + f"'prisma generate' stderr: {result.stderr}" + ) # Log stderr # Run the Prisma db push command verbose_proxy_logger.info("Running 'prisma db push --accept-data-loss'...") @@ -79,14 +87,20 @@ while retry_count < max_retries and exit_code != 0: exit_code = result.returncode if exit_code != 0: - verbose_proxy_logger.info(f"'prisma db push' stderr: {result.stderr}") # Log stderr - verbose_proxy_logger.error(f"'prisma db push' failed with exit code {exit_code}.") + verbose_proxy_logger.info( + f"'prisma db push' stderr: {result.stderr}" + ) # Log stderr + verbose_proxy_logger.error( + f"'prisma db push' failed with exit code {exit_code}." + ) if retry_count < max_retries: verbose_proxy_logger.info("Retrying in 10 seconds...") time.sleep(10) if retry_count == max_retries and exit_code != 0: - verbose_proxy_logger.error(f"Unable to push database changes after {max_retries} retries.") + verbose_proxy_logger.error( + f"Unable to push database changes after {max_retries} retries." + ) exit(1) verbose_proxy_logger.info("Database push successful!") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 89c15f413d..f59d117181 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -18,6 +18,7 @@ from typing import ( List, Optional, Tuple, + Union, cast, get_args, get_origin, @@ -36,7 +37,7 @@ if TYPE_CHECKING: from litellm.integrations.opentelemetry import OpenTelemetry - Span = _Span + Span = Union[_Span, Any] else: Span = Any OpenTelemetry = Any @@ -763,9 +764,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[RedisCache] = ( - None # redis cache used for tracking spend, tpm/rpm limits -) +redis_usage_cache: Optional[ + RedisCache +] = None # redis cache used for tracking spend, tpm/rpm limits user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -818,7 +819,6 @@ async def check_request_disconnection(request: Request, llm_api_call_task): while time.time() - start_time < 600: await asyncio.sleep(1) if await request.is_disconnected(): - # cancel the LLM API Call task if any passed - this is passed from individual providers # Example OpenAI, Azure, VertexAI etc llm_api_call_task.cancel() @@ -1092,9 +1092,9 @@ async def update_cache( # noqa: PLR0915 _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[LiteLLM_TeamTable] = ( - await user_api_key_cache.async_get_cache(key=_id) - ) + existing_spend_obj: Optional[ + LiteLLM_TeamTable + ] = await user_api_key_cache.async_get_cache(key=_id) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -1589,7 +1589,7 @@ class ProxyConfig: # users can pass os.environ/ variables on the proxy - we should read them from the env for key, value in cache_params.items(): - if type(value) is str and value.startswith("os.environ/"): + if isinstance(value, str) and value.startswith("os.environ/"): cache_params[key] = get_secret(value) ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables @@ -1610,7 +1610,6 @@ class ProxyConfig: litellm.guardrail_name_config_map = guardrail_name_config_map elif key == "callbacks": - initialize_callbacks_on_proxy( value=value, premium_user=premium_user, @@ -2765,9 +2764,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ["AZURE_API_VERSION"] = ( - api_version # set this for azure - litellm can read this from the env - ) + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -2810,7 +2809,6 @@ async def async_assistants_data_generator( try: time.time() async with response as chunk: - ### CALL HOOKS ### - modify outgoing data chunk = await proxy_logging_obj.async_post_call_streaming_hook( user_api_key_dict=user_api_key_dict, response=chunk @@ -4675,7 +4673,6 @@ async def get_thread( global proxy_logging_obj data: Dict = {} try: - # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, @@ -6385,7 +6382,6 @@ async def alerting_settings( for field_name, field_info in SlackAlertingArgs.model_fields.items(): if field_name in allowed_args: - _stored_in_db: Optional[bool] = None if field_name in alerting_args_dict: _stored_in_db = True @@ -7333,7 +7329,6 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915 "success_callback" in updated_litellm_settings and "success_callback" in config["litellm_settings"] ): - # check both success callback are lists if isinstance( config["litellm_settings"]["success_callback"], list @@ -7588,7 +7583,6 @@ async def get_config_list( for field_name, field_info in ConfigGeneralSettings.model_fields.items(): if field_name in allowed_args: - ## HANDLE TYPED DICT typed_dict_type = allowed_args[field_name]["type"] @@ -7621,9 +7615,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[idx].field_description = ( - sub_field_info.description - ) + nested_fields[ + idx + ].field_description = sub_field_info.description idx += 1 _stored_in_db = None diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index 4c0e22aef7..4690b6cbd8 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -286,7 +286,6 @@ async def get_global_activity( user_api_key_dict, start_date_obj, end_date_obj ) else: - sql_query = """ SELECT date_trunc('day', "startTime") AS date, @@ -453,7 +452,6 @@ async def get_global_activity_model( user_api_key_dict, start_date_obj, end_date_obj ) else: - sql_query = """ SELECT model_group, @@ -1096,7 +1094,6 @@ async def get_global_spend_report( start_date_obj, end_date_obj, team_id, customer_id, prisma_client ) if group_by == "team": - # first get data from spend logs -> SpendByModelApiKey # then read data from "SpendByModelApiKey" to format the response obj sql_query = """ @@ -1689,7 +1686,6 @@ async def ui_view_spend_logs( # noqa: PLR0915 ) try: - # Convert the date strings to datetime objects start_date_obj = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S").replace( tzinfo=timezone.utc @@ -2160,7 +2156,6 @@ async def global_spend_for_internal_user( code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) try: - user_id = user_api_key_dict.user_id if user_id is None: raise ValueError("/global/spend/logs Error: User ID is None") @@ -2293,7 +2288,6 @@ async def global_spend(): from litellm.proxy.proxy_server import prisma_client try: - total_spend = 0.0 if prisma_client is None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 4c82f586fb..900b26f3f1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -78,7 +78,7 @@ from litellm.types.utils import CallTypes, LoggedLiteLLMParams if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any @@ -491,7 +491,6 @@ class ProxyLogging: try: for callback in litellm.callbacks: - _callback = None if isinstance(callback, str): _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( @@ -1197,9 +1196,9 @@ class PrismaClient: api_requests=1, ) - self.daily_user_spend_transactions[daily_transaction_key] = ( - daily_transaction - ) + self.daily_user_spend_transactions[ + daily_transaction_key + ] = daily_transaction except Exception as e: raise e diff --git a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py index 0c91c326f5..684e2ad061 100644 --- a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py @@ -69,10 +69,10 @@ async def langfuse_proxy_route( request=request, api_key="Bearer {}".format(api_key) ) - callback_settings_obj: Optional[TeamCallbackMetadata] = ( - _get_dynamic_logging_metadata( - user_api_key_dict=user_api_key_dict, proxy_config=proxy_config - ) + callback_settings_obj: Optional[ + TeamCallbackMetadata + ] = _get_dynamic_logging_metadata( + user_api_key_dict=user_api_key_dict, proxy_config=proxy_config ) dynamic_langfuse_public_key: Optional[str] = None diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index ce8ae21c82..9307ce5a55 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -107,13 +107,16 @@ def rerank( # noqa: PLR0915 k for k, v in unique_version_params.items() if v is not None ] - model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( - litellm.get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=optional_params.api_base, - api_key=optional_params.api_key, - ) + ( + model, + _custom_llm_provider, + dynamic_api_key, + dynamic_api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, ) rerank_provider_config: BaseRerankConfig = ( @@ -272,7 +275,6 @@ def rerank( # noqa: PLR0915 _is_async=_is_async, ) elif _custom_llm_provider == "jina_ai": - if dynamic_api_key is None: raise ValueError( "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" diff --git a/litellm/responses/main.py b/litellm/responses/main.py index aec2f8fe4a..70b651f376 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -165,21 +165,24 @@ def responses( # get llm provider logic litellm_params = GenericLiteLLMParams(**kwargs) - model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( - litellm.get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=litellm_params.api_base, - api_key=litellm_params.api_key, - ) + ( + model, + custom_llm_provider, + dynamic_api_key, + dynamic_api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=litellm_params.api_base, + api_key=litellm_params.api_key, ) # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: diff --git a/litellm/router.py b/litellm/router.py index f739bc381d..78ad2afe1a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -148,7 +148,7 @@ from .router_utils.pattern_match_deployments import PatternMatchRouter if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any @@ -333,9 +333,9 @@ class Router: ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### - cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( - "local" # default to an in-memory cache - ) + cache_type: Literal[ + "local", "redis", "redis-semantic", "s3", "disk" + ] = "local" # default to an in-memory cache redis_cache = None cache_config: Dict[str, Any] = {} @@ -556,9 +556,9 @@ class Router: ) ) - self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( - model_group_retry_policy - ) + self.model_group_retry_policy: Optional[ + Dict[str, RetryPolicy] + ] = model_group_retry_policy self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -1093,9 +1093,9 @@ class Router: """ Adds default litellm params to kwargs, if set. """ - self.default_litellm_params[metadata_variable_name] = ( - self.default_litellm_params.pop("metadata", {}) - ) + self.default_litellm_params[ + metadata_variable_name + ] = self.default_litellm_params.pop("metadata", {}) for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None @@ -1678,14 +1678,16 @@ class Router: f"Prompt variables is set but not a dictionary. Got={prompt_variables}, type={type(prompt_variables)}" ) - model, messages, optional_params = ( - litellm_logging_object.get_chat_completion_prompt( - model=litellm_model, - messages=messages, - non_default_params=get_non_default_completion_params(kwargs=kwargs), - prompt_id=prompt_id, - prompt_variables=prompt_variables, - ) + ( + model, + messages, + optional_params, + ) = litellm_logging_object.get_chat_completion_prompt( + model=litellm_model, + messages=messages, + non_default_params=get_non_default_completion_params(kwargs=kwargs), + prompt_id=prompt_id, + prompt_variables=prompt_variables, ) kwargs = {**kwargs, **optional_params} @@ -2924,7 +2926,6 @@ class Router: Future Improvement - cache the result. """ try: - filtered_model_list = self.get_model_list() if filtered_model_list is None: raise Exception("Router not yet initialized.") @@ -3211,11 +3212,11 @@ class Router: if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3247,11 +3248,11 @@ class Router: e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3282,11 +3283,12 @@ class Router: e.message += "\n{}".format(error_message) if fallbacks is not None and model_group is not None: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") - fallback_model_group, generic_fallback_idx = ( - get_fallback_model_group( - fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] - model_group=cast(str, model_group), - ) + ( + fallback_model_group, + generic_fallback_idx, + ) = get_fallback_model_group( + fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] + model_group=cast(str, model_group), ) ## if none, check for generic fallback if ( @@ -3444,11 +3446,12 @@ class Router: """ Retry Logic """ - _healthy_deployments, _all_deployments = ( - await self._async_get_healthy_deployments( - model=kwargs.get("model") or "", - parent_otel_span=parent_otel_span, - ) + ( + _healthy_deployments, + _all_deployments, + ) = await self._async_get_healthy_deployments( + model=kwargs.get("model") or "", + parent_otel_span=parent_otel_span, ) # raises an exception if this error should not be retries @@ -3513,11 +3516,12 @@ class Router: remaining_retries = num_retries - current_attempt _model: Optional[str] = kwargs.get("model") # type: ignore if _model is not None: - _healthy_deployments, _ = ( - await self._async_get_healthy_deployments( - model=_model, - parent_otel_span=parent_otel_span, - ) + ( + _healthy_deployments, + _, + ) = await self._async_get_healthy_deployments( + model=_model, + parent_otel_span=parent_otel_span, ) else: _healthy_deployments = [] @@ -3884,7 +3888,6 @@ class Router: ) if exception_headers is not None: - _time_to_cooldown = ( litellm.utils._get_retry_after_from_exception_header( response_headers=exception_headers @@ -6131,7 +6134,6 @@ class Router: try: model_id = deployment.get("model_info", {}).get("id", None) if response is None: - # update self.deployment_stats if model_id is not None: self._update_usage( diff --git a/litellm/router_strategy/base_routing_strategy.py b/litellm/router_strategy/base_routing_strategy.py index a39d17e386..ea87e25eba 100644 --- a/litellm/router_strategy/base_routing_strategy.py +++ b/litellm/router_strategy/base_routing_strategy.py @@ -38,9 +38,9 @@ class BaseRoutingStrategy(ABC): except RuntimeError: # No event loop in current thread self._create_sync_thread(default_sync_interval) - self.in_memory_keys_to_update: set[str] = ( - set() - ) # Set with max size of 1000 keys + self.in_memory_keys_to_update: set[ + str + ] = set() # Set with max size of 1000 keys async def _increment_value_in_current_window( self, key: str, value: Union[int, float], ttl: int diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index 4f123df282..9e4001b67b 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -53,9 +53,9 @@ class RouterBudgetLimiting(CustomLogger): self.dual_cache = dual_cache self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) - self.provider_budget_config: Optional[GenericBudgetConfigType] = ( - provider_budget_config - ) + self.provider_budget_config: Optional[ + GenericBudgetConfigType + ] = provider_budget_config self.deployment_budget_config: Optional[GenericBudgetConfigType] = None self.tag_budget_config: Optional[GenericBudgetConfigType] = None self._init_provider_budgets() @@ -94,11 +94,13 @@ class RouterBudgetLimiting(CustomLogger): potential_deployments: List[Dict] = [] - cache_keys, provider_configs, deployment_configs = ( - await self._async_get_cache_keys_for_router_budget_limiting( - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, - ) + ( + cache_keys, + provider_configs, + deployment_configs, + ) = await self._async_get_cache_keys_for_router_budget_limiting( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, ) # Single cache read for all spend values @@ -114,17 +116,18 @@ class RouterBudgetLimiting(CustomLogger): for idx, key in enumerate(cache_keys): spend_map[key] = float(current_spends[idx] or 0.0) - potential_deployments, deployment_above_budget_info = ( - self._filter_out_deployments_above_budget( - healthy_deployments=healthy_deployments, - provider_configs=provider_configs, - deployment_configs=deployment_configs, - spend_map=spend_map, - potential_deployments=potential_deployments, - request_tags=_get_tags_from_request_kwargs( - request_kwargs=request_kwargs - ), - ) + ( + potential_deployments, + deployment_above_budget_info, + ) = self._filter_out_deployments_above_budget( + healthy_deployments=healthy_deployments, + provider_configs=provider_configs, + deployment_configs=deployment_configs, + spend_map=spend_map, + potential_deployments=potential_deployments, + request_tags=_get_tags_from_request_kwargs( + request_kwargs=request_kwargs + ), ) if len(potential_deployments) == 0: diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index b049c94264..55ca98843d 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -14,7 +14,7 @@ from litellm.types.utils import LiteLLMPydanticObjectBase if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index d1a46b7ea8..9e6c139314 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -20,7 +20,7 @@ from .base_routing_strategy import BaseRoutingStrategy if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any @@ -69,7 +69,6 @@ class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger): Raises - RateLimitError if deployment over defined RPM limit """ try: - # ------------ # Setup values # ------------ diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py index f096b026c0..13d6318fc4 100644 --- a/litellm/router_utils/cooldown_cache.py +++ b/litellm/router_utils/cooldown_cache.py @@ -3,7 +3,7 @@ Wrapper around router cache. Meant to handle model cooldown logic """ import time -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict, Union from litellm import verbose_logger from litellm.caching.caching import DualCache @@ -12,7 +12,7 @@ from litellm.caching.in_memory_cache import InMemoryCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - Span = _Span + Span = Union[_Span, Any] else: Span = Any diff --git a/litellm/router_utils/cooldown_callbacks.py b/litellm/router_utils/cooldown_callbacks.py index 54a016d3ec..5961a04feb 100644 --- a/litellm/router_utils/cooldown_callbacks.py +++ b/litellm/router_utils/cooldown_callbacks.py @@ -59,9 +59,9 @@ async def router_cooldown_event_callback( pass # get the prometheus logger from in memory loggers - prometheusLogger: Optional[PrometheusLogger] = ( - _get_prometheus_logger_from_callbacks() - ) + prometheusLogger: Optional[ + PrometheusLogger + ] = _get_prometheus_logger_from_callbacks() if prometheusLogger is not None: prometheusLogger.set_deployment_complete_outage( diff --git a/litellm/router_utils/cooldown_handlers.py b/litellm/router_utils/cooldown_handlers.py index 52babc27f2..ed9c2dd229 100644 --- a/litellm/router_utils/cooldown_handlers.py +++ b/litellm/router_utils/cooldown_handlers.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from litellm.router import Router as _Router LitellmRouter = _Router - Span = _Span + Span = Union[_Span, Any] else: LitellmRouter = Any Span = Any diff --git a/litellm/router_utils/handle_error.py b/litellm/router_utils/handle_error.py index 132440cbc3..c331da70ac 100644 --- a/litellm/router_utils/handle_error.py +++ b/litellm/router_utils/handle_error.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from litellm._logging import verbose_router_logger from litellm.router_utils.cooldown_handlers import ( @@ -13,7 +13,7 @@ if TYPE_CHECKING: from litellm.router import Router as _Router LitellmRouter = _Router - Span = _Span + Span = Union[_Span, Any] else: LitellmRouter = Any Span = Any diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 729510574a..c6804b1ad4 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -105,13 +105,11 @@ class PatternMatchRouter: new_deployments = [] for deployment in deployments: new_deployment = copy.deepcopy(deployment) - new_deployment["litellm_params"]["model"] = ( - PatternMatchRouter.set_deployment_model_name( - matched_pattern=matched_pattern, - litellm_deployment_litellm_model=deployment["litellm_params"][ - "model" - ], - ) + new_deployment["litellm_params"][ + "model" + ] = PatternMatchRouter.set_deployment_model_name( + matched_pattern=matched_pattern, + litellm_deployment_litellm_model=deployment["litellm_params"]["model"], ) new_deployments.append(new_deployment) diff --git a/litellm/router_utils/prompt_caching_cache.py b/litellm/router_utils/prompt_caching_cache.py index 1bf686d694..6a96b85e8a 100644 --- a/litellm/router_utils/prompt_caching_cache.py +++ b/litellm/router_utils/prompt_caching_cache.py @@ -4,7 +4,7 @@ Wrapper around router cache. Meant to store model id when prompt caching support import hashlib import json -from typing import TYPE_CHECKING, Any, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, List, Optional, TypedDict, Union from litellm.caching.caching import DualCache from litellm.caching.in_memory_cache import InMemoryCache @@ -16,7 +16,7 @@ if TYPE_CHECKING: from litellm.router import Router litellm_router = Router - Span = _Span + Span = Union[_Span, Any] else: Span = Any litellm_router = Any diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index fd89d6c558..327dbf3d19 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -50,7 +50,6 @@ class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager): if use_aws_secret_manager is None or use_aws_secret_manager is False: return try: - cls.validate_environment() litellm.secret_manager_client = cls() litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER diff --git a/litellm/secret_managers/hashicorp_secret_manager.py b/litellm/secret_managers/hashicorp_secret_manager.py index a3d129f89c..e0b4a08ce8 100644 --- a/litellm/secret_managers/hashicorp_secret_manager.py +++ b/litellm/secret_managers/hashicorp_secret_manager.py @@ -34,7 +34,7 @@ class HashicorpSecretManager(BaseSecretManager): # Validate environment if not self.vault_token: raise ValueError( - "Missing Vault token. Please set VAULT_TOKEN in your environment." + "Missing Vault token. Please set HCP_VAULT_TOKEN in your environment." ) litellm.secret_manager_client = self @@ -90,9 +90,9 @@ class HashicorpSecretManager(BaseSecretManager): headers["X-Vault-Namespace"] = self.vault_namespace try: # We use the client cert and key for mutual TLS - resp = httpx.post( + client = httpx.Client(cert=(self.tls_cert_path, self.tls_key_path)) + resp = client.post( login_url, - cert=(self.tls_cert_path, self.tls_key_path), headers=headers, json=self._get_tls_cert_auth_body(), ) diff --git a/litellm/types/integrations/arize_phoenix.py b/litellm/types/integrations/arize_phoenix.py index 4566022d17..a8a1fed5a6 100644 --- a/litellm/types/integrations/arize_phoenix.py +++ b/litellm/types/integrations/arize_phoenix.py @@ -1,9 +1,11 @@ from typing import TYPE_CHECKING, Literal, Optional from pydantic import BaseModel + from .arize import Protocol + class ArizePhoenixConfig(BaseModel): otlp_auth_headers: Optional[str] = None protocol: Protocol - endpoint: str + endpoint: str diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 1c5552637c..3ba5a3a4e0 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -722,12 +722,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk): class Hyperparameters(BaseModel): batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." - learning_rate_multiplier: Optional[Union[str, float]] = ( - None # Scaling factor for the learning rate - ) - n_epochs: Optional[Union[str, int]] = ( - None # "The number of epochs to train the model for" - ) + learning_rate_multiplier: Optional[ + Union[str, float] + ] = None # Scaling factor for the learning rate + n_epochs: Optional[ + Union[str, int] + ] = None # "The number of epochs to train the model for" class FineTuningJobCreate(BaseModel): @@ -754,18 +754,18 @@ class FineTuningJobCreate(BaseModel): model: str # "The name of the model to fine-tune." training_file: str # "The ID of an uploaded file that contains training data." - hyperparameters: Optional[Hyperparameters] = ( - None # "The hyperparameters used for the fine-tuning job." - ) - suffix: Optional[str] = ( - None # "A string of up to 18 characters that will be added to your fine-tuned model name." - ) - validation_file: Optional[str] = ( - None # "The ID of an uploaded file that contains validation data." - ) - integrations: Optional[List[str]] = ( - None # "A list of integrations to enable for your fine-tuning job." - ) + hyperparameters: Optional[ + Hyperparameters + ] = None # "The hyperparameters used for the fine-tuning job." + suffix: Optional[ + str + ] = None # "A string of up to 18 characters that will be added to your fine-tuned model name." + validation_file: Optional[ + str + ] = None # "The ID of an uploaded file that contains validation data." + integrations: Optional[ + List[str] + ] = None # "A list of integrations to enable for your fine-tuning job." seed: Optional[int] = None # "The seed controls the reproducibility of the job." diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index 8e2a8cc334..fb6dae0d1d 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -21,7 +21,6 @@ class RerankRequest(BaseModel): max_tokens_per_doc: Optional[int] = None - class OptionalRerankParams(TypedDict, total=False): query: str top_n: Optional[int] @@ -60,9 +59,9 @@ class RerankResponseResult(TypedDict, total=False): class RerankResponse(BaseModel): id: Optional[str] = None - results: Optional[List[RerankResponseResult]] = ( - None # Contains index and relevance_score - ) + results: Optional[ + List[RerankResponseResult] + ] = None # Contains index and relevance_score meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units # Define private attributes using PrivateAttr diff --git a/litellm/types/router.py b/litellm/types/router.py index dcd547def2..45a8a3fcf6 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -95,18 +95,16 @@ class ModelInfo(BaseModel): id: Optional[ str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - db_model: bool = ( - False # used for proxy - to separate models which are stored in the db vs. config. - ) + db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. updated_at: Optional[datetime.datetime] = None updated_by: Optional[str] = None created_at: Optional[datetime.datetime] = None created_by: Optional[str] = None - base_model: Optional[str] = ( - None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking - ) + base_model: Optional[ + str + ] = None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking tier: Optional[Literal["free", "paid"]] = None """ @@ -171,12 +169,12 @@ class GenericLiteLLMParams(CredentialLiteLLMParams): custom_llm_provider: Optional[str] = None tpm: Optional[int] = None rpm: Optional[int] = None - timeout: Optional[Union[float, str, httpx.Timeout]] = ( - None # if str, pass in as os.environ/ - ) - stream_timeout: Optional[Union[float, str]] = ( - None # timeout when making stream=True calls, if str, pass in as os.environ/ - ) + timeout: Optional[ + Union[float, str, httpx.Timeout] + ] = None # if str, pass in as os.environ/ + stream_timeout: Optional[ + Union[float, str] + ] = None # timeout when making stream=True calls, if str, pass in as os.environ/ max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None @@ -251,9 +249,9 @@ class GenericLiteLLMParams(CredentialLiteLLMParams): if max_retries is not None and isinstance(max_retries, str): max_retries = int(max_retries) # cast to int # We need to keep max_retries in args since it's a parameter of GenericLiteLLMParams - args["max_retries"] = ( - max_retries # Put max_retries back in args after popping it - ) + args[ + "max_retries" + ] = max_retries # Put max_retries back in args after popping it super().__init__(**args, **params) def __contains__(self, key): @@ -577,7 +575,6 @@ class AssistantsTypedDict(TypedDict): class FineTuningConfig(BaseModel): - custom_llm_provider: Literal["azure", "openai"] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index fe6330f8bd..7f84a41cd5 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -482,7 +482,6 @@ from openai.types.chat.chat_completion_audio import ChatCompletionAudio class ChatCompletionAudioResponse(ChatCompletionAudio): - def __init__( self, data: str, @@ -927,7 +926,6 @@ class StreamingChoices(OpenAIObject): self.finish_reason = None self.index = index if delta is not None: - if isinstance(delta, Delta): self.delta = delta elif isinstance(delta, dict): @@ -961,7 +959,6 @@ class StreamingChoices(OpenAIObject): class StreamingChatCompletionChunk(OpenAIChatCompletionChunk): def __init__(self, **kwargs): - new_choices = [] for choice in kwargs["choices"]: new_choice = StreamingChoices(**choice).model_dump() diff --git a/litellm/utils.py b/litellm/utils.py index 3c8b6667f9..777352ed34 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -482,7 +482,6 @@ def get_dynamic_callbacks( def function_setup( # noqa: PLR0915 original_function: str, rules_obj, start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. - ### NOTICES ### from litellm import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import set_callbacks @@ -504,9 +503,9 @@ def function_setup( # noqa: PLR0915 function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None ## DYNAMIC CALLBACKS ## - dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = ( - kwargs.pop("callbacks", None) - ) + dynamic_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = kwargs.pop("callbacks", None) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) if len(all_callbacks) > 0: @@ -1190,9 +1189,9 @@ def client(original_function): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs["retry_policy"] = ( - reset_retry_policy() - ) # prevent infinite loops + kwargs[ + "retry_policy" + ] = reset_retry_policy() # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -1404,7 +1403,6 @@ def client(original_function): # noqa: PLR0915 if ( num_retries and not _is_litellm_router_call ): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying - try: litellm.num_retries = ( None # set retries to None to prevent infinite loops @@ -1425,7 +1423,6 @@ def client(original_function): # noqa: PLR0915 and context_window_fallback_dict and model in context_window_fallback_dict ): - if len(args) > 0: args[0] = context_window_fallback_dict[model] # type: ignore else: @@ -1521,7 +1518,6 @@ def _select_tokenizer( @lru_cache(maxsize=128) def _select_tokenizer_helper(model: str) -> SelectTokenizerResponse: - if litellm.disable_hf_tokenizer_download is True: return _return_openai_tokenizer(model) @@ -2990,16 +2986,16 @@ def get_optional_params( # noqa: PLR0915 True # so that main.py adds the function call to the prompt ) if "tools" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("tools") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("tools") non_default_params.pop( "tool_choice", None ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("functions") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("functions") elif ( litellm.add_function_to_prompt ): # if user opts to add it to prompt instead @@ -3022,10 +3018,10 @@ def get_optional_params( # noqa: PLR0915 if "response_format" in non_default_params: if provider_config is not None: - non_default_params["response_format"] = ( - provider_config.get_json_schema_from_pydantic_object( - response_format=non_default_params["response_format"] - ) + non_default_params[ + "response_format" + ] = provider_config.get_json_schema_from_pydantic_object( + response_format=non_default_params["response_format"] ) else: non_default_params["response_format"] = type_to_response_format_param( @@ -3177,7 +3173,6 @@ def get_optional_params( # noqa: PLR0915 ), ) elif custom_llm_provider == "replicate": - optional_params = litellm.ReplicateConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, @@ -3211,7 +3206,6 @@ def get_optional_params( # noqa: PLR0915 ), ) elif custom_llm_provider == "together_ai": - optional_params = litellm.TogetherAIConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, @@ -3279,7 +3273,6 @@ def get_optional_params( # noqa: PLR0915 ), ) elif custom_llm_provider == "vertex_ai": - if model in litellm.vertex_mistral_models: if "codestral" in model: optional_params = ( @@ -3358,7 +3351,6 @@ def get_optional_params( # noqa: PLR0915 elif "anthropic" in bedrock_base_model and bedrock_route == "invoke": if bedrock_base_model.startswith("anthropic.claude-3"): - optional_params = ( litellm.AmazonAnthropicClaude3Config().map_openai_params( non_default_params=non_default_params, @@ -3395,7 +3387,6 @@ def get_optional_params( # noqa: PLR0915 ), ) elif custom_llm_provider == "cloudflare": - optional_params = litellm.CloudflareChatConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -3407,7 +3398,6 @@ def get_optional_params( # noqa: PLR0915 ), ) elif custom_llm_provider == "ollama": - optional_params = litellm.OllamaConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, @@ -3419,7 +3409,6 @@ def get_optional_params( # noqa: PLR0915 ), ) elif custom_llm_provider == "ollama_chat": - optional_params = litellm.OllamaChatConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -4005,9 +3994,9 @@ def _count_characters(text: str) -> int: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: - _choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( - response_obj.choices - ) + _choices: Union[ + List[Union[Choices, StreamingChoices]], List[StreamingChoices] + ] = response_obj.choices response_str = "" for choice in _choices: @@ -4405,7 +4394,6 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None if _model_info is None and model in litellm.model_cost: - key = model _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( @@ -4416,7 +4404,6 @@ def _get_model_info_helper( # noqa: PLR0915 _model_info is None and combined_stripped_model_name in litellm.model_cost ): - key = combined_stripped_model_name _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( @@ -4424,7 +4411,6 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None if _model_info is None and stripped_model_name in litellm.model_cost: - key = stripped_model_name _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( @@ -4432,7 +4418,6 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None if _model_info is None and split_model in litellm.model_cost: - key = split_model _model_info = _get_model_info_from_model_cost(key=key) if not _check_provider_match( diff --git a/poetry.lock b/poetry.lock index 64e8eaa8ff..d659d66952 100644 --- a/poetry.lock +++ b/poetry.lock @@ -436,7 +436,7 @@ files = [ name = "cffi" version = "1.17.1" description = "Foreign Function Interface for Python calling C code." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, @@ -658,7 +658,7 @@ cron = ["capturer (>=2.4)"] name = "cryptography" version = "43.0.3" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, @@ -2563,7 +2563,7 @@ files = [ name = "pycparser" version = "2.22" description = "C parser in Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, @@ -3326,6 +3326,32 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "ruff" +version = "0.1.15" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5fe8d54df166ecc24106db7dd6a68d44852d14eb0729ea4672bb4d96c320b7df"}, + {file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f0bfbb53c4b4de117ac4d6ddfd33aa5fc31beeaa21d23c45c6dd249faf9126f"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0d432aec35bfc0d800d4f70eba26e23a352386be3a6cf157083d18f6f5881c8"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9405fa9ac0e97f35aaddf185a1be194a589424b8713e3b97b762336ec79ff807"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66ec24fe36841636e814b8f90f572a8c0cb0e54d8b5c2d0e300d28a0d7bffec"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6f8ad828f01e8dd32cc58bc28375150171d198491fc901f6f98d2a39ba8e3ff5"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86811954eec63e9ea162af0ffa9f8d09088bab51b7438e8b6488b9401863c25e"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd4025ac5e87d9b80e1f300207eb2fd099ff8200fa2320d7dc066a3f4622dc6b"}, + {file = "ruff-0.1.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b17b93c02cdb6aeb696effecea1095ac93f3884a49a554a9afa76bb125c114c1"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ddb87643be40f034e97e97f5bc2ef7ce39de20e34608f3f829db727a93fb82c5"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:abf4822129ed3a5ce54383d5f0e964e7fef74a41e48eb1dfad404151efc130a2"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6c629cf64bacfd136c07c78ac10a54578ec9d1bd2a9d395efbee0935868bf852"}, + {file = "ruff-0.1.15-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1bab866aafb53da39c2cadfb8e1c4550ac5340bb40300083eb8967ba25481447"}, + {file = "ruff-0.1.15-py3-none-win32.whl", hash = "sha256:2417e1cb6e2068389b07e6fa74c306b2810fe3ee3476d5b8a96616633f40d14f"}, + {file = "ruff-0.1.15-py3-none-win_amd64.whl", hash = "sha256:3837ac73d869efc4182d9036b1405ef4c73d9b1f88da2413875e34e0d6919587"}, + {file = "ruff-0.1.15-py3-none-win_arm64.whl", hash = "sha256:9a933dfb1c14ec7a33cceb1e49ec4a16b51ce3c20fd42663198746efc0427360"}, + {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, +] + [[package]] name = "s3transfer" version = "0.10.4" @@ -3603,6 +3629,111 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "types-cffi" +version = "1.16.0.20241221" +description = "Typing stubs for cffi" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_cffi-1.16.0.20241221-py3-none-any.whl", hash = "sha256:e5b76b4211d7a9185f6ab8d06a106d56c7eb80af7cdb8bfcb4186ade10fb112f"}, + {file = "types_cffi-1.16.0.20241221.tar.gz", hash = "sha256:1c96649618f4b6145f58231acb976e0b448be6b847f7ab733dabe62dfbff6591"}, +] + +[package.dependencies] +types-setuptools = "*" + +[[package]] +name = "types-pyopenssl" +version = "24.1.0.20240722" +description = "Typing stubs for pyOpenSSL" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pyOpenSSL-24.1.0.20240722.tar.gz", hash = "sha256:47913b4678a01d879f503a12044468221ed8576263c1540dcb0484ca21b08c39"}, + {file = "types_pyOpenSSL-24.1.0.20240722-py3-none-any.whl", hash = "sha256:6a7a5d2ec042537934cfb4c9d4deb0e16c4c6250b09358df1f083682fe6fda54"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" +types-cffi = "*" + +[[package]] +name = "types-pyyaml" +version = "6.0.12.20241230" +description = "Typing stubs for PyYAML" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_PyYAML-6.0.12.20241230-py3-none-any.whl", hash = "sha256:fa4d32565219b68e6dee5f67534c722e53c00d1cfc09c435ef04d7353e1e96e6"}, + {file = "types_pyyaml-6.0.12.20241230.tar.gz", hash = "sha256:7f07622dbd34bb9c8b264fe860a17e0efcad00d50b5f27e93984909d9363498c"}, +] + +[[package]] +name = "types-redis" +version = "4.6.0.20241004" +description = "Typing stubs for redis" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e"}, + {file = "types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" +types-pyOpenSSL = "*" + +[[package]] +name = "types-requests" +version = "2.31.0.6" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, + {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, +] + +[package.dependencies] +types-urllib3 = "*" + +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "types-setuptools" +version = "75.8.0.20250110" +description = "Typing stubs for setuptools" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_setuptools-75.8.0.20250110-py3-none-any.whl", hash = "sha256:a9f12980bbf9bcdc23ecd80755789085bad6bfce4060c2275bc2b4ca9f2bc480"}, + {file = "types_setuptools-75.8.0.20250110.tar.gz", hash = "sha256:96f7ec8bbd6e0a54ea180d66ad68ad7a1d7954e7281a710ea2de75e355545271"}, +] + +[[package]] +name = "types-urllib3" +version = "1.26.25.14" +description = "Typing stubs for urllib3" +optional = false +python-versions = "*" +files = [ + {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, + {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, +] + [[package]] name = "typing-extensions" version = "4.13.0" @@ -3993,4 +4124,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "0f195796116a7c7a4a04d9958a7662d74baccb5266f531bb58a4403fd4db4e0e" +content-hash = "36792478ff4afec5c8e748caf9b2ae6bebf3dd223e78bea2626b6589ef3277e4" diff --git a/pyproject.toml b/pyproject.toml index 2dbfcc39de..5eb4a71160 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ Documentation = "https://docs.litellm.ai" [tool.poetry.dependencies] python = ">=3.8.1,<4.0, !=3.9.7" httpx = ">=0.23.0" -openai = ">=1.66.1" +openai = ">=1.68.2" python-dotenv = ">=0.2.0" tiktoken = ">=0.7.0" importlib-metadata = ">=6.8.0" @@ -100,6 +100,11 @@ pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-asyncio = "^0.21.1" respx = "^0.20.2" +ruff = "^0.1.0" +types-requests = "*" +types-setuptools = "*" +types-redis = "*" +types-PyYAML = "*" [tool.poetry.group.proxy-dev.dependencies] prisma = "0.11.0" diff --git a/tests/litellm_utils_tests/test_hashicorp.py b/tests/litellm_utils_tests/test_hashicorp.py index 612af5a79c..c61168c99d 100644 --- a/tests/litellm_utils_tests/test_hashicorp.py +++ b/tests/litellm_utils_tests/test_hashicorp.py @@ -165,19 +165,26 @@ async def test_hashicorp_secret_manager_delete_secret(): ) -def test_hashicorp_secret_manager_tls_cert_auth(): - with patch("httpx.post") as mock_post: - # Configure the mock response for TLS auth - mock_auth_response = MagicMock() - mock_auth_response.json.return_value = { +def test_hashicorp_secret_manager_tls_cert_auth(monkeypatch): + monkeypatch.setenv("HCP_VAULT_TOKEN", "test-client-token-12345") + print("HCP_VAULT_TOKEN=", os.getenv("HCP_VAULT_TOKEN")) + # Mock both httpx.post and httpx.Client + with patch("httpx.Client") as mock_client: + # Configure the mock client and response + mock_response = MagicMock() + mock_response.json.return_value = { "auth": { "client_token": "test-client-token-12345", "lease_duration": 3600, "renewable": True, } } - mock_auth_response.raise_for_status.return_value = None - mock_post.return_value = mock_auth_response + mock_response.raise_for_status.return_value = None + + # Configure the mock client's post method + mock_client_instance = MagicMock() + mock_client_instance.post.return_value = mock_response + mock_client.return_value = mock_client_instance # Create a new instance with TLS cert config test_manager = HashicorpSecretManager() @@ -185,19 +192,22 @@ def test_hashicorp_secret_manager_tls_cert_auth(): test_manager.tls_key_path = "key.pem" test_manager.vault_cert_role = "test-role" test_manager.vault_namespace = "test-namespace" + # Test the TLS auth method token = test_manager._auth_via_tls_cert() - # Verify the token and request parameters + # Verify the token assert token == "test-client-token-12345" - mock_post.assert_called_once_with( + + # Verify Client was created with correct cert tuple + mock_client.assert_called_once_with(cert=("cert.pem", "key.pem")) + + # Verify post was called with correct parameters + mock_client_instance.post.assert_called_once_with( f"{test_manager.vault_addr}/v1/auth/cert/login", - cert=("cert.pem", "key.pem"), headers={"X-Vault-Namespace": "test-namespace"}, json={"name": "test-role"}, ) # Verify the token was cached - assert ( - test_manager.cache.get_cache("hcp_vault_token") == "test-client-token-12345" - ) + assert test_manager.cache.get_cache("hcp_vault_token") == "test-client-token-12345"