LiteLLM Minor Fixes & Improvements (09/17/2024) (#5742)

* fix(proxy_server.py): use default azure credentials to support azure non-client secret kms

* fix(langsmith.py): raise error if credentials missing

* feat(langsmith.py): support error logging for langsmith + standard logging payload

Fixes https://github.com/BerriAI/litellm/issues/5738

* Fix hardcoding of schema in view check (#5749)

* fix - deal with case when check view exists returns None (#5740)

* Revert "fix - deal with case when check view exists returns None (#5740)" (#5741)

This reverts commit 535228159b.

* test(test_router_debug_logs.py): move to mock response

* Fix hardcoding of schema

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>

* fix(proxy_server.py): allow admin to disable ui via `DISABLE_ADMIN_UI` flag

* fix(router.py): fix default model name value

Fixes 55db19a1e4 (r1763712148)

* fix(utils.py): fix unbound variable error

* feat(rerank/main.py): add azure ai rerank endpoints

Closes https://github.com/BerriAI/litellm/issues/5667

* feat(secret_detection.py): Allow configuring secret detection params

Allows admin to control what plugins to run for secret detection. Prevents overzealous secret detection.

* docs(secret_detection.md): add secret detection guardrail docs

* fix: fix linting errors

* fix - deal with case when check view exists returns None (#5740)

* Revert "fix - deal with case when check view exists returns None (#5740)" (#5741)

This reverts commit 535228159b.

* Litellm fix router testing (#5748)

* test: fix testing - azure changed content policy error logic

* test: fix tests to use mock responses

* test(test_image_generation.py): handle api instability

* test(test_image_generation.py): handle azure api instability

* fix(utils.py): fix unbounded variable error

* fix(utils.py): fix unbounded variable error

* test: refactor test to use mock response

* test: mark flaky azure tests

* Bump next from 14.1.1 to 14.2.10 in /ui/litellm-dashboard (#5753)

Bumps [next](https://github.com/vercel/next.js) from 14.1.1 to 14.2.10.
- [Release notes](https://github.com/vercel/next.js/releases)
- [Changelog](https://github.com/vercel/next.js/blob/canary/release.js)
- [Commits](https://github.com/vercel/next.js/compare/v14.1.1...v14.2.10)

---
updated-dependencies:
- dependency-name: next
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Fix] o1-mini causes pydantic warnings on `reasoning_tokens`  (#5754)

* add requester_metadata in standard logging payload

* log requester_metadata in metadata

* use StandardLoggingPayload for logging

* docs StandardLoggingPayload

* fix import

* include standard logging object in failure

* add test for requester metadata

* handle completion_tokens_details

* add test for completion_tokens_details

* [Feat-Proxy-DataDog] Log Redis, Postgres Failure events on DataDog  (#5750)

* dd - start tracking redis status on dd

* add async_service_succes_hook / failure hook in custom logger

* add async_service_failure_hook

* log service failures on dd

* fix import error

* add test for redis errors / warning

* [Fix] Router/ Proxy - Tag Based routing, raise correct error when no deployments found and tag filtering is on  (#5745)

* fix tag routing - raise correct error when no model with tag based routing

* fix error string from tag based routing

* test router tag based routing

* raise 401 error when no tags avialable for deploymen

* linting fix

* [Feat] Log Request metadata on gcs bucket logging (#5743)

* add requester_metadata in standard logging payload

* log requester_metadata in metadata

* use StandardLoggingPayload for logging

* docs StandardLoggingPayload

* fix import

* include standard logging object in failure

* add test for requester metadata

* fix(litellm_logging.py): fix logging message

* fix(rerank_api/main.py): fix linting errors

* fix(custom_guardrails.py): maintain backwards compatibility for older guardrails

* fix(rerank_api/main.py): fix cost tracking for rerank endpoints

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: steffen-sbt <148480574+steffen-sbt@users.noreply.github.com>
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-17 23:00:04 -07:00 committed by GitHub
parent c5c64a6c04
commit 98c335acd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 1261 additions and 257 deletions

View file

@ -0,0 +1,557 @@
# ✨ Secret Detection/Redaction (Enterprise-only)
❓ Use this to REDACT API Keys, Secrets sent in requests to an LLM.
Example if you want to redact the value of `OPENAI_API_KEY` in the following request
#### Incoming Request
```json
{
"messages": [
{
"role": "user",
"content": "Hey, how's it going, API_KEY = 'sk_1234567890abcdef'",
}
]
}
```
#### Request after Moderation
```json
{
"messages": [
{
"role": "user",
"content": "Hey, how's it going, API_KEY = '[REDACTED]'",
}
]
}
```
**Usage**
**Step 1** Add this to your config.yaml
```yaml
guardrails:
- guardrail_name: "my-custom-name"
litellm_params:
guardrail: "hide-secrets" # supported values: "aporia", "lakera", ..
mode: "pre_call"
```
**Step 2** Run litellm proxy with `--detailed_debug` to see the server logs
```
litellm --config config.yaml --detailed_debug
```
**Step 3** Test it with request
Send this request
```shell
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "fake-claude-endpoint",
"messages": [
{
"role": "user",
"content": "what is the value of my open ai key? openai_api_key=sk-1234998222"
}
],
"guardrails": ["my-custom-name"]
}'
```
Expect to see the following warning on your litellm server logs
```shell
LiteLLM Proxy:WARNING: secret_detection.py:88 - Detected and redacted secrets in message: ['Secret Keyword']
```
You can also see the raw request sent from litellm to the API Provider with (`--detailed_debug`).
```json
POST Request Sent from LiteLLM:
curl -X POST \
https://api.groq.com/openai/v1/ \
-H 'Authorization: Bearer gsk_mySVchjY********************************************' \
-d {
"model": "llama3-8b-8192",
"messages": [
{
"role": "user",
"content": "what is the time today, openai_api_key=[REDACTED]"
}
],
"stream": false,
"extra_body": {}
}
```
## Turn on/off per project (API KEY/Team)
[**See Here**](./quick_start.md#-control-guardrails-per-project-api-key)
## Control secret detectors
LiteLLM uses the [`detect-secrets`](https://github.com/Yelp/detect-secrets) library for secret detection. See [all plugins run by default](#default-config-used)
### Usage
Here's how to control which plugins are run per request. This is useful if developers complain about secret detection impacting response quality.
**1. Set-up config.yaml**
```yaml
guardrails:
- guardrail_name: "hide-secrets"
litellm_params:
guardrail: "hide-secrets" # supported values: "aporia", "lakera"
mode: "pre_call"
detect_secrets_config: {
"plugins_used": [
{"name": "SoftlayerDetector"},
{"name": "StripeDetector"},
{"name": "NpmDetector"}
]
}
```
**2. Start proxy**
Run with `--detailed_debug` for more detailed logs. Use in dev only.
```bash
litellm --config /path/to/config.yaml --detailed_debug
```
**3. Test it!**
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "fake-claude-endpoint",
"messages": [
{
"role": "user",
"content": "what is the value of my open ai key? openai_api_key=sk-1234998222"
}
],
"guardrails": ["hide-secrets"]
}'
```
**Expected Logs**
Look for this in your logs, to confirm your changes worked as expected.
```
No secrets detected on input.
```
### Default Config Used
```
_default_detect_secrets_config = {
"plugins_used": [
{"name": "SoftlayerDetector"},
{"name": "StripeDetector"},
{"name": "NpmDetector"},
{"name": "IbmCosHmacDetector"},
{"name": "DiscordBotTokenDetector"},
{"name": "BasicAuthDetector"},
{"name": "AzureStorageKeyDetector"},
{"name": "ArtifactoryDetector"},
{"name": "AWSKeyDetector"},
{"name": "CloudantDetector"},
{"name": "IbmCloudIamDetector"},
{"name": "JwtTokenDetector"},
{"name": "MailchimpDetector"},
{"name": "SquareOAuthDetector"},
{"name": "PrivateKeyDetector"},
{"name": "TwilioKeyDetector"},
{
"name": "AdafruitKeyDetector",
"path": _custom_plugins_path + "/adafruit.py",
},
{
"name": "AdobeSecretDetector",
"path": _custom_plugins_path + "/adobe.py",
},
{
"name": "AgeSecretKeyDetector",
"path": _custom_plugins_path + "/age_secret_key.py",
},
{
"name": "AirtableApiKeyDetector",
"path": _custom_plugins_path + "/airtable_api_key.py",
},
{
"name": "AlgoliaApiKeyDetector",
"path": _custom_plugins_path + "/algolia_api_key.py",
},
{
"name": "AlibabaSecretDetector",
"path": _custom_plugins_path + "/alibaba.py",
},
{
"name": "AsanaSecretDetector",
"path": _custom_plugins_path + "/asana.py",
},
{
"name": "AtlassianApiTokenDetector",
"path": _custom_plugins_path + "/atlassian_api_token.py",
},
{
"name": "AuthressAccessKeyDetector",
"path": _custom_plugins_path + "/authress_access_key.py",
},
{
"name": "BittrexDetector",
"path": _custom_plugins_path + "/beamer_api_token.py",
},
{
"name": "BitbucketDetector",
"path": _custom_plugins_path + "/bitbucket.py",
},
{
"name": "BeamerApiTokenDetector",
"path": _custom_plugins_path + "/bittrex.py",
},
{
"name": "ClojarsApiTokenDetector",
"path": _custom_plugins_path + "/clojars_api_token.py",
},
{
"name": "CodecovAccessTokenDetector",
"path": _custom_plugins_path + "/codecov_access_token.py",
},
{
"name": "CoinbaseAccessTokenDetector",
"path": _custom_plugins_path + "/coinbase_access_token.py",
},
{
"name": "ConfluentDetector",
"path": _custom_plugins_path + "/confluent.py",
},
{
"name": "ContentfulApiTokenDetector",
"path": _custom_plugins_path + "/contentful_api_token.py",
},
{
"name": "DatabricksApiTokenDetector",
"path": _custom_plugins_path + "/databricks_api_token.py",
},
{
"name": "DatadogAccessTokenDetector",
"path": _custom_plugins_path + "/datadog_access_token.py",
},
{
"name": "DefinedNetworkingApiTokenDetector",
"path": _custom_plugins_path + "/defined_networking_api_token.py",
},
{
"name": "DigitaloceanDetector",
"path": _custom_plugins_path + "/digitalocean.py",
},
{
"name": "DopplerApiTokenDetector",
"path": _custom_plugins_path + "/doppler_api_token.py",
},
{
"name": "DroneciAccessTokenDetector",
"path": _custom_plugins_path + "/droneci_access_token.py",
},
{
"name": "DuffelApiTokenDetector",
"path": _custom_plugins_path + "/duffel_api_token.py",
},
{
"name": "DynatraceApiTokenDetector",
"path": _custom_plugins_path + "/dynatrace_api_token.py",
},
{
"name": "DiscordDetector",
"path": _custom_plugins_path + "/discord.py",
},
{
"name": "DropboxDetector",
"path": _custom_plugins_path + "/dropbox.py",
},
{
"name": "EasyPostDetector",
"path": _custom_plugins_path + "/easypost.py",
},
{
"name": "EtsyAccessTokenDetector",
"path": _custom_plugins_path + "/etsy_access_token.py",
},
{
"name": "FacebookAccessTokenDetector",
"path": _custom_plugins_path + "/facebook_access_token.py",
},
{
"name": "FastlyApiKeyDetector",
"path": _custom_plugins_path + "/fastly_api_token.py",
},
{
"name": "FinicityDetector",
"path": _custom_plugins_path + "/finicity.py",
},
{
"name": "FinnhubAccessTokenDetector",
"path": _custom_plugins_path + "/finnhub_access_token.py",
},
{
"name": "FlickrAccessTokenDetector",
"path": _custom_plugins_path + "/flickr_access_token.py",
},
{
"name": "FlutterwaveDetector",
"path": _custom_plugins_path + "/flutterwave.py",
},
{
"name": "FrameIoApiTokenDetector",
"path": _custom_plugins_path + "/frameio_api_token.py",
},
{
"name": "FreshbooksAccessTokenDetector",
"path": _custom_plugins_path + "/freshbooks_access_token.py",
},
{
"name": "GCPApiKeyDetector",
"path": _custom_plugins_path + "/gcp_api_key.py",
},
{
"name": "GitHubTokenCustomDetector",
"path": _custom_plugins_path + "/github_token.py",
},
{
"name": "GitLabDetector",
"path": _custom_plugins_path + "/gitlab.py",
},
{
"name": "GitterAccessTokenDetector",
"path": _custom_plugins_path + "/gitter_access_token.py",
},
{
"name": "GoCardlessApiTokenDetector",
"path": _custom_plugins_path + "/gocardless_api_token.py",
},
{
"name": "GrafanaDetector",
"path": _custom_plugins_path + "/grafana.py",
},
{
"name": "HashiCorpTFApiTokenDetector",
"path": _custom_plugins_path + "/hashicorp_tf_api_token.py",
},
{
"name": "HerokuApiKeyDetector",
"path": _custom_plugins_path + "/heroku_api_key.py",
},
{
"name": "HubSpotApiTokenDetector",
"path": _custom_plugins_path + "/hubspot_api_key.py",
},
{
"name": "HuggingFaceDetector",
"path": _custom_plugins_path + "/huggingface.py",
},
{
"name": "IntercomApiTokenDetector",
"path": _custom_plugins_path + "/intercom_api_key.py",
},
{
"name": "JFrogDetector",
"path": _custom_plugins_path + "/jfrog.py",
},
{
"name": "JWTBase64Detector",
"path": _custom_plugins_path + "/jwt.py",
},
{
"name": "KrakenAccessTokenDetector",
"path": _custom_plugins_path + "/kraken_access_token.py",
},
{
"name": "KucoinDetector",
"path": _custom_plugins_path + "/kucoin.py",
},
{
"name": "LaunchdarklyAccessTokenDetector",
"path": _custom_plugins_path + "/launchdarkly_access_token.py",
},
{
"name": "LinearDetector",
"path": _custom_plugins_path + "/linear.py",
},
{
"name": "LinkedInDetector",
"path": _custom_plugins_path + "/linkedin.py",
},
{
"name": "LobDetector",
"path": _custom_plugins_path + "/lob.py",
},
{
"name": "MailgunDetector",
"path": _custom_plugins_path + "/mailgun.py",
},
{
"name": "MapBoxApiTokenDetector",
"path": _custom_plugins_path + "/mapbox_api_token.py",
},
{
"name": "MattermostAccessTokenDetector",
"path": _custom_plugins_path + "/mattermost_access_token.py",
},
{
"name": "MessageBirdDetector",
"path": _custom_plugins_path + "/messagebird.py",
},
{
"name": "MicrosoftTeamsWebhookDetector",
"path": _custom_plugins_path + "/microsoft_teams_webhook.py",
},
{
"name": "NetlifyAccessTokenDetector",
"path": _custom_plugins_path + "/netlify_access_token.py",
},
{
"name": "NewRelicDetector",
"path": _custom_plugins_path + "/new_relic.py",
},
{
"name": "NYTimesAccessTokenDetector",
"path": _custom_plugins_path + "/nytimes_access_token.py",
},
{
"name": "OktaAccessTokenDetector",
"path": _custom_plugins_path + "/okta_access_token.py",
},
{
"name": "OpenAIApiKeyDetector",
"path": _custom_plugins_path + "/openai_api_key.py",
},
{
"name": "PlanetScaleDetector",
"path": _custom_plugins_path + "/planetscale.py",
},
{
"name": "PostmanApiTokenDetector",
"path": _custom_plugins_path + "/postman_api_token.py",
},
{
"name": "PrefectApiTokenDetector",
"path": _custom_plugins_path + "/prefect_api_token.py",
},
{
"name": "PulumiApiTokenDetector",
"path": _custom_plugins_path + "/pulumi_api_token.py",
},
{
"name": "PyPiUploadTokenDetector",
"path": _custom_plugins_path + "/pypi_upload_token.py",
},
{
"name": "RapidApiAccessTokenDetector",
"path": _custom_plugins_path + "/rapidapi_access_token.py",
},
{
"name": "ReadmeApiTokenDetector",
"path": _custom_plugins_path + "/readme_api_token.py",
},
{
"name": "RubygemsApiTokenDetector",
"path": _custom_plugins_path + "/rubygems_api_token.py",
},
{
"name": "ScalingoApiTokenDetector",
"path": _custom_plugins_path + "/scalingo_api_token.py",
},
{
"name": "SendbirdDetector",
"path": _custom_plugins_path + "/sendbird.py",
},
{
"name": "SendGridApiTokenDetector",
"path": _custom_plugins_path + "/sendgrid_api_token.py",
},
{
"name": "SendinBlueApiTokenDetector",
"path": _custom_plugins_path + "/sendinblue_api_token.py",
},
{
"name": "SentryAccessTokenDetector",
"path": _custom_plugins_path + "/sentry_access_token.py",
},
{
"name": "ShippoApiTokenDetector",
"path": _custom_plugins_path + "/shippo_api_token.py",
},
{
"name": "ShopifyDetector",
"path": _custom_plugins_path + "/shopify.py",
},
{
"name": "SlackDetector",
"path": _custom_plugins_path + "/slack.py",
},
{
"name": "SnykApiTokenDetector",
"path": _custom_plugins_path + "/snyk_api_token.py",
},
{
"name": "SquarespaceAccessTokenDetector",
"path": _custom_plugins_path + "/squarespace_access_token.py",
},
{
"name": "SumoLogicDetector",
"path": _custom_plugins_path + "/sumologic.py",
},
{
"name": "TelegramBotApiTokenDetector",
"path": _custom_plugins_path + "/telegram_bot_api_token.py",
},
{
"name": "TravisCiAccessTokenDetector",
"path": _custom_plugins_path + "/travisci_access_token.py",
},
{
"name": "TwitchApiTokenDetector",
"path": _custom_plugins_path + "/twitch_api_token.py",
},
{
"name": "TwitterDetector",
"path": _custom_plugins_path + "/twitter.py",
},
{
"name": "TypeformApiTokenDetector",
"path": _custom_plugins_path + "/typeform_api_token.py",
},
{
"name": "VaultDetector",
"path": _custom_plugins_path + "/vault.py",
},
{
"name": "YandexDetector",
"path": _custom_plugins_path + "/yandex.py",
},
{
"name": "ZendeskSecretKeyDetector",
"path": _custom_plugins_path + "/zendesk_secret_key.py",
},
{"name": "Base64HighEntropyString", "limit": 3.0},
{"name": "HexHighEntropyString", "limit": 3.0},
]
}
```

View file

@ -297,3 +297,14 @@ Set your colors to any of the following colors: https://www.tremor.so/docs/layou
- Deploy LiteLLM Proxy Server - Deploy LiteLLM Proxy Server
## Disable Admin UI
Set `DISABLE_ADMIN_UI="True"` in your environment to disable the Admin UI.
Useful, if your security team has additional restrictions on UI usage.
**Expected Response**
<Image img={require('../../img/admin_ui_disabled.png')}/>

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

View file

@ -81,6 +81,7 @@ const sidebars = {
"proxy/guardrails/lakera_ai", "proxy/guardrails/lakera_ai",
"proxy/guardrails/bedrock", "proxy/guardrails/bedrock",
"proxy/guardrails/pii_masking_v2", "proxy/guardrails/pii_masking_v2",
"proxy/guardrails/secret_detection",
"proxy/guardrails/custom_guardrail", "proxy/guardrails/custom_guardrail",
"prompt_injection" "prompt_injection"
], ],

View file

@ -5,39 +5,24 @@
# +-------------------------------------------------------------+ # +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os import sys
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import Optional, Literal, Union from typing import Optional
import litellm, traceback, sys, uuid
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
import tempfile import tempfile
from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_guardrail import CustomGuardrail
litellm.set_verbose = True
GUARDRAIL_NAME = "hide_secrets" GUARDRAIL_NAME = "hide_secrets"
_custom_plugins_path = "file://" + os.path.join( _custom_plugins_path = "file://" + os.path.join(
os.path.dirname(os.path.abspath(__file__)), "secrets_plugins" os.path.dirname(os.path.abspath(__file__)), "secrets_plugins"
) )
print("custom plugins path", _custom_plugins_path)
_default_detect_secrets_config = { _default_detect_secrets_config = {
"plugins_used": [ "plugins_used": [
{"name": "SoftlayerDetector"}, {"name": "SoftlayerDetector"},
@ -434,9 +419,10 @@ _default_detect_secrets_config = {
} }
class _ENTERPRISE_SecretDetection(CustomLogger): class _ENTERPRISE_SecretDetection(CustomGuardrail):
def __init__(self): def __init__(self, detect_secrets_config: Optional[dict] = None, **kwargs):
pass self.user_defined_detect_secrets_config = detect_secrets_config
super().__init__(**kwargs)
def scan_message_for_secrets(self, message_content: str): def scan_message_for_secrets(self, message_content: str):
from detect_secrets import SecretsCollection from detect_secrets import SecretsCollection
@ -447,7 +433,11 @@ class _ENTERPRISE_SecretDetection(CustomLogger):
temp_file.close() temp_file.close()
secrets = SecretsCollection() secrets = SecretsCollection()
with transient_settings(_default_detect_secrets_config):
detect_secrets_config = (
self.user_defined_detect_secrets_config or _default_detect_secrets_config
)
with transient_settings(detect_secrets_config):
secrets.scan_file(temp_file.name) secrets.scan_file(temp_file.name)
os.remove(temp_file.name) os.remove(temp_file.name)
@ -484,9 +474,12 @@ class _ENTERPRISE_SecretDetection(CustomLogger):
from detect_secrets import SecretsCollection from detect_secrets import SecretsCollection
from detect_secrets.settings import default_settings from detect_secrets.settings import default_settings
print("INSIDE SECRET DETECTION PRE-CALL HOOK!")
if await self.should_run_check(user_api_key_dict) is False: if await self.should_run_check(user_api_key_dict) is False:
return return
print("RUNNING CHECK!")
if "messages" in data and isinstance(data["messages"], list): if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]: for message in data["messages"]:
if "content" in message and isinstance(message["content"], str): if "content" in message and isinstance(message["content"], str):
@ -503,6 +496,8 @@ class _ENTERPRISE_SecretDetection(CustomLogger):
verbose_proxy_logger.warning( verbose_proxy_logger.warning(
f"Detected and redacted secrets in message: {secret_types}" f"Detected and redacted secrets in message: {secret_types}"
) )
else:
verbose_proxy_logger.debug("No secrets detected on input.")
if "prompt" in data: if "prompt" in data:
if isinstance(data["prompt"], str): if isinstance(data["prompt"], str):

View file

@ -2,7 +2,7 @@
## File for 'response_cost' calculation in Logging ## File for 'response_cost' calculation in Logging
import time import time
import traceback import traceback
from typing import List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -100,7 +100,7 @@ def cost_per_token(
"rerank", "rerank",
"arerank", "arerank",
] = "completion", ] = "completion",
) -> Tuple[float, float]: ) -> Tuple[float, float]: # type: ignore
""" """
Calculates the cost per token for a given model, prompt tokens, and completion tokens. Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -425,20 +425,24 @@ def get_model_params_and_category(model_name) -> str:
return model_name return model_name
def get_replicate_completion_pricing(completion_response=None, total_time=0.0): def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
# see https://replicate.com/pricing # see https://replicate.com/pricing
# for all litellm currently supported LLMs, almost all requests go to a100_80gb # for all litellm currently supported LLMs, almost all requests go to a100_80gb
a100_80gb_price_per_second_public = ( a100_80gb_price_per_second_public = (
0.001400 # assume all calls sent to A100 80GB for now 0.001400 # assume all calls sent to A100 80GB for now
) )
if total_time == 0.0: # total time is in ms if total_time == 0.0: # total time is in ms
start_time = completion_response["created"] start_time = completion_response.get("created", time.time())
end_time = getattr(completion_response, "ended", time.time()) end_time = getattr(completion_response, "ended", time.time())
total_time = end_time - start_time total_time = end_time - start_time
return a100_80gb_price_per_second_public * total_time / 1000 return a100_80gb_price_per_second_public * total_time / 1000
def has_hidden_params(obj: Any) -> bool:
return hasattr(obj, "_hidden_params")
def _select_model_name_for_cost_calc( def _select_model_name_for_cost_calc(
model: Optional[str], model: Optional[str],
completion_response: Union[BaseModel, dict, str], completion_response: Union[BaseModel, dict, str],
@ -463,12 +467,14 @@ def _select_model_name_for_cost_calc(
elif return_model is None: elif return_model is None:
return_model = completion_response.get("model", "") # type: ignore return_model = completion_response.get("model", "") # type: ignore
if hasattr(completion_response, "_hidden_params"): hidden_params = getattr(completion_response, "_hidden_params", None)
if hidden_params is not None:
if ( if (
completion_response._hidden_params.get("model", None) is not None hidden_params.get("model", None) is not None
and len(completion_response._hidden_params["model"]) > 0 and len(hidden_params["model"]) > 0
): ):
return_model = completion_response._hidden_params.get("model", model) return_model = hidden_params.get("model", model)
return return_model return return_model
@ -558,7 +564,7 @@ def completion_cost(
or isinstance(completion_response, dict) or isinstance(completion_response, dict)
): # tts returns a custom class ): # tts returns a custom class
usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( # type: ignore
"usage", {} "usage", {}
) )
if isinstance(usage_obj, BaseModel) and not isinstance( if isinstance(usage_obj, BaseModel) and not isinstance(
@ -569,17 +575,17 @@ def completion_cost(
"usage", "usage",
litellm.Usage(**usage_obj.model_dump()), litellm.Usage(**usage_obj.model_dump()),
) )
if usage_obj is None:
_usage = {}
elif isinstance(usage_obj, BaseModel):
_usage = usage_obj.model_dump()
else:
_usage = usage_obj
# get input/output tokens from completion_response # get input/output tokens from completion_response
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) prompt_tokens = _usage.get("prompt_tokens", 0)
completion_tokens = completion_response.get("usage", {}).get( completion_tokens = _usage.get("completion_tokens", 0)
"completion_tokens", 0 cache_creation_input_tokens = _usage.get("cache_creation_input_tokens", 0)
) cache_read_input_tokens = _usage.get("cache_read_input_tokens", 0)
cache_creation_input_tokens = completion_response.get("usage", {}).get(
"cache_creation_input_tokens", 0
)
cache_read_input_tokens = completion_response.get("usage", {}).get(
"cache_read_input_tokens", 0
)
total_time = getattr(completion_response, "_response_ms", 0) total_time = getattr(completion_response, "_response_ms", 0)
verbose_logger.debug( verbose_logger.debug(
@ -588,24 +594,19 @@ def completion_cost(
model = _select_model_name_for_cost_calc( model = _select_model_name_for_cost_calc(
model=model, completion_response=completion_response model=model, completion_response=completion_response
) )
if hasattr(completion_response, "_hidden_params"): hidden_params = getattr(completion_response, "_hidden_params", None)
custom_llm_provider = completion_response._hidden_params.get( if hidden_params is not None:
custom_llm_provider = hidden_params.get(
"custom_llm_provider", custom_llm_provider or None "custom_llm_provider", custom_llm_provider or None
) )
region_name = completion_response._hidden_params.get( region_name = hidden_params.get("region_name", region_name)
"region_name", region_name size = hidden_params.get("optional_params", {}).get(
)
size = completion_response._hidden_params.get(
"optional_params", {}
).get(
"size", "1024-x-1024" "size", "1024-x-1024"
) # openai default ) # openai default
quality = completion_response._hidden_params.get( quality = hidden_params.get("optional_params", {}).get(
"optional_params", {}
).get(
"quality", "standard" "quality", "standard"
) # openai default ) # openai default
n = completion_response._hidden_params.get("optional_params", {}).get( n = hidden_params.get("optional_params", {}).get(
"n", 1 "n", 1
) # openai default ) # openai default
else: else:
@ -643,6 +644,8 @@ def completion_cost(
# Vertex Charges Flat $0.20 per image # Vertex Charges Flat $0.20 per image
return 0.020 return 0.020
if size is None:
size = "1024-x-1024" # openai default
# fix size to match naming convention # fix size to match naming convention
if "x" in size and "-x-" not in size: if "x" in size and "-x-" not in size:
size = size.replace("x", "-x-") size = size.replace("x", "-x-")
@ -697,7 +700,7 @@ def completion_cost(
model in litellm.replicate_models or "replicate" in model model in litellm.replicate_models or "replicate" in model
) and model not in litellm.model_cost: ) and model not in litellm.model_cost:
# for unmapped replicate model, default to replicate's time tracking logic # for unmapped replicate model, default to replicate's time tracking logic
return get_replicate_completion_pricing(completion_response, total_time) return get_replicate_completion_pricing(completion_response, total_time) # type: ignore
if model is None: if model is None:
raise ValueError( raise ValueError(
@ -847,7 +850,9 @@ def rerank_cost(
Returns Returns
- float or None: cost of response OR none if error. - float or None: cost of response OR none if error.
""" """
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) _, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
try: try:
if custom_llm_provider == "cohere": if custom_llm_provider == "cohere":

View file

@ -45,6 +45,9 @@ class CustomBatchLogger(CustomLogger):
await self.flush_queue() await self.flush_queue()
async def flush_queue(self): async def flush_queue(self):
if self.flush_lock is None:
return
async with self.flush_lock: async with self.flush_lock:
if self.log_queue: if self.log_queue:
verbose_logger.debug( verbose_logger.debug(
@ -54,5 +57,5 @@ class CustomBatchLogger(CustomLogger):
self.log_queue.clear() self.log_queue.clear()
self.last_flush_time = time.time() self.last_flush_time = time.time()
async def async_send_batch(self): async def async_send_batch(self, *args, **kwargs):
pass pass

View file

@ -29,12 +29,13 @@ class CustomGuardrail(CustomLogger):
) )
if ( if (
self.guardrail_name not in requested_guardrails self.event_hook
and self.guardrail_name not in requested_guardrails
and event_type.value != "logging_only" and event_type.value != "logging_only"
): ):
return False return False
if self.event_hook != event_type.value: if self.event_hook and self.event_hook != event_type.value:
return False return False
return True return True

View file

@ -8,7 +8,7 @@ import traceback
import types import types
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, List, Optional, Union from typing import Any, Dict, List, Optional, TypedDict, Union
import dotenv # type: ignore import dotenv # type: ignore
import httpx import httpx
@ -23,6 +23,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
) )
from litellm.types.utils import StandardLoggingPayload
class LangsmithInputs(BaseModel): class LangsmithInputs(BaseModel):
@ -46,6 +47,12 @@ class LangsmithInputs(BaseModel):
user_api_key_team_alias: Optional[str] = None user_api_key_team_alias: Optional[str] = None
class LangsmithCredentialsObject(TypedDict):
LANGSMITH_API_KEY: str
LANGSMITH_PROJECT: str
LANGSMITH_BASE_URL: str
def is_serializable(value): def is_serializable(value):
non_serializable_types = ( non_serializable_types = (
types.CoroutineType, types.CoroutineType,
@ -57,15 +64,27 @@ def is_serializable(value):
class LangsmithLogger(CustomBatchLogger): class LangsmithLogger(CustomBatchLogger):
def __init__(self, **kwargs): def __init__(
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY") self,
self.langsmith_project = os.getenv("LANGSMITH_PROJECT", "litellm-completion") langsmith_api_key: Optional[str] = None,
langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None,
**kwargs,
):
self.default_credentials = self.get_credentials_from_env(
langsmith_api_key=langsmith_api_key,
langsmith_project=langsmith_project,
langsmith_base_url=langsmith_base_url,
)
self.sampling_rate: float = (
float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0
)
self.langsmith_default_run_name = os.getenv( self.langsmith_default_run_name = os.getenv(
"LANGSMITH_DEFAULT_RUN_NAME", "LLMRun" "LANGSMITH_DEFAULT_RUN_NAME", "LLMRun"
) )
self.langsmith_base_url = os.getenv(
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
)
self.async_httpx_client = get_async_httpx_client( self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback llm_provider=httpxSpecialProvider.LoggingCallback
) )
@ -78,126 +97,160 @@ class LangsmithLogger(CustomBatchLogger):
self.flush_lock = asyncio.Lock() self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock) super().__init__(**kwargs, flush_lock=self.flush_lock)
def get_credentials_from_env(
self,
langsmith_api_key: Optional[str],
langsmith_project: Optional[str],
langsmith_base_url: Optional[str],
) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
if _credentials_api_key is None:
raise Exception(
"Invalid Langsmith API Key given. _credentials_api_key=None."
)
_credentials_project = (
langsmith_project or os.getenv("LANGSMITH_PROJECT") or "litellm-completion"
)
if _credentials_project is None:
raise Exception(
"Invalid Langsmith API Key given. _credentials_project=None."
)
_credentials_base_url = (
langsmith_base_url
or os.getenv("LANGSMITH_BASE_URL")
or "https://api.smith.langchain.com"
)
if _credentials_base_url is None:
raise Exception(
"Invalid Langsmith API Key given. _credentials_base_url=None."
)
return LangsmithCredentialsObject(
LANGSMITH_API_KEY=_credentials_api_key,
LANGSMITH_BASE_URL=_credentials_base_url,
LANGSMITH_PROJECT=_credentials_project,
)
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time): def _prepare_log_data(self, kwargs, response_obj, start_time, end_time):
import datetime import json
from datetime import datetime as dt from datetime import datetime as dt
from datetime import timezone
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) or {}
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = value
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump_json()
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, dt):
value[k] = v.isoformat()
new_metadata[key] = value
metadata = new_metadata
kwargs["user_api_key"] = metadata.get("user_api_key", None)
kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
kwargs["user_api_key_team_alias"] = metadata.get(
"user_api_key_team_alias", None
)
project_name = metadata.get("project_name", self.langsmith_project)
run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None)
parent_run_id = metadata.get("parent_run_id", None)
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or []
verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
try: try:
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat() _litellm_params = kwargs.get("litellm_params", {}) or {}
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat() metadata = _litellm_params.get("metadata", {}) or {}
except: new_metadata = {}
start_time = datetime.datetime.utcnow().isoformat() for key, value in metadata.items():
end_time = datetime.datetime.utcnow().isoformat() if (
isinstance(value, list)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = value
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump_json()
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, dt):
value[k] = v.isoformat()
new_metadata[key] = value
# filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs metadata = new_metadata
logged_kwargs = LangsmithInputs(**kwargs)
kwargs = logged_kwargs.model_dump()
new_kwargs = {} kwargs["user_api_key"] = metadata.get("user_api_key", None)
for key in kwargs: kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
value = kwargs[key] kwargs["user_api_key_team_alias"] = metadata.get(
if key == "start_time" or key == "end_time" or value is None: "user_api_key_team_alias", None
pass )
elif key == "original_response" and not isinstance(value, str):
new_kwargs[key] = str(value)
elif type(value) == datetime.datetime:
new_kwargs[key] = value.isoformat()
elif type(value) != dict and is_serializable(value=value):
new_kwargs[key] = value
elif not is_serializable(value=value):
continue
if isinstance(response_obj, BaseModel): project_name = metadata.get(
try: "project_name", self.default_credentials["LANGSMITH_PROJECT"]
response_obj = response_obj.model_dump() )
except: run_name = metadata.get("run_name", self.langsmith_default_run_name)
response_obj = response_obj.dict() # type: ignore run_id = metadata.get("id", None)
parent_run_id = metadata.get("parent_run_id", None)
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or []
verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
data = { # filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
"name": run_name, # logged_kwargs = LangsmithInputs(**kwargs)
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" # kwargs = logged_kwargs.model_dump()
"inputs": new_kwargs,
"outputs": response_obj,
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
"tags": tags,
"extra": metadata,
}
if run_id: # new_kwargs = {}
data["id"] = run_id # Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if parent_run_id: if payload is None:
data["parent_run_id"] = parent_run_id raise Exception("Error logging request payload. Payload=none.")
if trace_id: new_kwargs = payload
data["trace_id"] = trace_id metadata = payload[
"metadata"
] # ensure logged metadata is json serializable
if session_id: data = {
data["session_id"] = session_id "name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": new_kwargs,
"outputs": new_kwargs["response"],
"session_name": project_name,
"start_time": new_kwargs["startTime"],
"end_time": new_kwargs["endTime"],
"tags": tags,
"extra": metadata,
}
if dotted_order: if payload["error_str"] is not None and payload["status"] == "failure":
data["dotted_order"] = dotted_order data["error"] = payload["error_str"]
if "id" not in data or data["id"] is None: if run_id:
""" data["id"] = run_id
for /batch langsmith requires id, trace_id and dotted_order passed as params
"""
run_id = uuid.uuid4()
data["id"] = str(run_id)
data["trace_id"] = str(run_id)
data["dotted_order"] = self.make_dot_order(run_id=run_id)
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) if parent_run_id:
data["parent_run_id"] = parent_run_id
return data if trace_id:
data["trace_id"] = trace_id
if session_id:
data["session_id"] = session_id
if dotted_order:
data["dotted_order"] = dotted_order
if "id" not in data or data["id"] is None:
"""
for /batch langsmith requires id, trace_id and dotted_order passed as params
"""
run_id = str(uuid.uuid4())
data["id"] = str(run_id)
data["trace_id"] = str(run_id)
data["dotted_order"] = self.make_dot_order(run_id=run_id)
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
return data
except Exception:
raise
def _send_batch(self): def _send_batch(self):
if not self.log_queue: if not self.log_queue:
return return
url = f"{self.langsmith_base_url}/runs/batch" langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
headers = {"x-api-key": self.langsmith_api_key} langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/batch"
headers = {"x-api-key": langsmith_api_key}
try: try:
response = requests.post( response = requests.post(
@ -216,15 +269,15 @@ class LangsmithLogger(CustomBatchLogger):
) )
self.log_queue.clear() self.log_queue.clear()
except Exception as e: except Exception:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}") verbose_logger.exception("Langsmith Layer Error - Error sending batch.")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
sampling_rate = ( sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE")) float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore
else 1.0 else 1.0
) )
random_sample = random.random() random_sample = random.random()
@ -249,17 +302,12 @@ class LangsmithLogger(CustomBatchLogger):
if len(self.log_queue) >= self.batch_size: if len(self.log_queue) >= self.batch_size:
self._send_batch() self._send_batch()
except: except Exception:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}") verbose_logger.exception("Langsmith Layer Error - log_success_event error")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
sampling_rate = ( sampling_rate = self.sampling_rate
float(os.getenv("LANGSMITH_SAMPLING_RATE"))
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit()
else 1.0
)
random_sample = random.random() random_sample = random.random()
if random_sample > sampling_rate: if random_sample > sampling_rate:
verbose_logger.info( verbose_logger.info(
@ -282,8 +330,36 @@ class LangsmithLogger(CustomBatchLogger):
) )
if len(self.log_queue) >= self.batch_size: if len(self.log_queue) >= self.batch_size:
await self.flush_queue() await self.flush_queue()
except: except Exception:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}") verbose_logger.exception(
"Langsmith Layer Error - error logging async success event."
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
sampling_rate = self.sampling_rate
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.info("Langsmith Failure Event Logging!")
try:
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception:
verbose_logger.exception(
"Langsmith Layer Error - error logging async failure event."
)
async def async_send_batch(self): async def async_send_batch(self):
""" """
@ -295,13 +371,16 @@ class LangsmithLogger(CustomBatchLogger):
Raises: Does not raise an exception, will only verbose_logger.exception() Raises: Does not raise an exception, will only verbose_logger.exception()
""" """
import json
if not self.log_queue: if not self.log_queue:
return return
url = f"{self.langsmith_base_url}/runs/batch" langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
headers = {"x-api-key": self.langsmith_api_key}
url = f"{langsmith_api_base}/runs/batch"
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
headers = {"x-api-key": langsmith_api_key}
try: try:
response = await self.async_httpx_client.post( response = await self.async_httpx_client.post(
@ -332,10 +411,14 @@ class LangsmithLogger(CustomBatchLogger):
def get_run_by_id(self, run_id): def get_run_by_id(self, run_id):
url = f"{self.langsmith_base_url}/runs/{run_id}" langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/{run_id}"
response = requests.get( response = requests.get(
url=url, url=url,
headers={"x-api-key": self.langsmith_api_key}, headers={"x-api-key": langsmith_api_key},
) )
return response.json() return response.json()

View file

@ -43,6 +43,7 @@ from litellm.types.utils import (
StandardLoggingMetadata, StandardLoggingMetadata,
StandardLoggingModelInformation, StandardLoggingModelInformation,
StandardLoggingPayload, StandardLoggingPayload,
StandardLoggingPayloadStatus,
StandardPassThroughResponseObject, StandardPassThroughResponseObject,
TextCompletionResponse, TextCompletionResponse,
TranscriptionResponse, TranscriptionResponse,
@ -668,6 +669,7 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self, logging_obj=self,
status="success",
) )
) )
elif isinstance(result, dict): # pass-through endpoints elif isinstance(result, dict): # pass-through endpoints
@ -679,6 +681,7 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self, logging_obj=self,
status="success",
) )
) )
else: # streaming chunks + image gen. else: # streaming chunks + image gen.
@ -762,6 +765,7 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self, logging_obj=self,
status="success",
) )
) )
if self.dynamic_success_callbacks is not None and isinstance( if self.dynamic_success_callbacks is not None and isinstance(
@ -1390,6 +1394,7 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self, logging_obj=self,
status="success",
) )
) )
if self.dynamic_async_success_callbacks is not None and isinstance( if self.dynamic_async_success_callbacks is not None and isinstance(
@ -1645,6 +1650,20 @@ class Logging:
self.model_call_details["litellm_params"].get("metadata", {}) or {} self.model_call_details["litellm_params"].get("metadata", {}) or {}
) )
metadata.update(exception.headers) metadata.update(exception.headers)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
)
)
return start_time, end_time return start_time, end_time
async def special_failure_handlers(self, exception: Exception): async def special_failure_handlers(self, exception: Exception):
@ -2347,10 +2366,12 @@ def is_valid_sha256_hash(value: str) -> bool:
def get_standard_logging_object_payload( def get_standard_logging_object_payload(
kwargs: Optional[dict], kwargs: Optional[dict],
init_response_obj: Any, init_response_obj: Union[Any, BaseModel, dict],
start_time: dt_object, start_time: dt_object,
end_time: dt_object, end_time: dt_object,
logging_obj: Logging, logging_obj: Logging,
status: StandardLoggingPayloadStatus,
error_str: Optional[str] = None,
) -> Optional[StandardLoggingPayload]: ) -> Optional[StandardLoggingPayload]:
try: try:
if kwargs is None: if kwargs is None:
@ -2467,7 +2488,7 @@ def get_standard_logging_object_payload(
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
model_cost_name = _select_model_name_for_cost_calc( model_cost_name = _select_model_name_for_cost_calc(
model=None, model=None,
completion_response=init_response_obj, completion_response=init_response_obj, # type: ignore
base_model=base_model, base_model=base_model,
custom_pricing=custom_pricing, custom_pricing=custom_pricing,
) )
@ -2498,6 +2519,7 @@ def get_standard_logging_object_payload(
id=str(id), id=str(id),
call_type=call_type or "", call_type=call_type or "",
cache_hit=cache_hit, cache_hit=cache_hit,
status=status,
saved_cache_cost=saved_cache_cost, saved_cache_cost=saved_cache_cost,
startTime=start_time_float, startTime=start_time_float,
endTime=end_time_float, endTime=end_time_float,
@ -2517,11 +2539,12 @@ def get_standard_logging_object_payload(
requester_ip_address=clean_metadata.get("requester_ip_address", None), requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"), messages=kwargs.get("messages"),
response=( # type: ignore response=( # type: ignore
response_obj if len(response_obj.keys()) > 0 else init_response_obj response_obj if len(response_obj.keys()) > 0 else init_response_obj # type: ignore
), ),
model_parameters=kwargs.get("optional_params", None), model_parameters=kwargs.get("optional_params", None),
hidden_params=clean_hidden_params, hidden_params=clean_hidden_params,
model_map_information=model_cost_information, model_map_information=model_cost_information,
error_str=error_str,
) )
verbose_logger.debug( verbose_logger.debug(

View file

@ -0,0 +1 @@
`/chat/completion` calls routed via `openai.py`.

View file

@ -0,0 +1 @@
from .handler import AzureAIRerank

View file

@ -0,0 +1,52 @@
from typing import Any, Dict, List, Optional, Union
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank
from litellm.rerank_api.types import RerankResponse
class AzureAIRerank(CohereRerank):
def rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
if headers is None:
headers = {"Authorization": "Bearer {}".format(api_key)}
else:
headers = {**headers, "Authorization": "Bearer {}".format(api_key)}
# Assuming api_base is a string representing the base URL
api_base_url = httpx.URL(api_base)
# Replace the path with '/v1/rerank' if it doesn't already end with it
if not api_base_url.path.endswith("/v1/rerank"):
api_base = str(api_base_url.copy_with(path="/v1/rerank"))
return super().rerank(
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)

View file

@ -0,0 +1,3 @@
"""
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
"""

View file

@ -10,6 +10,7 @@ import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
@ -19,6 +20,23 @@ from litellm.rerank_api.types import RerankRequest, RerankResponse
class CohereRerank(BaseLLM): class CohereRerank(BaseLLM):
def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict:
default_headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
}
if headers is None:
return default_headers
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def rerank( def rerank(
self, self,
model: str, model: str,
@ -26,12 +44,16 @@ class CohereRerank(BaseLLM):
api_base: str, api_base: str,
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, # New parameter _is_async: Optional[bool] = False, # New parameter
) -> RerankResponse: ) -> RerankResponse:
headers = self.validate_environment(api_key=api_key, headers=headers)
request_data = RerankRequest( request_data = RerankRequest(
model=model, model=model,
query=query, query=query,
@ -45,16 +67,22 @@ class CohereRerank(BaseLLM):
request_data_dict = request_data.dict(exclude_none=True) request_data_dict = request_data.dict(exclude_none=True)
if _is_async: if _is_async:
return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
## LOGGING
litellm_logging_obj.pre_call(
input=request_data_dict,
api_key=api_key,
additional_args={
"complete_input_dict": request_data_dict,
"api_base": api_base,
"headers": headers,
},
)
client = _get_httpx_client() client = _get_httpx_client()
response = client.post( response = client.post(
api_base, api_base,
headers={ headers=headers,
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
},
json=request_data_dict, json=request_data_dict,
) )
@ -65,16 +93,13 @@ class CohereRerank(BaseLLM):
request_data_dict: Dict[str, Any], request_data_dict: Dict[str, Any],
api_key: str, api_key: str,
api_base: str, api_base: str,
headers: dict,
) -> RerankResponse: ) -> RerankResponse:
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
response = await client.post( response = await client.post(
api_base, api_base,
headers={ headers=headers,
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
},
json=request_data_dict, json=request_data_dict,
) )

View file

@ -20,30 +20,15 @@ model_list:
litellm_params: litellm_params:
model: o1-preview model: o1-preview
litellm_settings:
drop_params: True
json_logs: True
store_audit_logs: True
log_raw_request_response: True
return_response_headers: True
num_retries: 5
request_timeout: 200
callbacks: ["custom_callbacks.proxy_handler_instance"]
guardrails: guardrails:
- guardrail_name: "presidio-pre-guard" - guardrail_name: "hide-secrets"
litellm_params: litellm_params:
guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio" guardrail: "hide-secrets" # supported values: "aporia", "lakera"
mode: "logging_only" mode: "pre_call"
mock_redacted_text: { # detect_secrets_config: {
"text": "My name is <PERSON>, who are you? Say my name in your response", # "plugins_used": [
"items": [ # {"name": "SoftlayerDetector"},
{ # {"name": "StripeDetector"},
"start": 11, # {"name": "NpmDetector"}
"end": 19, # ]
"entity_type": "PERSON", # }
"text": "<PERSON>",
"operator": "replace",
}
],
}

View file

@ -166,3 +166,76 @@ def missing_keys_form(missing_key_names: str):
</html> </html>
""" """
return missing_keys_html_form.format(missing_keys=missing_key_names) return missing_keys_html_form.format(missing_keys=missing_key_names)
def admin_ui_disabled():
from fastapi.responses import HTMLResponse
ui_disabled_html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Admin UI Disabled</title>
</head>
<body>
<div class="container">
<h1>Admin UI is Disabled</h1>
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
<pre>
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
</pre>
<p>After making this change, restart the application for it to take effect.</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return HTMLResponse(
content=ui_disabled_html,
status_code=200,
)

View file

@ -189,6 +189,18 @@ def init_guardrails_v2(
litellm.callbacks.append(_success_callback) # type: ignore litellm.callbacks.append(_success_callback) # type: ignore
litellm.callbacks.append(_presidio_callback) # type: ignore litellm.callbacks.append(_presidio_callback) # type: ignore
elif litellm_params["guardrail"] == "hide-secrets":
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
_secret_detection_object = _ENTERPRISE_SecretDetection(
detect_secrets_config=litellm_params.get("detect_secrets_config"),
event_hook=litellm_params["mode"],
guardrail_name=guardrail["guardrail_name"],
)
litellm.callbacks.append(_secret_detection_object) # type: ignore
elif ( elif (
isinstance(litellm_params["guardrail"], str) isinstance(litellm_params["guardrail"], str)
and "." in litellm_params["guardrail"] and "." in litellm_params["guardrail"]

View file

@ -19,12 +19,8 @@ model_list:
model: openai/429 model: openai/429
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app api_base: https://exampleopenaiendpoint-production.up.railway.app
tags: ["fake"]
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
litellm_settings:
success_callback: ["datadog"]
service_callback: ["datadog"]
cache: True

View file

@ -139,6 +139,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
## Import All Misc routes here ## ## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_utils.admin_ui_utils import ( from litellm.proxy.common_utils.admin_ui_utils import (
admin_ui_disabled,
html_form, html_form,
show_missing_vars_in_env, show_missing_vars_in_env,
) )
@ -250,7 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager, load_aws_secret_manager,
) )
from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret, str_to_bool
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicResponse, AnthropicResponse,
@ -652,7 +653,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
return return
try: try:
from azure.identity import ClientSecretCredential from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.keyvault.secrets import SecretClient from azure.keyvault.secrets import SecretClient
# Set your Azure Key Vault URI # Set your Azure Key Vault URI
@ -670,9 +671,10 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
and tenant_id is not None and tenant_id is not None
): ):
# Initialize the ClientSecretCredential # Initialize the ClientSecretCredential
credential = ClientSecretCredential( # credential = ClientSecretCredential(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id # client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
) # )
credential = DefaultAzureCredential()
# Create the SecretClient using the credential # Create the SecretClient using the credential
client = SecretClient(vault_url=KVUri, credential=credential) client = SecretClient(vault_url=KVUri, credential=credential)
@ -7967,6 +7969,13 @@ async def google_login(request: Request):
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
####### Check if UI is disabled #######
_disable_ui_flag = os.getenv("DISABLE_ADMIN_UI")
if _disable_ui_flag is not None:
is_disabled = str_to_bool(value=_disable_ui_flag)
if is_disabled:
return admin_ui_disabled()
####### Check if user is a Enterprise / Premium User ####### ####### Check if user is a Enterprise / Premium User #######
if ( if (
microsoft_client_id is not None microsoft_client_id is not None

View file

@ -414,6 +414,7 @@ class ProxyLogging:
is not True is not True
): ):
continue continue
response = await _callback.async_pre_call_hook( response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"], cache=self.call_details["user_api_key_cache"],
@ -468,7 +469,9 @@ class ProxyLogging:
################################################################ ################################################################
# V1 implementation - backwards compatibility # V1 implementation - backwards compatibility
if callback.event_hook is None: if callback.event_hook is None and hasattr(
callback, "moderation_check"
):
if callback.moderation_check == "pre_call": # type: ignore if callback.moderation_check == "pre_call": # type: ignore
return return
else: else:
@ -975,12 +978,13 @@ class PrismaClient:
] ]
required_view = "LiteLLM_VerificationTokenView" required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views) expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
ret = await self.db.query_raw( ret = await self.db.query_raw(
f""" f"""
WITH existing_views AS ( WITH existing_views AS (
SELECT viewname SELECT viewname
FROM pg_views FROM pg_views
WHERE schemaname = 'public' AND viewname IN ( WHERE schemaname = '{pg_schema}' AND viewname IN (
{expected_views_str} {expected_views_str}
) )
) )

View file

@ -5,11 +5,13 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.llms.togetherai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import client, supports_httpx_timeout from litellm.utils import client, exception_type, supports_httpx_timeout
from .types import RerankRequest, RerankResponse from .types import RerankRequest, RerankResponse
@ -17,6 +19,7 @@ from .types import RerankRequest, RerankResponse
# Initialize any necessary instances or variables here # Initialize any necessary instances or variables here
cohere_rerank = CohereRerank() cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank() together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank()
################################################# #################################################
@ -70,7 +73,7 @@ def rerank(
model: str, model: str,
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
top_n: Optional[int] = None, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
@ -80,11 +83,18 @@ def rerank(
""" """
Reranks a list of documents based on their relevance to the query Reranks a list of documents based on their relevance to the query
""" """
headers: Optional[dict] = kwargs.get("headers") # type: ignore
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None)
try: try:
_is_async = kwargs.pop("arerank", False) is True _is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, api_base = ( model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider( litellm.get_llm_provider(
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
@ -93,31 +103,52 @@ def rerank(
) )
) )
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=optional_params.model_dump(),
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=_custom_llm_provider,
)
# Implement rerank logic here based on the custom_llm_provider # Implement rerank logic here based on the custom_llm_provider
if _custom_llm_provider == "cohere": if _custom_llm_provider == "cohere":
# Implement Cohere rerank logic # Implement Cohere rerank logic
cohere_key = ( api_key: Optional[str] = (
dynamic_api_key dynamic_api_key
or optional_params.api_key or optional_params.api_key
or litellm.cohere_key or litellm.cohere_key
or get_secret("COHERE_API_KEY") or get_secret("COHERE_API_KEY") # type: ignore
or get_secret("CO_API_KEY") or get_secret("CO_API_KEY") # type: ignore
or litellm.api_key or litellm.api_key
) )
if cohere_key is None: if api_key is None:
raise ValueError( raise ValueError(
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment" "Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
) )
api_base = ( api_base: Optional[str] = (
optional_params.api_base dynamic_api_base
or optional_params.api_base
or litellm.api_base or litellm.api_base
or get_secret("COHERE_API_BASE") or get_secret("COHERE_API_BASE") # type: ignore
or "https://api.cohere.com/v1/rerank" or "https://api.cohere.com/v1/rerank"
) )
headers: Dict = litellm.headers or {} if api_base is None:
raise Exception(
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
)
headers = headers or litellm.headers or {}
response = cohere_rerank.rerank( response = cohere_rerank.rerank(
model=model, model=model,
@ -127,22 +158,72 @@ def rerank(
rank_fields=rank_fields, rank_fields=rank_fields,
return_documents=return_documents, return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
api_key=cohere_key, api_key=api_key,
api_base=api_base, api_base=api_base,
_is_async=_is_async, _is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
elif _custom_llm_provider == "azure_ai":
api_base = (
dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or optional_params.api_base
or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY") # type: ignore
)
headers = headers or litellm.headers or {}
if api_key is None:
raise ValueError(
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
)
if api_base is None:
raise Exception(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
response = azure_ai_rerank.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
) )
pass
elif _custom_llm_provider == "together_ai": elif _custom_llm_provider == "together_ai":
# Implement Together AI rerank logic # Implement Together AI rerank logic
together_key = ( api_key = (
dynamic_api_key dynamic_api_key
or optional_params.api_key or optional_params.api_key
or litellm.togetherai_api_key or litellm.togetherai_api_key
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY") # type: ignore
or litellm.api_key or litellm.api_key
) )
if together_key is None: if api_key is None:
raise ValueError( raise ValueError(
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment" "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
) )
@ -155,7 +236,7 @@ def rerank(
rank_fields=rank_fields, rank_fields=rank_fields,
return_documents=return_documents, return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
api_key=together_key, api_key=api_key,
_is_async=_is_async, _is_async=_is_async,
) )
@ -166,4 +247,6 @@ def rerank(
return response return response
except Exception as e: except Exception as e:
verbose_logger.error(f"Error in rerank: {str(e)}") verbose_logger.error(f"Error in rerank: {str(e)}")
raise e raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
)

View file

@ -1290,7 +1290,7 @@ class Router:
raise e raise e
async def _aimage_generation(self, prompt: str, model: str, **kwargs): async def _aimage_generation(self, prompt: str, model: str, **kwargs):
model_name = "" model_name = model
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"

View file

@ -1351,6 +1351,37 @@ def test_logging_async_cache_hit_sync_call():
assert standard_logging_object["saved_cache_cost"] > 0 assert standard_logging_object["saved_cache_cost"] > 0
def test_logging_standard_payload_failure_call():
from litellm.types.utils import StandardLoggingPayload
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
def test_logging_key_masking_gemini(): def test_logging_key_masking_gemini():
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]

View file

@ -63,6 +63,7 @@ def test_guardrail_masking_logging_only():
assert response.choices[0].message.content == "Hi Peter!" # type: ignore assert response.choices[0].message.content == "Hi Peter!" # type: ignore
time.sleep(3)
mock_call.assert_called_once() mock_call.assert_called_once()
print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]) print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"])

View file

@ -129,6 +129,47 @@ async def test_basic_rerank_together_ai(sync_mode):
assert_response_shape(response, custom_llm_provider="together_ai") assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank_azure_ai(sync_mode):
import os
litellm.set_verbose = True
if sync_mode is True:
response = litellm.rerank(
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
query="hello",
documents=["hello", "world"],
top_n=3,
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
else:
response = await litellm.arerank(
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
query="hello",
documents=["hello", "world"],
top_n=3,
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_rerank_custom_api_base(): async def test_rerank_custom_api_base():
mock_response = AsyncMock() mock_response = AsyncMock()

View file

@ -89,6 +89,9 @@ class LitellmParams(TypedDict):
presidio_ad_hoc_recognizers: Optional[str] presidio_ad_hoc_recognizers: Optional[str]
mock_redacted_text: Optional[dict] mock_redacted_text: Optional[dict]
# hide secrets params
detect_secrets_config: Optional[dict]
class Guardrail(TypedDict): class Guardrail(TypedDict):
guardrail_name: str guardrail_name: str

View file

@ -1278,10 +1278,14 @@ class StandardLoggingModelInformation(TypedDict):
model_map_value: Optional[ModelInfo] model_map_value: Optional[ModelInfo]
StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):
id: str id: str
call_type: str call_type: str
response_cost: float response_cost: float
status: StandardLoggingPayloadStatus
total_tokens: int total_tokens: int
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
@ -1302,6 +1306,7 @@ class StandardLoggingPayload(TypedDict):
requester_ip_address: Optional[str] requester_ip_address: Optional[str]
messages: Optional[Union[str, list, dict]] messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]] response: Optional[Union[str, list, dict]]
error_str: Optional[str]
model_parameters: dict model_parameters: dict
hidden_params: StandardLoggingHiddenParams hidden_params: StandardLoggingHiddenParams

View file

@ -1,6 +1,6 @@
{ {
"ignore": [], "ignore": [],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"], "exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"],
"reportMissingImports": false "reportMissingImports": false
} }