diff --git a/docs/my-website/docs/proxy/guardrails/secret_detection.md b/docs/my-website/docs/proxy/guardrails/secret_detection.md new file mode 100644 index 000000000..a70c35d96 --- /dev/null +++ b/docs/my-website/docs/proxy/guardrails/secret_detection.md @@ -0,0 +1,557 @@ +# ✨ Secret Detection/Redaction (Enterprise-only) +❓ Use this to REDACT API Keys, Secrets sent in requests to an LLM. + +Example if you want to redact the value of `OPENAI_API_KEY` in the following request + +#### Incoming Request + +```json +{ + "messages": [ + { + "role": "user", + "content": "Hey, how's it going, API_KEY = 'sk_1234567890abcdef'", + } + ] +} +``` + +#### Request after Moderation + +```json +{ + "messages": [ + { + "role": "user", + "content": "Hey, how's it going, API_KEY = '[REDACTED]'", + } + ] +} +``` + +**Usage** + +**Step 1** Add this to your config.yaml + +```yaml +guardrails: + - guardrail_name: "my-custom-name" + litellm_params: + guardrail: "hide-secrets" # supported values: "aporia", "lakera", .. + mode: "pre_call" +``` + +**Step 2** Run litellm proxy with `--detailed_debug` to see the server logs + +``` +litellm --config config.yaml --detailed_debug +``` + +**Step 3** Test it with request + +Send this request +```shell +curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "fake-claude-endpoint", + "messages": [ + { + "role": "user", + "content": "what is the value of my open ai key? openai_api_key=sk-1234998222" + } + ], + "guardrails": ["my-custom-name"] +}' +``` + + +Expect to see the following warning on your litellm server logs + +```shell +LiteLLM Proxy:WARNING: secret_detection.py:88 - Detected and redacted secrets in message: ['Secret Keyword'] +``` + + +You can also see the raw request sent from litellm to the API Provider with (`--detailed_debug`). +```json +POST Request Sent from LiteLLM: +curl -X POST \ +https://api.groq.com/openai/v1/ \ +-H 'Authorization: Bearer gsk_mySVchjY********************************************' \ +-d { + "model": "llama3-8b-8192", + "messages": [ + { + "role": "user", + "content": "what is the time today, openai_api_key=[REDACTED]" + } + ], + "stream": false, + "extra_body": {} +} +``` + +## Turn on/off per project (API KEY/Team) + +[**See Here**](./quick_start.md#-control-guardrails-per-project-api-key) + +## Control secret detectors + +LiteLLM uses the [`detect-secrets`](https://github.com/Yelp/detect-secrets) library for secret detection. See [all plugins run by default](#default-config-used) + + +### Usage + +Here's how to control which plugins are run per request. This is useful if developers complain about secret detection impacting response quality. + +**1. Set-up config.yaml** + +```yaml +guardrails: + - guardrail_name: "hide-secrets" + litellm_params: + guardrail: "hide-secrets" # supported values: "aporia", "lakera" + mode: "pre_call" + detect_secrets_config: { + "plugins_used": [ + {"name": "SoftlayerDetector"}, + {"name": "StripeDetector"}, + {"name": "NpmDetector"} + ] + } +``` + +**2. Start proxy** + +Run with `--detailed_debug` for more detailed logs. Use in dev only. + +```bash +litellm --config /path/to/config.yaml --detailed_debug +``` + +**3. Test it!** + +```bash +curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "fake-claude-endpoint", + "messages": [ + { + "role": "user", + "content": "what is the value of my open ai key? openai_api_key=sk-1234998222" + } + ], + "guardrails": ["hide-secrets"] +}' +``` + +**Expected Logs** + +Look for this in your logs, to confirm your changes worked as expected. + +``` +No secrets detected on input. +``` + +### Default Config Used + +``` +_default_detect_secrets_config = { + "plugins_used": [ + {"name": "SoftlayerDetector"}, + {"name": "StripeDetector"}, + {"name": "NpmDetector"}, + {"name": "IbmCosHmacDetector"}, + {"name": "DiscordBotTokenDetector"}, + {"name": "BasicAuthDetector"}, + {"name": "AzureStorageKeyDetector"}, + {"name": "ArtifactoryDetector"}, + {"name": "AWSKeyDetector"}, + {"name": "CloudantDetector"}, + {"name": "IbmCloudIamDetector"}, + {"name": "JwtTokenDetector"}, + {"name": "MailchimpDetector"}, + {"name": "SquareOAuthDetector"}, + {"name": "PrivateKeyDetector"}, + {"name": "TwilioKeyDetector"}, + { + "name": "AdafruitKeyDetector", + "path": _custom_plugins_path + "/adafruit.py", + }, + { + "name": "AdobeSecretDetector", + "path": _custom_plugins_path + "/adobe.py", + }, + { + "name": "AgeSecretKeyDetector", + "path": _custom_plugins_path + "/age_secret_key.py", + }, + { + "name": "AirtableApiKeyDetector", + "path": _custom_plugins_path + "/airtable_api_key.py", + }, + { + "name": "AlgoliaApiKeyDetector", + "path": _custom_plugins_path + "/algolia_api_key.py", + }, + { + "name": "AlibabaSecretDetector", + "path": _custom_plugins_path + "/alibaba.py", + }, + { + "name": "AsanaSecretDetector", + "path": _custom_plugins_path + "/asana.py", + }, + { + "name": "AtlassianApiTokenDetector", + "path": _custom_plugins_path + "/atlassian_api_token.py", + }, + { + "name": "AuthressAccessKeyDetector", + "path": _custom_plugins_path + "/authress_access_key.py", + }, + { + "name": "BittrexDetector", + "path": _custom_plugins_path + "/beamer_api_token.py", + }, + { + "name": "BitbucketDetector", + "path": _custom_plugins_path + "/bitbucket.py", + }, + { + "name": "BeamerApiTokenDetector", + "path": _custom_plugins_path + "/bittrex.py", + }, + { + "name": "ClojarsApiTokenDetector", + "path": _custom_plugins_path + "/clojars_api_token.py", + }, + { + "name": "CodecovAccessTokenDetector", + "path": _custom_plugins_path + "/codecov_access_token.py", + }, + { + "name": "CoinbaseAccessTokenDetector", + "path": _custom_plugins_path + "/coinbase_access_token.py", + }, + { + "name": "ConfluentDetector", + "path": _custom_plugins_path + "/confluent.py", + }, + { + "name": "ContentfulApiTokenDetector", + "path": _custom_plugins_path + "/contentful_api_token.py", + }, + { + "name": "DatabricksApiTokenDetector", + "path": _custom_plugins_path + "/databricks_api_token.py", + }, + { + "name": "DatadogAccessTokenDetector", + "path": _custom_plugins_path + "/datadog_access_token.py", + }, + { + "name": "DefinedNetworkingApiTokenDetector", + "path": _custom_plugins_path + "/defined_networking_api_token.py", + }, + { + "name": "DigitaloceanDetector", + "path": _custom_plugins_path + "/digitalocean.py", + }, + { + "name": "DopplerApiTokenDetector", + "path": _custom_plugins_path + "/doppler_api_token.py", + }, + { + "name": "DroneciAccessTokenDetector", + "path": _custom_plugins_path + "/droneci_access_token.py", + }, + { + "name": "DuffelApiTokenDetector", + "path": _custom_plugins_path + "/duffel_api_token.py", + }, + { + "name": "DynatraceApiTokenDetector", + "path": _custom_plugins_path + "/dynatrace_api_token.py", + }, + { + "name": "DiscordDetector", + "path": _custom_plugins_path + "/discord.py", + }, + { + "name": "DropboxDetector", + "path": _custom_plugins_path + "/dropbox.py", + }, + { + "name": "EasyPostDetector", + "path": _custom_plugins_path + "/easypost.py", + }, + { + "name": "EtsyAccessTokenDetector", + "path": _custom_plugins_path + "/etsy_access_token.py", + }, + { + "name": "FacebookAccessTokenDetector", + "path": _custom_plugins_path + "/facebook_access_token.py", + }, + { + "name": "FastlyApiKeyDetector", + "path": _custom_plugins_path + "/fastly_api_token.py", + }, + { + "name": "FinicityDetector", + "path": _custom_plugins_path + "/finicity.py", + }, + { + "name": "FinnhubAccessTokenDetector", + "path": _custom_plugins_path + "/finnhub_access_token.py", + }, + { + "name": "FlickrAccessTokenDetector", + "path": _custom_plugins_path + "/flickr_access_token.py", + }, + { + "name": "FlutterwaveDetector", + "path": _custom_plugins_path + "/flutterwave.py", + }, + { + "name": "FrameIoApiTokenDetector", + "path": _custom_plugins_path + "/frameio_api_token.py", + }, + { + "name": "FreshbooksAccessTokenDetector", + "path": _custom_plugins_path + "/freshbooks_access_token.py", + }, + { + "name": "GCPApiKeyDetector", + "path": _custom_plugins_path + "/gcp_api_key.py", + }, + { + "name": "GitHubTokenCustomDetector", + "path": _custom_plugins_path + "/github_token.py", + }, + { + "name": "GitLabDetector", + "path": _custom_plugins_path + "/gitlab.py", + }, + { + "name": "GitterAccessTokenDetector", + "path": _custom_plugins_path + "/gitter_access_token.py", + }, + { + "name": "GoCardlessApiTokenDetector", + "path": _custom_plugins_path + "/gocardless_api_token.py", + }, + { + "name": "GrafanaDetector", + "path": _custom_plugins_path + "/grafana.py", + }, + { + "name": "HashiCorpTFApiTokenDetector", + "path": _custom_plugins_path + "/hashicorp_tf_api_token.py", + }, + { + "name": "HerokuApiKeyDetector", + "path": _custom_plugins_path + "/heroku_api_key.py", + }, + { + "name": "HubSpotApiTokenDetector", + "path": _custom_plugins_path + "/hubspot_api_key.py", + }, + { + "name": "HuggingFaceDetector", + "path": _custom_plugins_path + "/huggingface.py", + }, + { + "name": "IntercomApiTokenDetector", + "path": _custom_plugins_path + "/intercom_api_key.py", + }, + { + "name": "JFrogDetector", + "path": _custom_plugins_path + "/jfrog.py", + }, + { + "name": "JWTBase64Detector", + "path": _custom_plugins_path + "/jwt.py", + }, + { + "name": "KrakenAccessTokenDetector", + "path": _custom_plugins_path + "/kraken_access_token.py", + }, + { + "name": "KucoinDetector", + "path": _custom_plugins_path + "/kucoin.py", + }, + { + "name": "LaunchdarklyAccessTokenDetector", + "path": _custom_plugins_path + "/launchdarkly_access_token.py", + }, + { + "name": "LinearDetector", + "path": _custom_plugins_path + "/linear.py", + }, + { + "name": "LinkedInDetector", + "path": _custom_plugins_path + "/linkedin.py", + }, + { + "name": "LobDetector", + "path": _custom_plugins_path + "/lob.py", + }, + { + "name": "MailgunDetector", + "path": _custom_plugins_path + "/mailgun.py", + }, + { + "name": "MapBoxApiTokenDetector", + "path": _custom_plugins_path + "/mapbox_api_token.py", + }, + { + "name": "MattermostAccessTokenDetector", + "path": _custom_plugins_path + "/mattermost_access_token.py", + }, + { + "name": "MessageBirdDetector", + "path": _custom_plugins_path + "/messagebird.py", + }, + { + "name": "MicrosoftTeamsWebhookDetector", + "path": _custom_plugins_path + "/microsoft_teams_webhook.py", + }, + { + "name": "NetlifyAccessTokenDetector", + "path": _custom_plugins_path + "/netlify_access_token.py", + }, + { + "name": "NewRelicDetector", + "path": _custom_plugins_path + "/new_relic.py", + }, + { + "name": "NYTimesAccessTokenDetector", + "path": _custom_plugins_path + "/nytimes_access_token.py", + }, + { + "name": "OktaAccessTokenDetector", + "path": _custom_plugins_path + "/okta_access_token.py", + }, + { + "name": "OpenAIApiKeyDetector", + "path": _custom_plugins_path + "/openai_api_key.py", + }, + { + "name": "PlanetScaleDetector", + "path": _custom_plugins_path + "/planetscale.py", + }, + { + "name": "PostmanApiTokenDetector", + "path": _custom_plugins_path + "/postman_api_token.py", + }, + { + "name": "PrefectApiTokenDetector", + "path": _custom_plugins_path + "/prefect_api_token.py", + }, + { + "name": "PulumiApiTokenDetector", + "path": _custom_plugins_path + "/pulumi_api_token.py", + }, + { + "name": "PyPiUploadTokenDetector", + "path": _custom_plugins_path + "/pypi_upload_token.py", + }, + { + "name": "RapidApiAccessTokenDetector", + "path": _custom_plugins_path + "/rapidapi_access_token.py", + }, + { + "name": "ReadmeApiTokenDetector", + "path": _custom_plugins_path + "/readme_api_token.py", + }, + { + "name": "RubygemsApiTokenDetector", + "path": _custom_plugins_path + "/rubygems_api_token.py", + }, + { + "name": "ScalingoApiTokenDetector", + "path": _custom_plugins_path + "/scalingo_api_token.py", + }, + { + "name": "SendbirdDetector", + "path": _custom_plugins_path + "/sendbird.py", + }, + { + "name": "SendGridApiTokenDetector", + "path": _custom_plugins_path + "/sendgrid_api_token.py", + }, + { + "name": "SendinBlueApiTokenDetector", + "path": _custom_plugins_path + "/sendinblue_api_token.py", + }, + { + "name": "SentryAccessTokenDetector", + "path": _custom_plugins_path + "/sentry_access_token.py", + }, + { + "name": "ShippoApiTokenDetector", + "path": _custom_plugins_path + "/shippo_api_token.py", + }, + { + "name": "ShopifyDetector", + "path": _custom_plugins_path + "/shopify.py", + }, + { + "name": "SlackDetector", + "path": _custom_plugins_path + "/slack.py", + }, + { + "name": "SnykApiTokenDetector", + "path": _custom_plugins_path + "/snyk_api_token.py", + }, + { + "name": "SquarespaceAccessTokenDetector", + "path": _custom_plugins_path + "/squarespace_access_token.py", + }, + { + "name": "SumoLogicDetector", + "path": _custom_plugins_path + "/sumologic.py", + }, + { + "name": "TelegramBotApiTokenDetector", + "path": _custom_plugins_path + "/telegram_bot_api_token.py", + }, + { + "name": "TravisCiAccessTokenDetector", + "path": _custom_plugins_path + "/travisci_access_token.py", + }, + { + "name": "TwitchApiTokenDetector", + "path": _custom_plugins_path + "/twitch_api_token.py", + }, + { + "name": "TwitterDetector", + "path": _custom_plugins_path + "/twitter.py", + }, + { + "name": "TypeformApiTokenDetector", + "path": _custom_plugins_path + "/typeform_api_token.py", + }, + { + "name": "VaultDetector", + "path": _custom_plugins_path + "/vault.py", + }, + { + "name": "YandexDetector", + "path": _custom_plugins_path + "/yandex.py", + }, + { + "name": "ZendeskSecretKeyDetector", + "path": _custom_plugins_path + "/zendesk_secret_key.py", + }, + {"name": "Base64HighEntropyString", "limit": 3.0}, + {"name": "HexHighEntropyString", "limit": 3.0}, + ] +} +``` \ No newline at end of file diff --git a/docs/my-website/docs/proxy/ui.md b/docs/my-website/docs/proxy/ui.md index d678d550c..a6ed9f6de 100644 --- a/docs/my-website/docs/proxy/ui.md +++ b/docs/my-website/docs/proxy/ui.md @@ -297,3 +297,14 @@ Set your colors to any of the following colors: https://www.tremor.so/docs/layou - Deploy LiteLLM Proxy Server + +## Disable Admin UI + +Set `DISABLE_ADMIN_UI="True"` in your environment to disable the Admin UI. + +Useful, if your security team has additional restrictions on UI usage. + + +**Expected Response** + + \ No newline at end of file diff --git a/docs/my-website/img/admin_ui_disabled.png b/docs/my-website/img/admin_ui_disabled.png new file mode 100644 index 000000000..da2da2c55 Binary files /dev/null and b/docs/my-website/img/admin_ui_disabled.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 6dafb5478..813fc75a6 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -81,6 +81,7 @@ const sidebars = { "proxy/guardrails/lakera_ai", "proxy/guardrails/bedrock", "proxy/guardrails/pii_masking_v2", + "proxy/guardrails/secret_detection", "proxy/guardrails/custom_guardrail", "prompt_injection" ], diff --git a/enterprise/enterprise_hooks/secret_detection.py b/enterprise/enterprise_hooks/secret_detection.py index 2289858bd..0574d3a05 100644 --- a/enterprise/enterprise_hooks/secret_detection.py +++ b/enterprise/enterprise_hooks/secret_detection.py @@ -5,39 +5,24 @@ # +-------------------------------------------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan -import sys, os +import sys +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from typing import Optional, Literal, Union -import litellm, traceback, sys, uuid +from typing import Optional from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth -from litellm.integrations.custom_logger import CustomLogger -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger -from litellm.utils import ( - ModelResponse, - EmbeddingResponse, - ImageResponse, - StreamingChoices, -) -from datetime import datetime -import aiohttp, asyncio from litellm._logging import verbose_proxy_logger import tempfile -from litellm._logging import verbose_proxy_logger - - -litellm.set_verbose = True +from litellm.integrations.custom_guardrail import CustomGuardrail GUARDRAIL_NAME = "hide_secrets" _custom_plugins_path = "file://" + os.path.join( os.path.dirname(os.path.abspath(__file__)), "secrets_plugins" ) -print("custom plugins path", _custom_plugins_path) _default_detect_secrets_config = { "plugins_used": [ {"name": "SoftlayerDetector"}, @@ -434,9 +419,10 @@ _default_detect_secrets_config = { } -class _ENTERPRISE_SecretDetection(CustomLogger): - def __init__(self): - pass +class _ENTERPRISE_SecretDetection(CustomGuardrail): + def __init__(self, detect_secrets_config: Optional[dict] = None, **kwargs): + self.user_defined_detect_secrets_config = detect_secrets_config + super().__init__(**kwargs) def scan_message_for_secrets(self, message_content: str): from detect_secrets import SecretsCollection @@ -447,7 +433,11 @@ class _ENTERPRISE_SecretDetection(CustomLogger): temp_file.close() secrets = SecretsCollection() - with transient_settings(_default_detect_secrets_config): + + detect_secrets_config = ( + self.user_defined_detect_secrets_config or _default_detect_secrets_config + ) + with transient_settings(detect_secrets_config): secrets.scan_file(temp_file.name) os.remove(temp_file.name) @@ -484,9 +474,12 @@ class _ENTERPRISE_SecretDetection(CustomLogger): from detect_secrets import SecretsCollection from detect_secrets.settings import default_settings + print("INSIDE SECRET DETECTION PRE-CALL HOOK!") + if await self.should_run_check(user_api_key_dict) is False: return + print("RUNNING CHECK!") if "messages" in data and isinstance(data["messages"], list): for message in data["messages"]: if "content" in message and isinstance(message["content"], str): @@ -503,6 +496,8 @@ class _ENTERPRISE_SecretDetection(CustomLogger): verbose_proxy_logger.warning( f"Detected and redacted secrets in message: {secret_types}" ) + else: + verbose_proxy_logger.debug("No secrets detected on input.") if "prompt" in data: if isinstance(data["prompt"], str): diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 598de09be..0a935a290 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -2,7 +2,7 @@ ## File for 'response_cost' calculation in Logging import time import traceback -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union from pydantic import BaseModel @@ -100,7 +100,7 @@ def cost_per_token( "rerank", "arerank", ] = "completion", -) -> Tuple[float, float]: +) -> Tuple[float, float]: # type: ignore """ Calculates the cost per token for a given model, prompt tokens, and completion tokens. @@ -425,20 +425,24 @@ def get_model_params_and_category(model_name) -> str: return model_name -def get_replicate_completion_pricing(completion_response=None, total_time=0.0): +def get_replicate_completion_pricing(completion_response: dict, total_time=0.0): # see https://replicate.com/pricing # for all litellm currently supported LLMs, almost all requests go to a100_80gb a100_80gb_price_per_second_public = ( 0.001400 # assume all calls sent to A100 80GB for now ) if total_time == 0.0: # total time is in ms - start_time = completion_response["created"] + start_time = completion_response.get("created", time.time()) end_time = getattr(completion_response, "ended", time.time()) total_time = end_time - start_time return a100_80gb_price_per_second_public * total_time / 1000 +def has_hidden_params(obj: Any) -> bool: + return hasattr(obj, "_hidden_params") + + def _select_model_name_for_cost_calc( model: Optional[str], completion_response: Union[BaseModel, dict, str], @@ -463,12 +467,14 @@ def _select_model_name_for_cost_calc( elif return_model is None: return_model = completion_response.get("model", "") # type: ignore - if hasattr(completion_response, "_hidden_params"): + hidden_params = getattr(completion_response, "_hidden_params", None) + + if hidden_params is not None: if ( - completion_response._hidden_params.get("model", None) is not None - and len(completion_response._hidden_params["model"]) > 0 + hidden_params.get("model", None) is not None + and len(hidden_params["model"]) > 0 ): - return_model = completion_response._hidden_params.get("model", model) + return_model = hidden_params.get("model", model) return return_model @@ -558,7 +564,7 @@ def completion_cost( or isinstance(completion_response, dict) ): # tts returns a custom class - usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( + usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( # type: ignore "usage", {} ) if isinstance(usage_obj, BaseModel) and not isinstance( @@ -569,17 +575,17 @@ def completion_cost( "usage", litellm.Usage(**usage_obj.model_dump()), ) + if usage_obj is None: + _usage = {} + elif isinstance(usage_obj, BaseModel): + _usage = usage_obj.model_dump() + else: + _usage = usage_obj # get input/output tokens from completion_response - prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = completion_response.get("usage", {}).get( - "completion_tokens", 0 - ) - cache_creation_input_tokens = completion_response.get("usage", {}).get( - "cache_creation_input_tokens", 0 - ) - cache_read_input_tokens = completion_response.get("usage", {}).get( - "cache_read_input_tokens", 0 - ) + prompt_tokens = _usage.get("prompt_tokens", 0) + completion_tokens = _usage.get("completion_tokens", 0) + cache_creation_input_tokens = _usage.get("cache_creation_input_tokens", 0) + cache_read_input_tokens = _usage.get("cache_read_input_tokens", 0) total_time = getattr(completion_response, "_response_ms", 0) verbose_logger.debug( @@ -588,24 +594,19 @@ def completion_cost( model = _select_model_name_for_cost_calc( model=model, completion_response=completion_response ) - if hasattr(completion_response, "_hidden_params"): - custom_llm_provider = completion_response._hidden_params.get( + hidden_params = getattr(completion_response, "_hidden_params", None) + if hidden_params is not None: + custom_llm_provider = hidden_params.get( "custom_llm_provider", custom_llm_provider or None ) - region_name = completion_response._hidden_params.get( - "region_name", region_name - ) - size = completion_response._hidden_params.get( - "optional_params", {} - ).get( + region_name = hidden_params.get("region_name", region_name) + size = hidden_params.get("optional_params", {}).get( "size", "1024-x-1024" ) # openai default - quality = completion_response._hidden_params.get( - "optional_params", {} - ).get( + quality = hidden_params.get("optional_params", {}).get( "quality", "standard" ) # openai default - n = completion_response._hidden_params.get("optional_params", {}).get( + n = hidden_params.get("optional_params", {}).get( "n", 1 ) # openai default else: @@ -643,6 +644,8 @@ def completion_cost( # Vertex Charges Flat $0.20 per image return 0.020 + if size is None: + size = "1024-x-1024" # openai default # fix size to match naming convention if "x" in size and "-x-" not in size: size = size.replace("x", "-x-") @@ -697,7 +700,7 @@ def completion_cost( model in litellm.replicate_models or "replicate" in model ) and model not in litellm.model_cost: # for unmapped replicate model, default to replicate's time tracking logic - return get_replicate_completion_pricing(completion_response, total_time) + return get_replicate_completion_pricing(completion_response, total_time) # type: ignore if model is None: raise ValueError( @@ -847,7 +850,9 @@ def rerank_cost( Returns - float or None: cost of response OR none if error. """ - _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) + _, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider + ) try: if custom_llm_provider == "cohere": diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index 23c63e951..aa7f0bba2 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -45,6 +45,9 @@ class CustomBatchLogger(CustomLogger): await self.flush_queue() async def flush_queue(self): + if self.flush_lock is None: + return + async with self.flush_lock: if self.log_queue: verbose_logger.debug( @@ -54,5 +57,5 @@ class CustomBatchLogger(CustomLogger): self.log_queue.clear() self.last_flush_time = time.time() - async def async_send_batch(self): + async def async_send_batch(self, *args, **kwargs): pass diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 39c8f2b1e..1782ca08c 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -29,12 +29,13 @@ class CustomGuardrail(CustomLogger): ) if ( - self.guardrail_name not in requested_guardrails + self.event_hook + and self.guardrail_name not in requested_guardrails and event_type.value != "logging_only" ): return False - if self.event_hook != event_type.value: + if self.event_hook and self.event_hook != event_type.value: return False return True diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 3c7280f88..62ee16117 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -8,7 +8,7 @@ import traceback import types import uuid from datetime import datetime, timezone -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, TypedDict, Union import dotenv # type: ignore import httpx @@ -23,6 +23,7 @@ from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) +from litellm.types.utils import StandardLoggingPayload class LangsmithInputs(BaseModel): @@ -46,6 +47,12 @@ class LangsmithInputs(BaseModel): user_api_key_team_alias: Optional[str] = None +class LangsmithCredentialsObject(TypedDict): + LANGSMITH_API_KEY: str + LANGSMITH_PROJECT: str + LANGSMITH_BASE_URL: str + + def is_serializable(value): non_serializable_types = ( types.CoroutineType, @@ -57,15 +64,27 @@ def is_serializable(value): class LangsmithLogger(CustomBatchLogger): - def __init__(self, **kwargs): - self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY") - self.langsmith_project = os.getenv("LANGSMITH_PROJECT", "litellm-completion") + def __init__( + self, + langsmith_api_key: Optional[str] = None, + langsmith_project: Optional[str] = None, + langsmith_base_url: Optional[str] = None, + **kwargs, + ): + self.default_credentials = self.get_credentials_from_env( + langsmith_api_key=langsmith_api_key, + langsmith_project=langsmith_project, + langsmith_base_url=langsmith_base_url, + ) + self.sampling_rate: float = ( + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore + if os.getenv("LANGSMITH_SAMPLING_RATE") is not None + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore + else 1.0 + ) self.langsmith_default_run_name = os.getenv( "LANGSMITH_DEFAULT_RUN_NAME", "LLMRun" ) - self.langsmith_base_url = os.getenv( - "LANGSMITH_BASE_URL", "https://api.smith.langchain.com" - ) self.async_httpx_client = get_async_httpx_client( llm_provider=httpxSpecialProvider.LoggingCallback ) @@ -78,126 +97,160 @@ class LangsmithLogger(CustomBatchLogger): self.flush_lock = asyncio.Lock() super().__init__(**kwargs, flush_lock=self.flush_lock) + def get_credentials_from_env( + self, + langsmith_api_key: Optional[str], + langsmith_project: Optional[str], + langsmith_base_url: Optional[str], + ) -> LangsmithCredentialsObject: + + _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY") + if _credentials_api_key is None: + raise Exception( + "Invalid Langsmith API Key given. _credentials_api_key=None." + ) + _credentials_project = ( + langsmith_project or os.getenv("LANGSMITH_PROJECT") or "litellm-completion" + ) + if _credentials_project is None: + raise Exception( + "Invalid Langsmith API Key given. _credentials_project=None." + ) + _credentials_base_url = ( + langsmith_base_url + or os.getenv("LANGSMITH_BASE_URL") + or "https://api.smith.langchain.com" + ) + if _credentials_base_url is None: + raise Exception( + "Invalid Langsmith API Key given. _credentials_base_url=None." + ) + + return LangsmithCredentialsObject( + LANGSMITH_API_KEY=_credentials_api_key, + LANGSMITH_BASE_URL=_credentials_base_url, + LANGSMITH_PROJECT=_credentials_project, + ) + def _prepare_log_data(self, kwargs, response_obj, start_time, end_time): - import datetime + import json from datetime import datetime as dt - from datetime import timezone - - metadata = kwargs.get("litellm_params", {}).get("metadata", {}) or {} - new_metadata = {} - for key, value in metadata.items(): - if ( - isinstance(value, list) - or isinstance(value, str) - or isinstance(value, int) - or isinstance(value, float) - ): - new_metadata[key] = value - elif isinstance(value, BaseModel): - new_metadata[key] = value.model_dump_json() - elif isinstance(value, dict): - for k, v in value.items(): - if isinstance(v, dt): - value[k] = v.isoformat() - new_metadata[key] = value - - metadata = new_metadata - - kwargs["user_api_key"] = metadata.get("user_api_key", None) - kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None) - kwargs["user_api_key_team_alias"] = metadata.get( - "user_api_key_team_alias", None - ) - - project_name = metadata.get("project_name", self.langsmith_project) - run_name = metadata.get("run_name", self.langsmith_default_run_name) - run_id = metadata.get("id", None) - parent_run_id = metadata.get("parent_run_id", None) - trace_id = metadata.get("trace_id", None) - session_id = metadata.get("session_id", None) - dotted_order = metadata.get("dotted_order", None) - tags = metadata.get("tags", []) or [] - verbose_logger.debug( - f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" - ) try: - start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat() - end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat() - except: - start_time = datetime.datetime.utcnow().isoformat() - end_time = datetime.datetime.utcnow().isoformat() + _litellm_params = kwargs.get("litellm_params", {}) or {} + metadata = _litellm_params.get("metadata", {}) or {} + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = value + elif isinstance(value, BaseModel): + new_metadata[key] = value.model_dump_json() + elif isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, dt): + value[k] = v.isoformat() + new_metadata[key] = value - # filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs - logged_kwargs = LangsmithInputs(**kwargs) - kwargs = logged_kwargs.model_dump() + metadata = new_metadata - new_kwargs = {} - for key in kwargs: - value = kwargs[key] - if key == "start_time" or key == "end_time" or value is None: - pass - elif key == "original_response" and not isinstance(value, str): - new_kwargs[key] = str(value) - elif type(value) == datetime.datetime: - new_kwargs[key] = value.isoformat() - elif type(value) != dict and is_serializable(value=value): - new_kwargs[key] = value - elif not is_serializable(value=value): - continue + kwargs["user_api_key"] = metadata.get("user_api_key", None) + kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None) + kwargs["user_api_key_team_alias"] = metadata.get( + "user_api_key_team_alias", None + ) - if isinstance(response_obj, BaseModel): - try: - response_obj = response_obj.model_dump() - except: - response_obj = response_obj.dict() # type: ignore + project_name = metadata.get( + "project_name", self.default_credentials["LANGSMITH_PROJECT"] + ) + run_name = metadata.get("run_name", self.langsmith_default_run_name) + run_id = metadata.get("id", None) + parent_run_id = metadata.get("parent_run_id", None) + trace_id = metadata.get("trace_id", None) + session_id = metadata.get("session_id", None) + dotted_order = metadata.get("dotted_order", None) + tags = metadata.get("tags", []) or [] + verbose_logger.debug( + f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" + ) - data = { - "name": run_name, - "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" - "inputs": new_kwargs, - "outputs": response_obj, - "session_name": project_name, - "start_time": start_time, - "end_time": end_time, - "tags": tags, - "extra": metadata, - } + # filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs + # logged_kwargs = LangsmithInputs(**kwargs) + # kwargs = logged_kwargs.model_dump() - if run_id: - data["id"] = run_id + # new_kwargs = {} + # Ensure everything in the payload is converted to str + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) - if parent_run_id: - data["parent_run_id"] = parent_run_id + if payload is None: + raise Exception("Error logging request payload. Payload=none.") - if trace_id: - data["trace_id"] = trace_id + new_kwargs = payload + metadata = payload[ + "metadata" + ] # ensure logged metadata is json serializable - if session_id: - data["session_id"] = session_id + data = { + "name": run_name, + "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" + "inputs": new_kwargs, + "outputs": new_kwargs["response"], + "session_name": project_name, + "start_time": new_kwargs["startTime"], + "end_time": new_kwargs["endTime"], + "tags": tags, + "extra": metadata, + } - if dotted_order: - data["dotted_order"] = dotted_order + if payload["error_str"] is not None and payload["status"] == "failure": + data["error"] = payload["error_str"] - if "id" not in data or data["id"] is None: - """ - for /batch langsmith requires id, trace_id and dotted_order passed as params - """ - run_id = uuid.uuid4() - data["id"] = str(run_id) - data["trace_id"] = str(run_id) - data["dotted_order"] = self.make_dot_order(run_id=run_id) + if run_id: + data["id"] = run_id - verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) + if parent_run_id: + data["parent_run_id"] = parent_run_id - return data + if trace_id: + data["trace_id"] = trace_id + + if session_id: + data["session_id"] = session_id + + if dotted_order: + data["dotted_order"] = dotted_order + + if "id" not in data or data["id"] is None: + """ + for /batch langsmith requires id, trace_id and dotted_order passed as params + """ + run_id = str(uuid.uuid4()) + data["id"] = str(run_id) + data["trace_id"] = str(run_id) + data["dotted_order"] = self.make_dot_order(run_id=run_id) + + verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) + + return data + except Exception: + raise def _send_batch(self): if not self.log_queue: return - url = f"{self.langsmith_base_url}/runs/batch" - headers = {"x-api-key": self.langsmith_api_key} + langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] + langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] + + url = f"{langsmith_api_base}/runs/batch" + + headers = {"x-api-key": langsmith_api_key} try: response = requests.post( @@ -216,15 +269,15 @@ class LangsmithLogger(CustomBatchLogger): ) self.log_queue.clear() - except Exception as e: - verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}") + except Exception: + verbose_logger.exception("Langsmith Layer Error - Error sending batch.") def log_success_event(self, kwargs, response_obj, start_time, end_time): try: sampling_rate = ( - float(os.getenv("LANGSMITH_SAMPLING_RATE")) + float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore if os.getenv("LANGSMITH_SAMPLING_RATE") is not None - and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() + and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore else 1.0 ) random_sample = random.random() @@ -249,17 +302,12 @@ class LangsmithLogger(CustomBatchLogger): if len(self.log_queue) >= self.batch_size: self._send_batch() - except: - verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}") + except Exception: + verbose_logger.exception("Langsmith Layer Error - log_success_event error") async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: - sampling_rate = ( - float(os.getenv("LANGSMITH_SAMPLING_RATE")) - if os.getenv("LANGSMITH_SAMPLING_RATE") is not None - and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() - else 1.0 - ) + sampling_rate = self.sampling_rate random_sample = random.random() if random_sample > sampling_rate: verbose_logger.info( @@ -282,8 +330,36 @@ class LangsmithLogger(CustomBatchLogger): ) if len(self.log_queue) >= self.batch_size: await self.flush_queue() - except: - verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}") + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async success event." + ) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + sampling_rate = self.sampling_rate + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.info("Langsmith Failure Event Logging!") + try: + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + self.log_queue.append(data) + verbose_logger.debug( + "Langsmith logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, + ) + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception: + verbose_logger.exception( + "Langsmith Layer Error - error logging async failure event." + ) async def async_send_batch(self): """ @@ -295,13 +371,16 @@ class LangsmithLogger(CustomBatchLogger): Raises: Does not raise an exception, will only verbose_logger.exception() """ - import json - if not self.log_queue: return - url = f"{self.langsmith_base_url}/runs/batch" - headers = {"x-api-key": self.langsmith_api_key} + langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] + + url = f"{langsmith_api_base}/runs/batch" + + langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] + + headers = {"x-api-key": langsmith_api_key} try: response = await self.async_httpx_client.post( @@ -332,10 +411,14 @@ class LangsmithLogger(CustomBatchLogger): def get_run_by_id(self, run_id): - url = f"{self.langsmith_base_url}/runs/{run_id}" + langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] + + langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] + + url = f"{langsmith_api_base}/runs/{run_id}" response = requests.get( url=url, - headers={"x-api-key": self.langsmith_api_key}, + headers={"x-api-key": langsmith_api_key}, ) return response.json() diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 3992614c8..f97947941 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -43,6 +43,7 @@ from litellm.types.utils import ( StandardLoggingMetadata, StandardLoggingModelInformation, StandardLoggingPayload, + StandardLoggingPayloadStatus, StandardPassThroughResponseObject, TextCompletionResponse, TranscriptionResponse, @@ -668,6 +669,7 @@ class Logging: start_time=start_time, end_time=end_time, logging_obj=self, + status="success", ) ) elif isinstance(result, dict): # pass-through endpoints @@ -679,6 +681,7 @@ class Logging: start_time=start_time, end_time=end_time, logging_obj=self, + status="success", ) ) else: # streaming chunks + image gen. @@ -762,6 +765,7 @@ class Logging: start_time=start_time, end_time=end_time, logging_obj=self, + status="success", ) ) if self.dynamic_success_callbacks is not None and isinstance( @@ -1390,6 +1394,7 @@ class Logging: start_time=start_time, end_time=end_time, logging_obj=self, + status="success", ) ) if self.dynamic_async_success_callbacks is not None and isinstance( @@ -1645,6 +1650,20 @@ class Logging: self.model_call_details["litellm_params"].get("metadata", {}) or {} ) metadata.update(exception.headers) + + ## STANDARDIZED LOGGING PAYLOAD + + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + ) + ) return start_time, end_time async def special_failure_handlers(self, exception: Exception): @@ -2347,10 +2366,12 @@ def is_valid_sha256_hash(value: str) -> bool: def get_standard_logging_object_payload( kwargs: Optional[dict], - init_response_obj: Any, + init_response_obj: Union[Any, BaseModel, dict], start_time: dt_object, end_time: dt_object, logging_obj: Logging, + status: StandardLoggingPayloadStatus, + error_str: Optional[str] = None, ) -> Optional[StandardLoggingPayload]: try: if kwargs is None: @@ -2467,7 +2488,7 @@ def get_standard_logging_object_payload( custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) model_cost_name = _select_model_name_for_cost_calc( model=None, - completion_response=init_response_obj, + completion_response=init_response_obj, # type: ignore base_model=base_model, custom_pricing=custom_pricing, ) @@ -2498,6 +2519,7 @@ def get_standard_logging_object_payload( id=str(id), call_type=call_type or "", cache_hit=cache_hit, + status=status, saved_cache_cost=saved_cache_cost, startTime=start_time_float, endTime=end_time_float, @@ -2517,11 +2539,12 @@ def get_standard_logging_object_payload( requester_ip_address=clean_metadata.get("requester_ip_address", None), messages=kwargs.get("messages"), response=( # type: ignore - response_obj if len(response_obj.keys()) > 0 else init_response_obj + response_obj if len(response_obj.keys()) > 0 else init_response_obj # type: ignore ), model_parameters=kwargs.get("optional_params", None), hidden_params=clean_hidden_params, model_map_information=model_cost_information, + error_str=error_str, ) verbose_logger.debug( diff --git a/litellm/llms/azure_ai/README.md b/litellm/llms/azure_ai/README.md new file mode 100644 index 000000000..8c521519d --- /dev/null +++ b/litellm/llms/azure_ai/README.md @@ -0,0 +1 @@ +`/chat/completion` calls routed via `openai.py`. \ No newline at end of file diff --git a/litellm/llms/azure_ai/rerank/__init__.py b/litellm/llms/azure_ai/rerank/__init__.py new file mode 100644 index 000000000..a25d34b1c --- /dev/null +++ b/litellm/llms/azure_ai/rerank/__init__.py @@ -0,0 +1 @@ +from .handler import AzureAIRerank diff --git a/litellm/llms/azure_ai/rerank/handler.py b/litellm/llms/azure_ai/rerank/handler.py new file mode 100644 index 000000000..523448eec --- /dev/null +++ b/litellm/llms/azure_ai/rerank/handler.py @@ -0,0 +1,52 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.cohere.rerank import CohereRerank +from litellm.rerank_api.types import RerankResponse + + +class AzureAIRerank(CohereRerank): + def rerank( + self, + model: str, + api_key: str, + api_base: str, + query: str, + documents: List[Union[str, Dict[str, Any]]], + headers: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + _is_async: Optional[bool] = False, + ) -> RerankResponse: + + if headers is None: + headers = {"Authorization": "Bearer {}".format(api_key)} + else: + headers = {**headers, "Authorization": "Bearer {}".format(api_key)} + + # Assuming api_base is a string representing the base URL + api_base_url = httpx.URL(api_base) + + # Replace the path with '/v1/rerank' if it doesn't already end with it + if not api_base_url.path.endswith("/v1/rerank"): + api_base = str(api_base_url.copy_with(path="/v1/rerank")) + + return super().rerank( + model=model, + api_key=api_key, + api_base=api_base, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + _is_async=_is_async, + headers=headers, + litellm_logging_obj=litellm_logging_obj, + ) diff --git a/litellm/llms/azure_ai/rerank/transformation.py b/litellm/llms/azure_ai/rerank/transformation.py new file mode 100644 index 000000000..b5aad0ca2 --- /dev/null +++ b/litellm/llms/azure_ai/rerank/transformation.py @@ -0,0 +1,3 @@ +""" +Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. +""" diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 64afcae4f..069cf3968 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -10,6 +10,7 @@ import httpx from pydantic import BaseModel import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, @@ -19,6 +20,23 @@ from litellm.rerank_api.types import RerankRequest, RerankResponse class CohereRerank(BaseLLM): + def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict: + default_headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"bearer {api_key}", + } + + if headers is None: + return default_headers + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + def rerank( self, model: str, @@ -26,12 +44,16 @@ class CohereRerank(BaseLLM): api_base: str, query: str, documents: List[Union[str, Dict[str, Any]]], + headers: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, # New parameter ) -> RerankResponse: + headers = self.validate_environment(api_key=api_key, headers=headers) + request_data = RerankRequest( model=model, query=query, @@ -45,16 +67,22 @@ class CohereRerank(BaseLLM): request_data_dict = request_data.dict(exclude_none=True) if _is_async: - return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method + return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method + ## LOGGING + litellm_logging_obj.pre_call( + input=request_data_dict, + api_key=api_key, + additional_args={ + "complete_input_dict": request_data_dict, + "api_base": api_base, + "headers": headers, + }, + ) client = _get_httpx_client() response = client.post( api_base, - headers={ - "accept": "application/json", - "content-type": "application/json", - "Authorization": f"bearer {api_key}", - }, + headers=headers, json=request_data_dict, ) @@ -65,16 +93,13 @@ class CohereRerank(BaseLLM): request_data_dict: Dict[str, Any], api_key: str, api_base: str, + headers: dict, ) -> RerankResponse: client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) response = await client.post( api_base, - headers={ - "accept": "application/json", - "content-type": "application/json", - "Authorization": f"bearer {api_key}", - }, + headers=headers, json=request_data_dict, ) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 3502a786b..c039b94af 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -20,30 +20,15 @@ model_list: litellm_params: model: o1-preview -litellm_settings: - drop_params: True - json_logs: True - store_audit_logs: True - log_raw_request_response: True - return_response_headers: True - num_retries: 5 - request_timeout: 200 - callbacks: ["custom_callbacks.proxy_handler_instance"] - guardrails: - - guardrail_name: "presidio-pre-guard" + - guardrail_name: "hide-secrets" litellm_params: - guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio" - mode: "logging_only" - mock_redacted_text: { - "text": "My name is , who are you? Say my name in your response", - "items": [ - { - "start": 11, - "end": 19, - "entity_type": "PERSON", - "text": "", - "operator": "replace", - } - ], - } \ No newline at end of file + guardrail: "hide-secrets" # supported values: "aporia", "lakera" + mode: "pre_call" + # detect_secrets_config: { + # "plugins_used": [ + # {"name": "SoftlayerDetector"}, + # {"name": "StripeDetector"}, + # {"name": "NpmDetector"} + # ] + # } \ No newline at end of file diff --git a/litellm/proxy/common_utils/admin_ui_utils.py b/litellm/proxy/common_utils/admin_ui_utils.py index 3845c78ce..bd45fc627 100644 --- a/litellm/proxy/common_utils/admin_ui_utils.py +++ b/litellm/proxy/common_utils/admin_ui_utils.py @@ -166,3 +166,76 @@ def missing_keys_form(missing_key_names: str): """ return missing_keys_html_form.format(missing_keys=missing_key_names) + + +def admin_ui_disabled(): + from fastapi.responses import HTMLResponse + + ui_disabled_html = """ + + + + + + + Admin UI Disabled + + +
+

