LiteLLM Minor Fixes & Improvements (09/17/2024) (#5742)

* fix(proxy_server.py): use default azure credentials to support azure non-client secret kms

* fix(langsmith.py): raise error if credentials missing

* feat(langsmith.py): support error logging for langsmith + standard logging payload

Fixes https://github.com/BerriAI/litellm/issues/5738

* Fix hardcoding of schema in view check (#5749)

* fix - deal with case when check view exists returns None (#5740)

* Revert "fix - deal with case when check view exists returns None (#5740)" (#5741)

This reverts commit 535228159b.

* test(test_router_debug_logs.py): move to mock response

* Fix hardcoding of schema

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>

* fix(proxy_server.py): allow admin to disable ui via `DISABLE_ADMIN_UI` flag

* fix(router.py): fix default model name value

Fixes 55db19a1e4 (r1763712148)

* fix(utils.py): fix unbound variable error

* feat(rerank/main.py): add azure ai rerank endpoints

Closes https://github.com/BerriAI/litellm/issues/5667

* feat(secret_detection.py): Allow configuring secret detection params

Allows admin to control what plugins to run for secret detection. Prevents overzealous secret detection.

* docs(secret_detection.md): add secret detection guardrail docs

* fix: fix linting errors

* fix - deal with case when check view exists returns None (#5740)

* Revert "fix - deal with case when check view exists returns None (#5740)" (#5741)

This reverts commit 535228159b.

* Litellm fix router testing (#5748)

* test: fix testing - azure changed content policy error logic

* test: fix tests to use mock responses

* test(test_image_generation.py): handle api instability

* test(test_image_generation.py): handle azure api instability

* fix(utils.py): fix unbounded variable error

* fix(utils.py): fix unbounded variable error

* test: refactor test to use mock response

* test: mark flaky azure tests

* Bump next from 14.1.1 to 14.2.10 in /ui/litellm-dashboard (#5753)

Bumps [next](https://github.com/vercel/next.js) from 14.1.1 to 14.2.10.
- [Release notes](https://github.com/vercel/next.js/releases)
- [Changelog](https://github.com/vercel/next.js/blob/canary/release.js)
- [Commits](https://github.com/vercel/next.js/compare/v14.1.1...v14.2.10)

---
updated-dependencies:
- dependency-name: next
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Fix] o1-mini causes pydantic warnings on `reasoning_tokens`  (#5754)

* add requester_metadata in standard logging payload

* log requester_metadata in metadata

* use StandardLoggingPayload for logging

* docs StandardLoggingPayload

* fix import

* include standard logging object in failure

* add test for requester metadata

* handle completion_tokens_details

* add test for completion_tokens_details

* [Feat-Proxy-DataDog] Log Redis, Postgres Failure events on DataDog  (#5750)

* dd - start tracking redis status on dd

* add async_service_succes_hook / failure hook in custom logger

* add async_service_failure_hook

* log service failures on dd

* fix import error

* add test for redis errors / warning

* [Fix] Router/ Proxy - Tag Based routing, raise correct error when no deployments found and tag filtering is on  (#5745)

* fix tag routing - raise correct error when no model with tag based routing

* fix error string from tag based routing

* test router tag based routing

* raise 401 error when no tags avialable for deploymen

* linting fix

* [Feat] Log Request metadata on gcs bucket logging (#5743)

* add requester_metadata in standard logging payload

* log requester_metadata in metadata

* use StandardLoggingPayload for logging

* docs StandardLoggingPayload

* fix import

* include standard logging object in failure

* add test for requester metadata

* fix(litellm_logging.py): fix logging message

* fix(rerank_api/main.py): fix linting errors

* fix(custom_guardrails.py): maintain backwards compatibility for older guardrails

* fix(rerank_api/main.py): fix cost tracking for rerank endpoints

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: steffen-sbt <148480574+steffen-sbt@users.noreply.github.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-17 23:00:04 -07:00 committed by GitHub
parent c5c64a6c04
commit 98c335acd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 1261 additions and 257 deletions

View file

@ -0,0 +1,557 @@
# ✨ Secret Detection/Redaction (Enterprise-only)
❓ Use this to REDACT API Keys, Secrets sent in requests to an LLM.
Example if you want to redact the value of `OPENAI_API_KEY` in the following request
#### Incoming Request
```json
{
"messages": [
{
"role": "user",
"content": "Hey, how's it going, API_KEY = 'sk_1234567890abcdef'",
}
]
}
```
#### Request after Moderation
```json
{
"messages": [
{
"role": "user",
"content": "Hey, how's it going, API_KEY = '[REDACTED]'",
}
]
}
```
**Usage**
**Step 1** Add this to your config.yaml
```yaml
guardrails:
- guardrail_name: "my-custom-name"
litellm_params:
guardrail: "hide-secrets" # supported values: "aporia", "lakera", ..
mode: "pre_call"
```
**Step 2** Run litellm proxy with `--detailed_debug` to see the server logs
```
litellm --config config.yaml --detailed_debug
```
**Step 3** Test it with request
Send this request
```shell
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "fake-claude-endpoint",
"messages": [
{
"role": "user",
"content": "what is the value of my open ai key? openai_api_key=sk-1234998222"
}
],
"guardrails": ["my-custom-name"]
}'
```
Expect to see the following warning on your litellm server logs
```shell
LiteLLM Proxy:WARNING: secret_detection.py:88 - Detected and redacted secrets in message: ['Secret Keyword']
```
You can also see the raw request sent from litellm to the API Provider with (`--detailed_debug`).
```json
POST Request Sent from LiteLLM:
curl -X POST \
https://api.groq.com/openai/v1/ \
-H 'Authorization: Bearer gsk_mySVchjY********************************************' \
-d {
"model": "llama3-8b-8192",
"messages": [
{
"role": "user",
"content": "what is the time today, openai_api_key=[REDACTED]"
}
],
"stream": false,
"extra_body": {}
}
```
## Turn on/off per project (API KEY/Team)
[**See Here**](./quick_start.md#-control-guardrails-per-project-api-key)
## Control secret detectors
LiteLLM uses the [`detect-secrets`](https://github.com/Yelp/detect-secrets) library for secret detection. See [all plugins run by default](#default-config-used)
### Usage
Here's how to control which plugins are run per request. This is useful if developers complain about secret detection impacting response quality.
**1. Set-up config.yaml**
```yaml
guardrails:
- guardrail_name: "hide-secrets"
litellm_params:
guardrail: "hide-secrets" # supported values: "aporia", "lakera"
mode: "pre_call"
detect_secrets_config: {
"plugins_used": [
{"name": "SoftlayerDetector"},
{"name": "StripeDetector"},
{"name": "NpmDetector"}
]
}
```
**2. Start proxy**
Run with `--detailed_debug` for more detailed logs. Use in dev only.
```bash
litellm --config /path/to/config.yaml --detailed_debug
```
**3. Test it!**
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "fake-claude-endpoint",
"messages": [
{
"role": "user",
"content": "what is the value of my open ai key? openai_api_key=sk-1234998222"
}
],
"guardrails": ["hide-secrets"]
}'
```
**Expected Logs**
Look for this in your logs, to confirm your changes worked as expected.
```
No secrets detected on input.
```
### Default Config Used
```
_default_detect_secrets_config = {
"plugins_used": [
{"name": "SoftlayerDetector"},
{"name": "StripeDetector"},
{"name": "NpmDetector"},
{"name": "IbmCosHmacDetector"},
{"name": "DiscordBotTokenDetector"},
{"name": "BasicAuthDetector"},
{"name": "AzureStorageKeyDetector"},
{"name": "ArtifactoryDetector"},
{"name": "AWSKeyDetector"},
{"name": "CloudantDetector"},
{"name": "IbmCloudIamDetector"},
{"name": "JwtTokenDetector"},
{"name": "MailchimpDetector"},
{"name": "SquareOAuthDetector"},
{"name": "PrivateKeyDetector"},
{"name": "TwilioKeyDetector"},
{
"name": "AdafruitKeyDetector",
"path": _custom_plugins_path + "/adafruit.py",
},
{
"name": "AdobeSecretDetector",
"path": _custom_plugins_path + "/adobe.py",
},
{
"name": "AgeSecretKeyDetector",
"path": _custom_plugins_path + "/age_secret_key.py",
},
{
"name": "AirtableApiKeyDetector",
"path": _custom_plugins_path + "/airtable_api_key.py",
},
{
"name": "AlgoliaApiKeyDetector",
"path": _custom_plugins_path + "/algolia_api_key.py",
},
{
"name": "AlibabaSecretDetector",
"path": _custom_plugins_path + "/alibaba.py",
},
{
"name": "AsanaSecretDetector",
"path": _custom_plugins_path + "/asana.py",
},
{
"name": "AtlassianApiTokenDetector",
"path": _custom_plugins_path + "/atlassian_api_token.py",
},
{
"name": "AuthressAccessKeyDetector",
"path": _custom_plugins_path + "/authress_access_key.py",
},
{
"name": "BittrexDetector",
"path": _custom_plugins_path + "/beamer_api_token.py",
},
{
"name": "BitbucketDetector",
"path": _custom_plugins_path + "/bitbucket.py",
},
{
"name": "BeamerApiTokenDetector",
"path": _custom_plugins_path + "/bittrex.py",
},
{
"name": "ClojarsApiTokenDetector",
"path": _custom_plugins_path + "/clojars_api_token.py",
},
{
"name": "CodecovAccessTokenDetector",
"path": _custom_plugins_path + "/codecov_access_token.py",
},
{
"name": "CoinbaseAccessTokenDetector",
"path": _custom_plugins_path + "/coinbase_access_token.py",
},
{
"name": "ConfluentDetector",
"path": _custom_plugins_path + "/confluent.py",
},
{
"name": "ContentfulApiTokenDetector",
"path": _custom_plugins_path + "/contentful_api_token.py",
},
{
"name": "DatabricksApiTokenDetector",
"path": _custom_plugins_path + "/databricks_api_token.py",
},
{
"name": "DatadogAccessTokenDetector",
"path": _custom_plugins_path + "/datadog_access_token.py",
},
{
"name": "DefinedNetworkingApiTokenDetector",
"path": _custom_plugins_path + "/defined_networking_api_token.py",
},
{
"name": "DigitaloceanDetector",
"path": _custom_plugins_path + "/digitalocean.py",
},
{
"name": "DopplerApiTokenDetector",
"path": _custom_plugins_path + "/doppler_api_token.py",
},
{
"name": "DroneciAccessTokenDetector",
"path": _custom_plugins_path + "/droneci_access_token.py",
},
{
"name": "DuffelApiTokenDetector",
"path": _custom_plugins_path + "/duffel_api_token.py",
},
{
"name": "DynatraceApiTokenDetector",
"path": _custom_plugins_path + "/dynatrace_api_token.py",
},
{
"name": "DiscordDetector",
"path": _custom_plugins_path + "/discord.py",
},
{
"name": "DropboxDetector",
"path": _custom_plugins_path + "/dropbox.py",
},
{
"name": "EasyPostDetector",
"path": _custom_plugins_path + "/easypost.py",
},
{
"name": "EtsyAccessTokenDetector",
"path": _custom_plugins_path + "/etsy_access_token.py",
},
{
"name": "FacebookAccessTokenDetector",
"path": _custom_plugins_path + "/facebook_access_token.py",
},
{
"name": "FastlyApiKeyDetector",
"path": _custom_plugins_path + "/fastly_api_token.py",
},
{
"name": "FinicityDetector",
"path": _custom_plugins_path + "/finicity.py",
},
{
"name": "FinnhubAccessTokenDetector",
"path": _custom_plugins_path + "/finnhub_access_token.py",
},
{
"name": "FlickrAccessTokenDetector",
"path": _custom_plugins_path + "/flickr_access_token.py",
},
{
"name": "FlutterwaveDetector",
"path": _custom_plugins_path + "/flutterwave.py",
},
{
"name": "FrameIoApiTokenDetector",
"path": _custom_plugins_path + "/frameio_api_token.py",
},
{
"name": "FreshbooksAccessTokenDetector",
"path": _custom_plugins_path + "/freshbooks_access_token.py",
},
{
"name": "GCPApiKeyDetector",
"path": _custom_plugins_path + "/gcp_api_key.py",
},
{
"name": "GitHubTokenCustomDetector",
"path": _custom_plugins_path + "/github_token.py",
},
{
"name": "GitLabDetector",
"path": _custom_plugins_path + "/gitlab.py",
},
{
"name": "GitterAccessTokenDetector",
"path": _custom_plugins_path + "/gitter_access_token.py",
},
{
"name": "GoCardlessApiTokenDetector",
"path": _custom_plugins_path + "/gocardless_api_token.py",
},
{
"name": "GrafanaDetector",
"path": _custom_plugins_path + "/grafana.py",
},
{
"name": "HashiCorpTFApiTokenDetector",
"path": _custom_plugins_path + "/hashicorp_tf_api_token.py",
},
{
"name": "HerokuApiKeyDetector",
"path": _custom_plugins_path + "/heroku_api_key.py",
},
{
"name": "HubSpotApiTokenDetector",
"path": _custom_plugins_path + "/hubspot_api_key.py",
},
{
"name": "HuggingFaceDetector",
"path": _custom_plugins_path + "/huggingface.py",
},
{
"name": "IntercomApiTokenDetector",
"path": _custom_plugins_path + "/intercom_api_key.py",
},
{
"name": "JFrogDetector",
"path": _custom_plugins_path + "/jfrog.py",
},
{
"name": "JWTBase64Detector",
"path": _custom_plugins_path + "/jwt.py",
},
{
"name": "KrakenAccessTokenDetector",
"path": _custom_plugins_path + "/kraken_access_token.py",
},
{
"name": "KucoinDetector",
"path": _custom_plugins_path + "/kucoin.py",
},
{
"name": "LaunchdarklyAccessTokenDetector",
"path": _custom_plugins_path + "/launchdarkly_access_token.py",
},
{
"name": "LinearDetector",
"path": _custom_plugins_path + "/linear.py",
},
{
"name": "LinkedInDetector",
"path": _custom_plugins_path + "/linkedin.py",
},
{
"name": "LobDetector",
"path": _custom_plugins_path + "/lob.py",
},
{
"name": "MailgunDetector",
"path": _custom_plugins_path + "/mailgun.py",
},
{
"name": "MapBoxApiTokenDetector",
"path": _custom_plugins_path + "/mapbox_api_token.py",
},
{
"name": "MattermostAccessTokenDetector",
"path": _custom_plugins_path + "/mattermost_access_token.py",
},
{
"name": "MessageBirdDetector",
"path": _custom_plugins_path + "/messagebird.py",
},
{
"name": "MicrosoftTeamsWebhookDetector",
"path": _custom_plugins_path + "/microsoft_teams_webhook.py",
},
{
"name": "NetlifyAccessTokenDetector",
"path": _custom_plugins_path + "/netlify_access_token.py",
},
{
"name": "NewRelicDetector",
"path": _custom_plugins_path + "/new_relic.py",
},
{
"name": "NYTimesAccessTokenDetector",
"path": _custom_plugins_path + "/nytimes_access_token.py",
},
{
"name": "OktaAccessTokenDetector",
"path": _custom_plugins_path + "/okta_access_token.py",
},
{
"name": "OpenAIApiKeyDetector",
"path": _custom_plugins_path + "/openai_api_key.py",
},
{
"name": "PlanetScaleDetector",
"path": _custom_plugins_path + "/planetscale.py",
},
{
"name": "PostmanApiTokenDetector",
"path": _custom_plugins_path + "/postman_api_token.py",
},
{
"name": "PrefectApiTokenDetector",
"path": _custom_plugins_path + "/prefect_api_token.py",
},
{
"name": "PulumiApiTokenDetector",
"path": _custom_plugins_path + "/pulumi_api_token.py",
},
{
"name": "PyPiUploadTokenDetector",
"path": _custom_plugins_path + "/pypi_upload_token.py",
},
{
"name": "RapidApiAccessTokenDetector",
"path": _custom_plugins_path + "/rapidapi_access_token.py",
},
{
"name": "ReadmeApiTokenDetector",
"path": _custom_plugins_path + "/readme_api_token.py",
},
{
"name": "RubygemsApiTokenDetector",
"path": _custom_plugins_path + "/rubygems_api_token.py",
},
{
"name": "ScalingoApiTokenDetector",
"path": _custom_plugins_path + "/scalingo_api_token.py",
},
{
"name": "SendbirdDetector",
"path": _custom_plugins_path + "/sendbird.py",
},
{
"name": "SendGridApiTokenDetector",
"path": _custom_plugins_path + "/sendgrid_api_token.py",
},
{
"name": "SendinBlueApiTokenDetector",
"path": _custom_plugins_path + "/sendinblue_api_token.py",
},
{
"name": "SentryAccessTokenDetector",
"path": _custom_plugins_path + "/sentry_access_token.py",
},
{
"name": "ShippoApiTokenDetector",
"path": _custom_plugins_path + "/shippo_api_token.py",
},
{
"name": "ShopifyDetector",
"path": _custom_plugins_path + "/shopify.py",
},
{
"name": "SlackDetector",
"path": _custom_plugins_path + "/slack.py",
},
{
"name": "SnykApiTokenDetector",
"path": _custom_plugins_path + "/snyk_api_token.py",
},
{
"name": "SquarespaceAccessTokenDetector",
"path": _custom_plugins_path + "/squarespace_access_token.py",
},
{
"name": "SumoLogicDetector",
"path": _custom_plugins_path + "/sumologic.py",
},
{
"name": "TelegramBotApiTokenDetector",
"path": _custom_plugins_path + "/telegram_bot_api_token.py",
},
{
"name": "TravisCiAccessTokenDetector",
"path": _custom_plugins_path + "/travisci_access_token.py",
},
{
"name": "TwitchApiTokenDetector",
"path": _custom_plugins_path + "/twitch_api_token.py",
},
{
"name": "TwitterDetector",
"path": _custom_plugins_path + "/twitter.py",
},
{
"name": "TypeformApiTokenDetector",
"path": _custom_plugins_path + "/typeform_api_token.py",
},
{
"name": "VaultDetector",
"path": _custom_plugins_path + "/vault.py",
},
{
"name": "YandexDetector",
"path": _custom_plugins_path + "/yandex.py",
},
{
"name": "ZendeskSecretKeyDetector",
"path": _custom_plugins_path + "/zendesk_secret_key.py",
},
{"name": "Base64HighEntropyString", "limit": 3.0},
{"name": "HexHighEntropyString", "limit": 3.0},
]
}
```

View file

@ -297,3 +297,14 @@ Set your colors to any of the following colors: https://www.tremor.so/docs/layou
- Deploy LiteLLM Proxy Server
## Disable Admin UI
Set `DISABLE_ADMIN_UI="True"` in your environment to disable the Admin UI.
Useful, if your security team has additional restrictions on UI usage.
**Expected Response**
<Image img={require('../../img/admin_ui_disabled.png')}/>

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

View file

@ -81,6 +81,7 @@ const sidebars = {
"proxy/guardrails/lakera_ai",
"proxy/guardrails/bedrock",
"proxy/guardrails/pii_masking_v2",
"proxy/guardrails/secret_detection",
"proxy/guardrails/custom_guardrail",
"prompt_injection"
],

View file

@ -5,39 +5,24 @@
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from typing import Optional
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger
import tempfile
from litellm._logging import verbose_proxy_logger
litellm.set_verbose = True
from litellm.integrations.custom_guardrail import CustomGuardrail
GUARDRAIL_NAME = "hide_secrets"
_custom_plugins_path = "file://" + os.path.join(
os.path.dirname(os.path.abspath(__file__)), "secrets_plugins"
)
print("custom plugins path", _custom_plugins_path)
_default_detect_secrets_config = {
"plugins_used": [
{"name": "SoftlayerDetector"},
@ -434,9 +419,10 @@ _default_detect_secrets_config = {
}
class _ENTERPRISE_SecretDetection(CustomLogger):
def __init__(self):
pass
class _ENTERPRISE_SecretDetection(CustomGuardrail):
def __init__(self, detect_secrets_config: Optional[dict] = None, **kwargs):
self.user_defined_detect_secrets_config = detect_secrets_config
super().__init__(**kwargs)
def scan_message_for_secrets(self, message_content: str):
from detect_secrets import SecretsCollection
@ -447,7 +433,11 @@ class _ENTERPRISE_SecretDetection(CustomLogger):
temp_file.close()
secrets = SecretsCollection()
with transient_settings(_default_detect_secrets_config):
detect_secrets_config = (
self.user_defined_detect_secrets_config or _default_detect_secrets_config
)
with transient_settings(detect_secrets_config):
secrets.scan_file(temp_file.name)
os.remove(temp_file.name)
@ -484,9 +474,12 @@ class _ENTERPRISE_SecretDetection(CustomLogger):
from detect_secrets import SecretsCollection
from detect_secrets.settings import default_settings
print("INSIDE SECRET DETECTION PRE-CALL HOOK!")
if await self.should_run_check(user_api_key_dict) is False:
return
print("RUNNING CHECK!")
if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]:
if "content" in message and isinstance(message["content"], str):
@ -503,6 +496,8 @@ class _ENTERPRISE_SecretDetection(CustomLogger):
verbose_proxy_logger.warning(
f"Detected and redacted secrets in message: {secret_types}"
)
else:
verbose_proxy_logger.debug("No secrets detected on input.")
if "prompt" in data:
if isinstance(data["prompt"], str):

View file

@ -2,7 +2,7 @@
## File for 'response_cost' calculation in Logging
import time
import traceback
from typing import List, Literal, Optional, Tuple, Union
from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel
@ -100,7 +100,7 @@ def cost_per_token(
"rerank",
"arerank",
] = "completion",
) -> Tuple[float, float]:
) -> Tuple[float, float]: # type: ignore
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -425,20 +425,24 @@ def get_model_params_and_category(model_name) -> str:
return model_name
def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
# see https://replicate.com/pricing
# for all litellm currently supported LLMs, almost all requests go to a100_80gb
a100_80gb_price_per_second_public = (
0.001400 # assume all calls sent to A100 80GB for now
)
if total_time == 0.0: # total time is in ms
start_time = completion_response["created"]
start_time = completion_response.get("created", time.time())
end_time = getattr(completion_response, "ended", time.time())
total_time = end_time - start_time
return a100_80gb_price_per_second_public * total_time / 1000
def has_hidden_params(obj: Any) -> bool:
return hasattr(obj, "_hidden_params")
def _select_model_name_for_cost_calc(
model: Optional[str],
completion_response: Union[BaseModel, dict, str],
@ -463,12 +467,14 @@ def _select_model_name_for_cost_calc(
elif return_model is None:
return_model = completion_response.get("model", "") # type: ignore
if hasattr(completion_response, "_hidden_params"):
hidden_params = getattr(completion_response, "_hidden_params", None)
if hidden_params is not None:
if (
completion_response._hidden_params.get("model", None) is not None
and len(completion_response._hidden_params["model"]) > 0
hidden_params.get("model", None) is not None
and len(hidden_params["model"]) > 0
):
return_model = completion_response._hidden_params.get("model", model)
return_model = hidden_params.get("model", model)
return return_model
@ -558,7 +564,7 @@ def completion_cost(
or isinstance(completion_response, dict)
): # tts returns a custom class
usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get(
usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( # type: ignore
"usage", {}
)
if isinstance(usage_obj, BaseModel) and not isinstance(
@ -569,17 +575,17 @@ def completion_cost(
"usage",
litellm.Usage(**usage_obj.model_dump()),
)
if usage_obj is None:
_usage = {}
elif isinstance(usage_obj, BaseModel):
_usage = usage_obj.model_dump()
else:
_usage = usage_obj
# get input/output tokens from completion_response
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = completion_response.get("usage", {}).get(
"completion_tokens", 0
)
cache_creation_input_tokens = completion_response.get("usage", {}).get(
"cache_creation_input_tokens", 0
)
cache_read_input_tokens = completion_response.get("usage", {}).get(
"cache_read_input_tokens", 0
)
prompt_tokens = _usage.get("prompt_tokens", 0)
completion_tokens = _usage.get("completion_tokens", 0)
cache_creation_input_tokens = _usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = _usage.get("cache_read_input_tokens", 0)
total_time = getattr(completion_response, "_response_ms", 0)
verbose_logger.debug(
@ -588,24 +594,19 @@ def completion_cost(
model = _select_model_name_for_cost_calc(
model=model, completion_response=completion_response
)
if hasattr(completion_response, "_hidden_params"):
custom_llm_provider = completion_response._hidden_params.get(
hidden_params = getattr(completion_response, "_hidden_params", None)
if hidden_params is not None:
custom_llm_provider = hidden_params.get(
"custom_llm_provider", custom_llm_provider or None
)
region_name = completion_response._hidden_params.get(
"region_name", region_name
)
size = completion_response._hidden_params.get(
"optional_params", {}
).get(
region_name = hidden_params.get("region_name", region_name)
size = hidden_params.get("optional_params", {}).get(
"size", "1024-x-1024"
) # openai default
quality = completion_response._hidden_params.get(
"optional_params", {}
).get(
quality = hidden_params.get("optional_params", {}).get(
"quality", "standard"
) # openai default
n = completion_response._hidden_params.get("optional_params", {}).get(
n = hidden_params.get("optional_params", {}).get(
"n", 1
) # openai default
else:
@ -643,6 +644,8 @@ def completion_cost(
# Vertex Charges Flat $0.20 per image
return 0.020
if size is None:
size = "1024-x-1024" # openai default
# fix size to match naming convention
if "x" in size and "-x-" not in size:
size = size.replace("x", "-x-")
@ -697,7 +700,7 @@ def completion_cost(
model in litellm.replicate_models or "replicate" in model
) and model not in litellm.model_cost:
# for unmapped replicate model, default to replicate's time tracking logic
return get_replicate_completion_pricing(completion_response, total_time)
return get_replicate_completion_pricing(completion_response, total_time) # type: ignore
if model is None:
raise ValueError(
@ -847,7 +850,9 @@ def rerank_cost(
Returns
- float or None: cost of response OR none if error.
"""
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
try:
if custom_llm_provider == "cohere":

View file

@ -45,6 +45,9 @@ class CustomBatchLogger(CustomLogger):
await self.flush_queue()
async def flush_queue(self):
if self.flush_lock is None:
return
async with self.flush_lock:
if self.log_queue:
verbose_logger.debug(
@ -54,5 +57,5 @@ class CustomBatchLogger(CustomLogger):
self.log_queue.clear()
self.last_flush_time = time.time()
async def async_send_batch(self):
async def async_send_batch(self, *args, **kwargs):
pass

View file

@ -29,12 +29,13 @@ class CustomGuardrail(CustomLogger):
)
if (
self.guardrail_name not in requested_guardrails
self.event_hook
and self.guardrail_name not in requested_guardrails
and event_type.value != "logging_only"
):
return False
if self.event_hook != event_type.value:
if self.event_hook and self.event_hook != event_type.value:
return False
return True

View file

@ -8,7 +8,7 @@ import traceback
import types
import uuid
from datetime import datetime, timezone
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, TypedDict, Union
import dotenv # type: ignore
import httpx
@ -23,6 +23,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload
class LangsmithInputs(BaseModel):
@ -46,6 +47,12 @@ class LangsmithInputs(BaseModel):
user_api_key_team_alias: Optional[str] = None
class LangsmithCredentialsObject(TypedDict):
LANGSMITH_API_KEY: str
LANGSMITH_PROJECT: str
LANGSMITH_BASE_URL: str
def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
@ -57,15 +64,27 @@ def is_serializable(value):
class LangsmithLogger(CustomBatchLogger):
def __init__(self, **kwargs):
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
self.langsmith_project = os.getenv("LANGSMITH_PROJECT", "litellm-completion")
def __init__(
self,
langsmith_api_key: Optional[str] = None,
langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None,
**kwargs,
):
self.default_credentials = self.get_credentials_from_env(
langsmith_api_key=langsmith_api_key,
langsmith_project=langsmith_project,
langsmith_base_url=langsmith_base_url,
)
self.sampling_rate: float = (
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
self.langsmith_default_run_name = os.getenv(
"LANGSMITH_DEFAULT_RUN_NAME", "LLMRun"
)
self.langsmith_base_url = os.getenv(
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
)
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
@ -78,126 +97,160 @@ class LangsmithLogger(CustomBatchLogger):
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
def get_credentials_from_env(
self,
langsmith_api_key: Optional[str],
langsmith_project: Optional[str],
langsmith_base_url: Optional[str],
) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
if _credentials_api_key is None:
raise Exception(
"Invalid Langsmith API Key given. _credentials_api_key=None."
)
_credentials_project = (
langsmith_project or os.getenv("LANGSMITH_PROJECT") or "litellm-completion"
)
if _credentials_project is None:
raise Exception(
"Invalid Langsmith API Key given. _credentials_project=None."
)
_credentials_base_url = (
langsmith_base_url
or os.getenv("LANGSMITH_BASE_URL")
or "https://api.smith.langchain.com"
)
if _credentials_base_url is None:
raise Exception(
"Invalid Langsmith API Key given. _credentials_base_url=None."
)
return LangsmithCredentialsObject(
LANGSMITH_API_KEY=_credentials_api_key,
LANGSMITH_BASE_URL=_credentials_base_url,
LANGSMITH_PROJECT=_credentials_project,
)
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time):
import datetime
import json
from datetime import datetime as dt
from datetime import timezone
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) or {}
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = value
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump_json()
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, dt):
value[k] = v.isoformat()
new_metadata[key] = value
metadata = new_metadata
kwargs["user_api_key"] = metadata.get("user_api_key", None)
kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
kwargs["user_api_key_team_alias"] = metadata.get(
"user_api_key_team_alias", None
)
project_name = metadata.get("project_name", self.langsmith_project)
run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None)
parent_run_id = metadata.get("parent_run_id", None)
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or []
verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
try:
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
except:
start_time = datetime.datetime.utcnow().isoformat()
end_time = datetime.datetime.utcnow().isoformat()
_litellm_params = kwargs.get("litellm_params", {}) or {}
metadata = _litellm_params.get("metadata", {}) or {}
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = value
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump_json()
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, dt):
value[k] = v.isoformat()
new_metadata[key] = value
# filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
logged_kwargs = LangsmithInputs(**kwargs)
kwargs = logged_kwargs.model_dump()
metadata = new_metadata
new_kwargs = {}
for key in kwargs:
value = kwargs[key]
if key == "start_time" or key == "end_time" or value is None:
pass
elif key == "original_response" and not isinstance(value, str):
new_kwargs[key] = str(value)
elif type(value) == datetime.datetime:
new_kwargs[key] = value.isoformat()
elif type(value) != dict and is_serializable(value=value):
new_kwargs[key] = value
elif not is_serializable(value=value):
continue
kwargs["user_api_key"] = metadata.get("user_api_key", None)
kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
kwargs["user_api_key_team_alias"] = metadata.get(
"user_api_key_team_alias", None
)
if isinstance(response_obj, BaseModel):
try:
response_obj = response_obj.model_dump()
except:
response_obj = response_obj.dict() # type: ignore
project_name = metadata.get(
"project_name", self.default_credentials["LANGSMITH_PROJECT"]
)
run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None)
parent_run_id = metadata.get("parent_run_id", None)
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or []
verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": new_kwargs,
"outputs": response_obj,
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
"tags": tags,
"extra": metadata,
}
# filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
# logged_kwargs = LangsmithInputs(**kwargs)
# kwargs = logged_kwargs.model_dump()
if run_id:
data["id"] = run_id
# new_kwargs = {}
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if parent_run_id:
data["parent_run_id"] = parent_run_id
if payload is None:
raise Exception("Error logging request payload. Payload=none.")
if trace_id:
data["trace_id"] = trace_id
new_kwargs = payload
metadata = payload[
"metadata"
] # ensure logged metadata is json serializable
if session_id:
data["session_id"] = session_id
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": new_kwargs,
"outputs": new_kwargs["response"],
"session_name": project_name,
"start_time": new_kwargs["startTime"],
"end_time": new_kwargs["endTime"],
"tags": tags,
"extra": metadata,
}
if dotted_order:
data["dotted_order"] = dotted_order
if payload["error_str"] is not None and payload["status"] == "failure":
data["error"] = payload["error_str"]
if "id" not in data or data["id"] is None:
"""
for /batch langsmith requires id, trace_id and dotted_order passed as params
"""
run_id = uuid.uuid4()
data["id"] = str(run_id)
data["trace_id"] = str(run_id)
data["dotted_order"] = self.make_dot_order(run_id=run_id)
if run_id:
data["id"] = run_id
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
if parent_run_id:
data["parent_run_id"] = parent_run_id
return data
if trace_id:
data["trace_id"] = trace_id
if session_id:
data["session_id"] = session_id
if dotted_order:
data["dotted_order"] = dotted_order
if "id" not in data or data["id"] is None:
"""
for /batch langsmith requires id, trace_id and dotted_order passed as params
"""
run_id = str(uuid.uuid4())
data["id"] = str(run_id)
data["trace_id"] = str(run_id)
data["dotted_order"] = self.make_dot_order(run_id=run_id)
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
return data
except Exception:
raise
def _send_batch(self):
if not self.log_queue:
return
url = f"{self.langsmith_base_url}/runs/batch"
headers = {"x-api-key": self.langsmith_api_key}
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/batch"
headers = {"x-api-key": langsmith_api_key}
try:
response = requests.post(
@ -216,15 +269,15 @@ class LangsmithLogger(CustomBatchLogger):
)
self.log_queue.clear()
except Exception as e:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}")
except Exception:
verbose_logger.exception("Langsmith Layer Error - Error sending batch.")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE"))
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit()
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
random_sample = random.random()
@ -249,17 +302,12 @@ class LangsmithLogger(CustomBatchLogger):
if len(self.log_queue) >= self.batch_size:
self._send_batch()
except:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}")
except Exception:
verbose_logger.exception("Langsmith Layer Error - log_success_event error")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE"))
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit()
else 1.0
)
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
@ -282,8 +330,36 @@ class LangsmithLogger(CustomBatchLogger):
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}")
except Exception:
verbose_logger.exception(
"Langsmith Layer Error - error logging async success event."
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.info("Langsmith Failure Event Logging!")
try:
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception:
verbose_logger.exception(
"Langsmith Layer Error - error logging async failure event."
)
async def async_send_batch(self):
"""
@ -295,13 +371,16 @@ class LangsmithLogger(CustomBatchLogger):
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
import json
if not self.log_queue:
return
url = f"{self.langsmith_base_url}/runs/batch"
headers = {"x-api-key": self.langsmith_api_key}
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/batch"
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
headers = {"x-api-key": langsmith_api_key}
try:
response = await self.async_httpx_client.post(
@ -332,10 +411,14 @@ class LangsmithLogger(CustomBatchLogger):
def get_run_by_id(self, run_id):
url = f"{self.langsmith_base_url}/runs/{run_id}"
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/{run_id}"
response = requests.get(
url=url,
headers={"x-api-key": self.langsmith_api_key},
headers={"x-api-key": langsmith_api_key},
)
return response.json()

View file

@ -43,6 +43,7 @@ from litellm.types.utils import (
StandardLoggingMetadata,
StandardLoggingModelInformation,
StandardLoggingPayload,
StandardLoggingPayloadStatus,
StandardPassThroughResponseObject,
TextCompletionResponse,
TranscriptionResponse,
@ -668,6 +669,7 @@ class Logging:
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
)
)
elif isinstance(result, dict): # pass-through endpoints
@ -679,6 +681,7 @@ class Logging:
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
)
)
else: # streaming chunks + image gen.
@ -762,6 +765,7 @@ class Logging:
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
)
)
if self.dynamic_success_callbacks is not None and isinstance(
@ -1390,6 +1394,7 @@ class Logging:
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
)
)
if self.dynamic_async_success_callbacks is not None and isinstance(
@ -1645,6 +1650,20 @@ class Logging:
self.model_call_details["litellm_params"].get("metadata", {}) or {}
)
metadata.update(exception.headers)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
)
)
return start_time, end_time
async def special_failure_handlers(self, exception: Exception):
@ -2347,10 +2366,12 @@ def is_valid_sha256_hash(value: str) -> bool:
def get_standard_logging_object_payload(
kwargs: Optional[dict],
init_response_obj: Any,
init_response_obj: Union[Any, BaseModel, dict],
start_time: dt_object,
end_time: dt_object,
logging_obj: Logging,
status: StandardLoggingPayloadStatus,
error_str: Optional[str] = None,
) -> Optional[StandardLoggingPayload]:
try:
if kwargs is None:
@ -2467,7 +2488,7 @@ def get_standard_logging_object_payload(
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
model_cost_name = _select_model_name_for_cost_calc(
model=None,
completion_response=init_response_obj,
completion_response=init_response_obj, # type: ignore
base_model=base_model,
custom_pricing=custom_pricing,
)
@ -2498,6 +2519,7 @@ def get_standard_logging_object_payload(
id=str(id),
call_type=call_type or "",
cache_hit=cache_hit,
status=status,
saved_cache_cost=saved_cache_cost,
startTime=start_time_float,
endTime=end_time_float,
@ -2517,11 +2539,12 @@ def get_standard_logging_object_payload(
requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"),
response=( # type: ignore
response_obj if len(response_obj.keys()) > 0 else init_response_obj
response_obj if len(response_obj.keys()) > 0 else init_response_obj # type: ignore
),
model_parameters=kwargs.get("optional_params", None),
hidden_params=clean_hidden_params,
model_map_information=model_cost_information,
error_str=error_str,
)
verbose_logger.debug(

View file

@ -0,0 +1 @@
`/chat/completion` calls routed via `openai.py`.

View file

@ -0,0 +1 @@
from .handler import AzureAIRerank

View file

@ -0,0 +1,52 @@
from typing import Any, Dict, List, Optional, Union
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank
from litellm.rerank_api.types import RerankResponse
class AzureAIRerank(CohereRerank):
def rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
if headers is None:
headers = {"Authorization": "Bearer {}".format(api_key)}
else:
headers = {**headers, "Authorization": "Bearer {}".format(api_key)}
# Assuming api_base is a string representing the base URL
api_base_url = httpx.URL(api_base)
# Replace the path with '/v1/rerank' if it doesn't already end with it
if not api_base_url.path.endswith("/v1/rerank"):
api_base = str(api_base_url.copy_with(path="/v1/rerank"))
return super().rerank(
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)

View file

@ -0,0 +1,3 @@
"""
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
"""

View file

@ -10,6 +10,7 @@ import httpx
from pydantic import BaseModel
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
@ -19,6 +20,23 @@ from litellm.rerank_api.types import RerankRequest, RerankResponse
class CohereRerank(BaseLLM):
def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict:
default_headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
}
if headers is None:
return default_headers
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def rerank(
self,
model: str,
@ -26,12 +44,16 @@ class CohereRerank(BaseLLM):
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, # New parameter
) -> RerankResponse:
headers = self.validate_environment(api_key=api_key, headers=headers)
request_data = RerankRequest(
model=model,
query=query,
@ -45,16 +67,22 @@ class CohereRerank(BaseLLM):
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method
return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
## LOGGING
litellm_logging_obj.pre_call(
input=request_data_dict,
api_key=api_key,
additional_args={
"complete_input_dict": request_data_dict,
"api_base": api_base,
"headers": headers,
},
)
client = _get_httpx_client()
response = client.post(
api_base,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
},
headers=headers,
json=request_data_dict,
)
@ -65,16 +93,13 @@ class CohereRerank(BaseLLM):
request_data_dict: Dict[str, Any],
api_key: str,
api_base: str,
headers: dict,
) -> RerankResponse:
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
response = await client.post(
api_base,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
},
headers=headers,
json=request_data_dict,
)

View file

@ -20,30 +20,15 @@ model_list:
litellm_params:
model: o1-preview
litellm_settings:
drop_params: True
json_logs: True
store_audit_logs: True
log_raw_request_response: True
return_response_headers: True
num_retries: 5
request_timeout: 200
callbacks: ["custom_callbacks.proxy_handler_instance"]
guardrails:
- guardrail_name: "presidio-pre-guard"
- guardrail_name: "hide-secrets"
litellm_params:
guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio"
mode: "logging_only"
mock_redacted_text: {
"text": "My name is <PERSON>, who are you? Say my name in your response",
"items": [
{
"start": 11,
"end": 19,
"entity_type": "PERSON",
"text": "<PERSON>",
"operator": "replace",
}
],
}
guardrail: "hide-secrets" # supported values: "aporia", "lakera"
mode: "pre_call"
# detect_secrets_config: {
# "plugins_used": [
# {"name": "SoftlayerDetector"},
# {"name": "StripeDetector"},
# {"name": "NpmDetector"}
# ]
# }

View file

@ -166,3 +166,76 @@ def missing_keys_form(missing_key_names: str):
</html>
"""
return missing_keys_html_form.format(missing_keys=missing_key_names)
def admin_ui_disabled():
from fastapi.responses import HTMLResponse
ui_disabled_html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Admin UI Disabled</title>
</head>
<body>
<div class="container">
<h1>Admin UI is Disabled</h1>
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
<pre>
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
</pre>
<p>After making this change, restart the application for it to take effect.</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return HTMLResponse(
content=ui_disabled_html,
status_code=200,
)

View file

@ -189,6 +189,18 @@ def init_guardrails_v2(
litellm.callbacks.append(_success_callback) # type: ignore
litellm.callbacks.append(_presidio_callback) # type: ignore
elif litellm_params["guardrail"] == "hide-secrets":
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
_secret_detection_object = _ENTERPRISE_SecretDetection(
detect_secrets_config=litellm_params.get("detect_secrets_config"),
event_hook=litellm_params["mode"],
guardrail_name=guardrail["guardrail_name"],
)
litellm.callbacks.append(_secret_detection_object) # type: ignore
elif (
isinstance(litellm_params["guardrail"], str)
and "." in litellm_params["guardrail"]

View file

@ -19,12 +19,8 @@ model_list:
model: openai/429
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app
tags: ["fake"]
general_settings:
master_key: sk-1234
litellm_settings:
success_callback: ["datadog"]
service_callback: ["datadog"]
cache: True

View file

@ -139,6 +139,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_utils.admin_ui_utils import (
admin_ui_disabled,
html_form,
show_missing_vars_in_env,
)
@ -250,7 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.main import get_secret, str_to_bool
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
@ -652,7 +653,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
return
try:
from azure.identity import ClientSecretCredential
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
# Set your Azure Key Vault URI
@ -670,9 +671,10 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
and tenant_id is not None
):
# Initialize the ClientSecretCredential
credential = ClientSecretCredential(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
)
# credential = ClientSecretCredential(
# client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
# )
credential = DefaultAzureCredential()
# Create the SecretClient using the credential
client = SecretClient(vault_url=KVUri, credential=credential)
@ -7967,6 +7969,13 @@ async def google_login(request: Request):
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
####### Check if UI is disabled #######
_disable_ui_flag = os.getenv("DISABLE_ADMIN_UI")
if _disable_ui_flag is not None:
is_disabled = str_to_bool(value=_disable_ui_flag)
if is_disabled:
return admin_ui_disabled()
####### Check if user is a Enterprise / Premium User #######
if (
microsoft_client_id is not None

View file

@ -414,6 +414,7 @@ class ProxyLogging:
is not True
):
continue
response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
@ -468,7 +469,9 @@ class ProxyLogging:
################################################################
# V1 implementation - backwards compatibility
if callback.event_hook is None:
if callback.event_hook is None and hasattr(
callback, "moderation_check"
):
if callback.moderation_check == "pre_call": # type: ignore
return
else:
@ -975,12 +978,13 @@ class PrismaClient:
]
required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
ret = await self.db.query_raw(
f"""
WITH existing_views AS (
SELECT viewname
FROM pg_views
WHERE schemaname = 'public' AND viewname IN (
WHERE schemaname = '{pg_schema}' AND viewname IN (
{expected_views_str}
)
)

View file

@ -5,11 +5,13 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.togetherai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.router import *
from litellm.utils import client, supports_httpx_timeout
from litellm.utils import client, exception_type, supports_httpx_timeout
from .types import RerankRequest, RerankResponse
@ -17,6 +19,7 @@ from .types import RerankRequest, RerankResponse
# Initialize any necessary instances or variables here
cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank()
#################################################
@ -70,7 +73,7 @@ def rerank(
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
@ -80,11 +83,18 @@ def rerank(
"""
Reranks a list of documents based on their relevance to the query
"""
headers: Optional[dict] = kwargs.get("headers") # type: ignore
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None)
try:
_is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, api_base = (
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
@ -93,31 +103,52 @@ def rerank(
)
)
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=optional_params.model_dump(),
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=_custom_llm_provider,
)
# Implement rerank logic here based on the custom_llm_provider
if _custom_llm_provider == "cohere":
# Implement Cohere rerank logic
cohere_key = (
api_key: Optional[str] = (
dynamic_api_key
or optional_params.api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or get_secret("COHERE_API_KEY") # type: ignore
or get_secret("CO_API_KEY") # type: ignore
or litellm.api_key
)
if cohere_key is None:
if api_key is None:
raise ValueError(
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
)
api_base = (
optional_params.api_base
api_base: Optional[str] = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("COHERE_API_BASE")
or get_secret("COHERE_API_BASE") # type: ignore
or "https://api.cohere.com/v1/rerank"
)
headers: Dict = litellm.headers or {}
if api_base is None:
raise Exception(
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
)
headers = headers or litellm.headers or {}
response = cohere_rerank.rerank(
model=model,
@ -127,22 +158,72 @@ def rerank(
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=cohere_key,
api_key=api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
elif _custom_llm_provider == "azure_ai":
api_base = (
dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or optional_params.api_base
or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY") # type: ignore
)
headers = headers or litellm.headers or {}
if api_key is None:
raise ValueError(
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
)
if api_base is None:
raise Exception(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
response = azure_ai_rerank.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
pass
elif _custom_llm_provider == "together_ai":
# Implement Together AI rerank logic
together_key = (
api_key = (
dynamic_api_key
or optional_params.api_key
or litellm.togetherai_api_key
or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHERAI_API_KEY") # type: ignore
or litellm.api_key
)
if together_key is None:
if api_key is None:
raise ValueError(
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
)
@ -155,7 +236,7 @@ def rerank(
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=together_key,
api_key=api_key,
_is_async=_is_async,
)
@ -166,4 +247,6 @@ def rerank(
return response
except Exception as e:
verbose_logger.error(f"Error in rerank: {str(e)}")
raise e
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
)

View file

@ -1290,7 +1290,7 @@ class Router:
raise e
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
model_name = model
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"

View file

@ -1351,6 +1351,37 @@ def test_logging_async_cache_hit_sync_call():
assert standard_logging_object["saved_cache_cost"] > 0
def test_logging_standard_payload_failure_call():
from litellm.types.utils import StandardLoggingPayload
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
def test_logging_key_masking_gemini():
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]

View file

@ -63,6 +63,7 @@ def test_guardrail_masking_logging_only():
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
time.sleep(3)
mock_call.assert_called_once()
print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"])

View file

@ -129,6 +129,47 @@ async def test_basic_rerank_together_ai(sync_mode):
assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank_azure_ai(sync_mode):
import os
litellm.set_verbose = True
if sync_mode is True:
response = litellm.rerank(
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
query="hello",
documents=["hello", "world"],
top_n=3,
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
else:
response = await litellm.arerank(
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
query="hello",
documents=["hello", "world"],
top_n=3,
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio()
async def test_rerank_custom_api_base():
mock_response = AsyncMock()

View file

@ -89,6 +89,9 @@ class LitellmParams(TypedDict):
presidio_ad_hoc_recognizers: Optional[str]
mock_redacted_text: Optional[dict]
# hide secrets params
detect_secrets_config: Optional[dict]
class Guardrail(TypedDict):
guardrail_name: str

View file

@ -1278,10 +1278,14 @@ class StandardLoggingModelInformation(TypedDict):
model_map_value: Optional[ModelInfo]
StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingPayload(TypedDict):
id: str
call_type: str
response_cost: float
status: StandardLoggingPayloadStatus
total_tokens: int
prompt_tokens: int
completion_tokens: int
@ -1302,6 +1306,7 @@ class StandardLoggingPayload(TypedDict):
requester_ip_address: Optional[str]
messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]]
error_str: Optional[str]
model_parameters: dict
hidden_params: StandardLoggingHiddenParams

View file

@ -1,6 +1,6 @@
{
"ignore": [],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"],
"reportMissingImports": false
}