Admin UI is Disabled

+

The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:

+
+    DISABLE_ADMIN_UI="False" # Set this to "False" to enable the Admin UI.
+                
+

After making this change, restart the application for it to take effect.

+
+ +
+

Need Help? Support

+

Discord: https://discord.com/invite/wuPM9dRgDw

+

Docs: https://docs.litellm.ai/docs/

+
+ + + """ + + return HTMLResponse( + content=ui_disabled_html, + status_code=200, + ) diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index c46300990..9c8443119 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -189,6 +189,18 @@ def init_guardrails_v2( litellm.callbacks.append(_success_callback) # type: ignore litellm.callbacks.append(_presidio_callback) # type: ignore + elif litellm_params["guardrail"] == "hide-secrets": + from enterprise.enterprise_hooks.secret_detection import ( + _ENTERPRISE_SecretDetection, + ) + + _secret_detection_object = _ENTERPRISE_SecretDetection( + detect_secrets_config=litellm_params.get("detect_secrets_config"), + event_hook=litellm_params["mode"], + guardrail_name=guardrail["guardrail_name"], + ) + + litellm.callbacks.append(_secret_detection_object) # type: ignore elif ( isinstance(litellm_params["guardrail"], str) and "." in litellm_params["guardrail"] diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 7bd1c652f..e0fb79b15 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -19,12 +19,8 @@ model_list: model: openai/429 api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app + tags: ["fake"] general_settings: - master_key: sk-1234 -litellm_settings: - success_callback: ["datadog"] - service_callback: ["datadog"] - cache: True - + master_key: sk-1234 \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index de1530071..20dab118b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -139,6 +139,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth ## Import All Misc routes here ## from litellm.proxy.caching_routes import router as caching_router from litellm.proxy.common_utils.admin_ui_utils import ( + admin_ui_disabled, html_form, show_missing_vars_in_env, ) @@ -250,7 +251,7 @@ from litellm.secret_managers.aws_secret_manager import ( load_aws_secret_manager, ) from litellm.secret_managers.google_kms import load_google_kms -from litellm.secret_managers.main import get_secret +from litellm.secret_managers.main import get_secret, str_to_bool from litellm.types.llms.anthropic import ( AnthropicMessagesRequest, AnthropicResponse, @@ -652,7 +653,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): return try: - from azure.identity import ClientSecretCredential + from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.keyvault.secrets import SecretClient # Set your Azure Key Vault URI @@ -670,9 +671,10 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): and tenant_id is not None ): # Initialize the ClientSecretCredential - credential = ClientSecretCredential( - client_id=client_id, client_secret=client_secret, tenant_id=tenant_id - ) + # credential = ClientSecretCredential( + # client_id=client_id, client_secret=client_secret, tenant_id=tenant_id + # ) + credential = DefaultAzureCredential() # Create the SecretClient using the credential client = SecretClient(vault_url=KVUri, credential=credential) @@ -7967,6 +7969,13 @@ async def google_login(request: Request): google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + ####### Check if UI is disabled ####### + _disable_ui_flag = os.getenv("DISABLE_ADMIN_UI") + if _disable_ui_flag is not None: + is_disabled = str_to_bool(value=_disable_ui_flag) + if is_disabled: + return admin_ui_disabled() + ####### Check if user is a Enterprise / Premium User ####### if ( microsoft_client_id is not None diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index f0ee7ea9f..44ae71b15 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -414,6 +414,7 @@ class ProxyLogging: is not True ): continue + response = await _callback.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], @@ -468,7 +469,9 @@ class ProxyLogging: ################################################################ # V1 implementation - backwards compatibility - if callback.event_hook is None: + if callback.event_hook is None and hasattr( + callback, "moderation_check" + ): if callback.moderation_check == "pre_call": # type: ignore return else: @@ -975,12 +978,13 @@ class PrismaClient: ] required_view = "LiteLLM_VerificationTokenView" expected_views_str = ", ".join(f"'{view}'" for view in expected_views) + pg_schema = os.getenv("DATABASE_SCHEMA", "public") ret = await self.db.query_raw( f""" WITH existing_views AS ( SELECT viewname FROM pg_views - WHERE schemaname = 'public' AND viewname IN ( + WHERE schemaname = '{pg_schema}' AND viewname IN ( {expected_views_str} ) ) diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 462208cfc..d58e3c34f 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -5,11 +5,13 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union import litellm from litellm._logging import verbose_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.router import * -from litellm.utils import client, supports_httpx_timeout +from litellm.utils import client, exception_type, supports_httpx_timeout from .types import RerankRequest, RerankResponse @@ -17,6 +19,7 @@ from .types import RerankRequest, RerankResponse # Initialize any necessary instances or variables here cohere_rerank = CohereRerank() together_rerank = TogetherAIRerank() +azure_ai_rerank = AzureAIRerank() ################################################# @@ -70,7 +73,7 @@ def rerank( model: str, query: str, documents: List[Union[str, Dict[str, Any]]], - custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, + custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, @@ -80,11 +83,18 @@ def rerank( """ Reranks a list of documents based on their relevance to the query """ + headers: Optional[dict] = kwargs.get("headers") # type: ignore + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore + litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + user = kwargs.get("user", None) try: _is_async = kwargs.pop("arerank", False) is True optional_params = GenericLiteLLMParams(**kwargs) - model, _custom_llm_provider, dynamic_api_key, api_base = ( + model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( litellm.get_llm_provider( model=model, custom_llm_provider=custom_llm_provider, @@ -93,31 +103,52 @@ def rerank( ) ) + litellm_logging_obj.update_environment_variables( + model=model, + user=user, + optional_params=optional_params.model_dump(), + litellm_params={ + "litellm_call_id": litellm_call_id, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + }, + custom_llm_provider=_custom_llm_provider, + ) + # Implement rerank logic here based on the custom_llm_provider if _custom_llm_provider == "cohere": # Implement Cohere rerank logic - cohere_key = ( + api_key: Optional[str] = ( dynamic_api_key or optional_params.api_key or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") + or get_secret("COHERE_API_KEY") # type: ignore + or get_secret("CO_API_KEY") # type: ignore or litellm.api_key ) - if cohere_key is None: + if api_key is None: raise ValueError( "Cohere API key is required, please set 'COHERE_API_KEY' in your environment" ) - api_base = ( - optional_params.api_base + api_base: Optional[str] = ( + dynamic_api_base + or optional_params.api_base or litellm.api_base - or get_secret("COHERE_API_BASE") + or get_secret("COHERE_API_BASE") # type: ignore or "https://api.cohere.com/v1/rerank" ) - headers: Dict = litellm.headers or {} + if api_base is None: + raise Exception( + "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var." + ) + + headers = headers or litellm.headers or {} response = cohere_rerank.rerank( model=model, @@ -127,22 +158,72 @@ def rerank( rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, - api_key=cohere_key, + api_key=api_key, api_base=api_base, _is_async=_is_async, + headers=headers, + litellm_logging_obj=litellm_logging_obj, + ) + elif _custom_llm_provider == "azure_ai": + api_base = ( + dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or optional_params.api_base + or litellm.api_base + or get_secret("AZURE_AI_API_BASE") # type: ignore + ) + # set API KEY + api_key = ( + dynamic_api_key + or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("AZURE_AI_API_KEY") # type: ignore + ) + + headers = headers or litellm.headers or {} + + if api_key is None: + raise ValueError( + "Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment" + ) + + if api_base is None: + raise Exception( + "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var." + ) + + ## LOAD CONFIG - if set + config = litellm.OpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + response = azure_ai_rerank.rerank( + model=model, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + api_key=api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers, + litellm_logging_obj=litellm_logging_obj, ) - pass elif _custom_llm_provider == "together_ai": # Implement Together AI rerank logic - together_key = ( + api_key = ( dynamic_api_key or optional_params.api_key or litellm.togetherai_api_key - or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHERAI_API_KEY") # type: ignore or litellm.api_key ) - if together_key is None: + if api_key is None: raise ValueError( "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment" ) @@ -155,7 +236,7 @@ def rerank( rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, - api_key=together_key, + api_key=api_key, _is_async=_is_async, ) @@ -166,4 +247,6 @@ def rerank( return response except Exception as e: verbose_logger.error(f"Error in rerank: {str(e)}") - raise e + raise exception_type( + model=model, custom_llm_provider=custom_llm_provider, original_exception=e + ) diff --git a/litellm/router.py b/litellm/router.py index ec6159da4..2a3f583fa 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1290,7 +1290,7 @@ class Router: raise e async def _aimage_generation(self, prompt: str, model: str, **kwargs): - model_name = "" + model_name = model try: verbose_router_logger.debug( f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index be377b78b..e6d602e72 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -1351,6 +1351,37 @@ def test_logging_async_cache_hit_sync_call(): assert standard_logging_object["saved_cache_cost"] > 0 +def test_logging_standard_payload_failure_call(): + from litellm.types.utils import StandardLoggingPayload + + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + + with patch.object( + customHandler, "log_failure_event", new=MagicMock() + ) as mock_client: + try: + resp = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="litellm.RateLimitError", + ) + except litellm.RateLimitError: + pass + + mock_client.assert_called_once() + + assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"] + assert ( + mock_client.call_args.kwargs["kwargs"]["standard_logging_object"] + is not None + ) + + standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[ + "kwargs" + ]["standard_logging_object"] + + def test_logging_key_masking_gemini(): customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] diff --git a/litellm/tests/test_guardrails_config.py b/litellm/tests/test_guardrails_config.py index a086c8081..bd68f71e3 100644 --- a/litellm/tests/test_guardrails_config.py +++ b/litellm/tests/test_guardrails_config.py @@ -63,6 +63,7 @@ def test_guardrail_masking_logging_only(): assert response.choices[0].message.content == "Hi Peter!" # type: ignore + time.sleep(3) mock_call.assert_called_once() print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]) diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index 4d0fdfb34..c46f536a9 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -129,6 +129,47 @@ async def test_basic_rerank_together_ai(sync_mode): assert_response_shape(response, custom_llm_provider="together_ai") +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_basic_rerank_azure_ai(sync_mode): + import os + + litellm.set_verbose = True + + if sync_mode is True: + response = litellm.rerank( + model="azure_ai/Cohere-rerank-v3-multilingual-ko", + query="hello", + documents=["hello", "world"], + top_n=3, + api_key=os.getenv("AZURE_AI_COHERE_API_KEY"), + api_base=os.getenv("AZURE_AI_COHERE_API_BASE"), + ) + + print("re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="together_ai") + else: + response = await litellm.arerank( + model="azure_ai/Cohere-rerank-v3-multilingual-ko", + query="hello", + documents=["hello", "world"], + top_n=3, + api_key=os.getenv("AZURE_AI_COHERE_API_KEY"), + api_base=os.getenv("AZURE_AI_COHERE_API_BASE"), + ) + + print("async re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="together_ai") + + @pytest.mark.asyncio() async def test_rerank_custom_api_base(): mock_response = AsyncMock() diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 57be6b0c4..4bf606487 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -89,6 +89,9 @@ class LitellmParams(TypedDict): presidio_ad_hoc_recognizers: Optional[str] mock_redacted_text: Optional[dict] + # hide secrets params + detect_secrets_config: Optional[dict] + class Guardrail(TypedDict): guardrail_name: str diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 6d5da5c68..d606ffeef 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1278,10 +1278,14 @@ class StandardLoggingModelInformation(TypedDict): model_map_value: Optional[ModelInfo] +StandardLoggingPayloadStatus = Literal["success", "failure"] + + class StandardLoggingPayload(TypedDict): id: str call_type: str response_cost: float + status: StandardLoggingPayloadStatus total_tokens: int prompt_tokens: int completion_tokens: int @@ -1302,6 +1306,7 @@ class StandardLoggingPayload(TypedDict): requester_ip_address: Optional[str] messages: Optional[Union[str, list, dict]] response: Optional[Union[str, list, dict]] + error_str: Optional[str] model_parameters: dict hidden_params: StandardLoggingHiddenParams diff --git a/pyrightconfig.json b/pyrightconfig.json index 051dcb9fc..86a21c65e 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,6 +1,6 @@ { "ignore": [], - "exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"], + "exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"], "reportMissingImports": false } \ No newline at end of file