diff --git a/.circleci/config.yml b/.circleci/config.yml index 5dfeedcaa2..40d498d6e7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -66,7 +66,7 @@ jobs: pip install "pydantic==2.7.1" pip install "diskcache==5.6.1" pip install "Pillow==10.3.0" - pip install "ijson==3.2.3" + pip install "jsonschema==4.22.0" - save_cache: paths: - ./venv @@ -128,7 +128,7 @@ jobs: pip install jinja2 pip install tokenizers pip install openai - pip install ijson + pip install jsonschema - run: name: Run tests command: | @@ -183,7 +183,7 @@ jobs: pip install numpydoc pip install prisma pip install fastapi - pip install ijson + pip install jsonschema pip install "httpx==0.24.1" pip install "gunicorn==21.2.0" pip install "anyio==3.7.1" @@ -212,6 +212,7 @@ jobs: -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AUTO_INFER_REGION=True \ -e OPENAI_API_KEY=$OPENAI_API_KEY \ + -e LITELLM_LICENSE=$LITELLM_LICENSE \ -e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \ -e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \ -e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \ diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index db29319092..5e2bd60794 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -50,7 +50,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea |Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | ✅ | | | | | |AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | -|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | ✅ | | | +|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | ✅ | ✅ | | | |Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ (for anthropic) | | |Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ | diff --git a/docs/my-website/docs/completion/token_usage.md b/docs/my-website/docs/completion/token_usage.md index 807ccfd91e..0bec6b3f90 100644 --- a/docs/my-website/docs/completion/token_usage.md +++ b/docs/my-website/docs/completion/token_usage.md @@ -1,7 +1,21 @@ # Completion Token Usage & Cost By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/)) -However, we also expose some helper functions + **[NEW]** an API to calculate token usage across providers: +LiteLLM returns `response_cost` in all calls. + +```python +from litellm import completion + +response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="Hello world", + ) + +print(response._hidden_params["response_cost"]) +``` + +LiteLLM also exposes some helper functions: - `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode) @@ -23,7 +37,7 @@ However, we also expose some helper functions + **[NEW]** an API to calculate to - `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#10-apilitellmai) -📣 This is a community maintained list. Contributions are welcome! ❤️ +📣 [This is a community maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). Contributions are welcome! ❤️ ## Example Usage diff --git a/docs/my-website/docs/enterprise.md b/docs/my-website/docs/enterprise.md index 875aec57f0..5bd09ec156 100644 --- a/docs/my-website/docs/enterprise.md +++ b/docs/my-website/docs/enterprise.md @@ -2,26 +2,39 @@ For companies that need SSO, user management and professional support for LiteLLM Proxy :::info - +Interested in Enterprise? Schedule a meeting with us here 👉 [Talk to founders](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ::: This covers: -- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)** -- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui) -- ✅ [**Audit Logs with retention policy**](../docs/proxy/enterprise.md#audit-logs) -- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md) -- ✅ [**Control available public, private routes**](../docs/proxy/enterprise.md#control-available-public-private-routes) -- ✅ [**Guardrails, Content Moderation, PII Masking, Secret/API Key Masking**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai) -- ✅ [**Prompt Injection Detection**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai) -- ✅ [**Invite Team Members to access `/spend` Routes**](../docs/proxy/cost_tracking#allowing-non-proxy-admins-to-access-spend-endpoints) +- **Enterprise Features** + - **Security** + - ✅ [SSO for Admin UI](./proxy/ui#✨-enterprise-features) + - ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs) + - ✅ [JWT-Auth](../docs/proxy/token_auth.md) + - ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes) + - ✅ [[BETA] AWS Key Manager v2 - Key Decryption](./proxy/enterprise#beta-aws-key-manager---key-decryption) + - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) + - ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](./proxy/enterprise#enforce-required-params-for-llm-requests) + - **Spend Tracking** + - ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags) + - ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend) + - **Advanced Metrics** + - ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](./proxy/prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens) + - **Guardrails, PII Masking, Content Moderation** + - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation) + - ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai) + - ✅ Reject calls from Blocked User list + - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) + - **Custom Branding** + - ✅ [Custom Branding + Routes on Swagger Docs](./proxy/enterprise#swagger-docs---custom-routes--branding) + - ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub) + - ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding) - ✅ **Feature Prioritization** - ✅ **Custom Integrations** - ✅ **Professional Support - Dedicated discord + slack** -- ✅ [**Custom Swagger**](../docs/proxy/enterprise.md#swagger-docs---custom-routes--branding) -- ✅ [**Public Model Hub**](../docs/proxy/enterprise.md#public-model-hub) -- ✅ [**Custom Email Branding**](../docs/proxy/email.md#customizing-email-branding) + diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index 3b9e679698..a662129d03 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -168,11 +168,15 @@ print(response) ## Supported Models +`Model Name` 👉 Human-friendly name. +`Function Call` 👉 How to call the model in LiteLLM. + | Model Name | Function Call | |------------------|--------------------------------------------| +| claude-3-5-sonnet | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-3-haiku | `completion('claude-3-haiku-20240307', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-3-opus | `completion('claude-3-opus-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` | -| claude-3-5-sonnet | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` | +| claude-3-5-sonnet-20240620 | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-3-sonnet | `completion('claude-3-sonnet-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-2.1 | `completion('claude-2.1', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-2 | `completion('claude-2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | diff --git a/docs/my-website/docs/providers/azure_ai.md b/docs/my-website/docs/providers/azure_ai.md index 87b8041ef5..26c965a0cb 100644 --- a/docs/my-website/docs/providers/azure_ai.md +++ b/docs/my-website/docs/providers/azure_ai.md @@ -14,7 +14,7 @@ LiteLLM supports all models on Azure AI Studio ### ENV VAR ```python import os -os.environ["AZURE_API_API_KEY"] = "" +os.environ["AZURE_AI_API_KEY"] = "" os.environ["AZURE_AI_API_BASE"] = "" ``` @@ -24,7 +24,7 @@ os.environ["AZURE_AI_API_BASE"] = "" from litellm import completion import os ## set ENV variables -os.environ["AZURE_API_API_KEY"] = "azure ai key" +os.environ["AZURE_AI_API_KEY"] = "azure ai key" os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/ # predibase llama-3 call diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index f380a6a50e..b72dac10bc 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -549,6 +549,10 @@ response = completion( This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues) + +Experimental - 2024-Jun-23: + `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token` will be extracted from boto3.client and be passed into the httpx client + ::: Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth. diff --git a/docs/my-website/docs/providers/databricks.md b/docs/my-website/docs/providers/databricks.md index 24c7c40cff..633350d220 100644 --- a/docs/my-website/docs/providers/databricks.md +++ b/docs/my-website/docs/providers/databricks.md @@ -27,7 +27,7 @@ import os os.environ["DATABRICKS_API_KEY"] = "databricks key" os.environ["DATABRICKS_API_BASE"] = "databricks base url" # e.g.: https://adb-3064715882934586.6.azuredatabricks.net/serving-endpoints -# predibase llama-3 call +# Databricks dbrx-instruct call response = completion( model="databricks/databricks-dbrx-instruct", messages = [{ "content": "Hello, how are you?","role": "user"}] @@ -143,13 +143,13 @@ response = completion( model_list: - model_name: llama-3 litellm_params: - model: predibase/llama-3-8b-instruct - api_key: os.environ/PREDIBASE_API_KEY + model: databricks/databricks-meta-llama-3-70b-instruct + api_key: os.environ/DATABRICKS_API_KEY max_tokens: 20 temperature: 0.5 ``` -## Passings Database specific params - 'instruction' +## Passings Databricks specific params - 'instruction' For embedding models, databricks lets you pass in an additional param 'instruction'. [Full Spec](https://github.com/BerriAI/litellm/blob/43353c28b341df0d9992b45c6ce464222ebd7984/litellm/llms/databricks.py#L164) @@ -162,7 +162,7 @@ import os os.environ["DATABRICKS_API_KEY"] = "databricks key" os.environ["DATABRICKS_API_BASE"] = "databricks url" -# predibase llama3 call +# Databricks bge-large-en call response = litellm.embedding( model="databricks/databricks-bge-large-en", input=["good morning from litellm"], @@ -184,7 +184,6 @@ response = litellm.embedding( ## Supported Databricks Chat Completion Models -Here's an example of using a Databricks models with LiteLLM | Model Name | Command | |----------------------------|------------------------------------------------------------------| @@ -196,8 +195,8 @@ Here's an example of using a Databricks models with LiteLLM | databricks-mpt-7b-instruct | `completion(model='databricks/databricks-mpt-7b-instruct', messages=messages)` | ## Supported Databricks Embedding Models -Here's an example of using a databricks models with LiteLLM | Model Name | Command | |----------------------------|------------------------------------------------------------------| -| databricks-bge-large-en | `completion(model='databricks/databricks-bge-large-en', messages=messages)` | +| databricks-bge-large-en | `embedding(model='databricks/databricks-bge-large-en', messages=messages)` | +| databricks-gte-large-en | `embedding(model='databricks/databricks-gte-large-en', messages=messages)` | diff --git a/docs/my-website/docs/providers/openai_compatible.md b/docs/my-website/docs/providers/openai_compatible.md index ff0e857099..33ab8fb411 100644 --- a/docs/my-website/docs/providers/openai_compatible.md +++ b/docs/my-website/docs/providers/openai_compatible.md @@ -18,7 +18,7 @@ import litellm import os response = litellm.completion( - model="openai/mistral, # add `openai/` prefix to model so litellm knows to route to OpenAI + model="openai/mistral", # add `openai/` prefix to model so litellm knows to route to OpenAI api_key="sk-1234", # api key to your openai compatible endpoint api_base="http://0.0.0.0:4000", # set API Base of your Custom OpenAI Endpoint messages=[ @@ -115,3 +115,18 @@ Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server + + +### Advanced - Disable System Messages + +Some VLLM models (e.g. gemma) don't support system messages. To map those requests to 'user' messages, use the `supports_system_message` flag. + +```yaml +model_list: +- model_name: my-custom-model + litellm_params: + model: openai/google/gemma + api_base: http://my-custom-base + api_key: "" + supports_system_message: False # 👈 KEY CHANGE +``` \ No newline at end of file diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index de1b5811f1..ce9e73bab1 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -123,6 +123,182 @@ print(completion(**data)) ### **JSON Schema** +From v`1.40.1+` LiteLLM supports sending `response_schema` as a param for Gemini-1.5-Pro on Vertex AI. For other models (e.g. `gemini-1.5-flash` or `claude-3-5-sonnet`), LiteLLM adds the schema to the message list with a user-controlled prompt. + +**Response Schema** + + + +```python +from litellm import completion +import json + +## SETUP ENVIRONMENT +# !gcloud auth application-default login - run this to add vertex credentials to your env + +messages = [ + { + "role": "user", + "content": "List 5 popular cookie recipes." + } +] + +response_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + } + + +completion( + model="vertex_ai_beta/gemini-1.5-pro", + messages=messages, + response_format={"type": "json_object", "response_schema": response_schema} # 👈 KEY CHANGE + ) + +print(json.loads(completion.choices[0].message.content)) +``` + + + + +1. Add model to config.yaml +```yaml +model_list: + - model_name: gemini-pro + litellm_params: + model: vertex_ai_beta/gemini-1.5-pro + vertex_project: "project-id" + vertex_location: "us-central1" + vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env +``` + +2. Start Proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "gemini-pro", + "messages": [ + {"role": "user", "content": "List 5 popular cookie recipes."} + ], + "response_format": {"type": "json_object", "response_schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + }} +} +' +``` + + + + +**Validate Schema** + +To validate the response_schema, set `enforce_validation: true`. + + + + +```python +from litellm import completion, JSONSchemaValidationError +try: + completion( + model="vertex_ai_beta/gemini-1.5-pro", + messages=messages, + response_format={ + "type": "json_object", + "response_schema": response_schema, + "enforce_validation": true # 👈 KEY CHANGE + } + ) +except JSONSchemaValidationError as e: + print("Raw Response: {}".format(e.raw_response)) + raise e +``` + + + +1. Add model to config.yaml +```yaml +model_list: + - model_name: gemini-pro + litellm_params: + model: vertex_ai_beta/gemini-1.5-pro + vertex_project: "project-id" + vertex_location: "us-central1" + vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env +``` + +2. Start Proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "gemini-pro", + "messages": [ + {"role": "user", "content": "List 5 popular cookie recipes."} + ], + "response_format": {"type": "json_object", "response_schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + }, + "enforce_validation": true + } +} +' +``` + + + + +LiteLLM will validate the response against the schema, and raise a `JSONSchemaValidationError` if the response does not match the schema. + +JSONSchemaValidationError inherits from `openai.APIError` + +Access the raw response with `e.raw_response` + +**Add to prompt yourself** + ```python from litellm import completion @@ -645,6 +821,86 @@ assert isinstance( ``` +## Usage - PDF / Videos / etc. Files + +Pass any file supported by Vertex AI, through LiteLLM. + + + + +```python +from litellm import completion + +response = completion( + model="vertex_ai/gemini-1.5-flash", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "You are a very professional document summarization specialist. Please summarize the given document."}, + { + "type": "image_url", + "image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + }, + ], + } + ], + max_tokens=300, +) + +print(response.choices[0]) + +``` + + + +1. Add model to config + +```yaml +- model_name: gemini-1.5-flash + litellm_params: + model: vertex_ai/gemini-1.5-flash + vertex_credentials: "/path/to/service_account.json" +``` + +2. Start Proxy + +``` +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl http://0.0.0.0:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{ + "model": "gemini-1.5-flash", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "You are a very professional document summarization specialist. Please summarize the given document" + }, + { + "type": "image_url", + "image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + }, + } + ] + } + ], + "max_tokens": 300 + }' + +``` + + + + ## Chat Models | Model Name | Function Call | |------------------|--------------------------------------| diff --git a/docs/my-website/docs/providers/volcano.md b/docs/my-website/docs/providers/volcano.md new file mode 100644 index 0000000000..1742a43d81 --- /dev/null +++ b/docs/my-website/docs/providers/volcano.md @@ -0,0 +1,98 @@ +# Volcano Engine (Volcengine) +https://www.volcengine.com/docs/82379/1263482 + +:::tip + +**We support ALL Volcengine NIM models, just set `model=volcengine/` as a prefix when sending litellm requests** + +::: + +## API Key +```python +# env variable +os.environ['VOLCENGINE_API_KEY'] +``` + +## Sample Usage +```python +from litellm import completion +import os + +os.environ['VOLCENGINE_API_KEY'] = "" +response = completion( + model="volcengine/", + messages=[ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ], + temperature=0.2, # optional + top_p=0.9, # optional + frequency_penalty=0.1, # optional + presence_penalty=0.1, # optional + max_tokens=10, # optional + stop=["\n\n"], # optional +) +print(response) +``` + +## Sample Usage - Streaming +```python +from litellm import completion +import os + +os.environ['VOLCENGINE_API_KEY'] = "" +response = completion( + model="volcengine/", + messages=[ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ], + stream=True, + temperature=0.2, # optional + top_p=0.9, # optional + frequency_penalty=0.1, # optional + presence_penalty=0.1, # optional + max_tokens=10, # optional + stop=["\n\n"], # optional +) + +for chunk in response: + print(chunk) +``` + + +## Supported Models - 💥 ALL Volcengine NIM Models Supported! +We support ALL `volcengine` models, just set `volcengine/` as a prefix when sending completion requests + +## Sample Usage - LiteLLM Proxy + +### Config.yaml setting + +```yaml +model_list: + - model_name: volcengine-model + litellm_params: + model: volcengine/ + api_key: os.environ/VOLCENGINE_API_KEY +``` + +### Send Request + +```shell +curl --location 'http://localhost:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "volcengine-model", + "messages": [ + { + "role": "user", + "content": "here is my api key. openai_api_key=sk-1234" + } + ] +}' +``` \ No newline at end of file diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 9381a14a44..00457dbc4d 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -277,17 +277,65 @@ curl --location 'http://0.0.0.0:4000/v1/model/info' \ --data '' ``` +## Wildcard Model Name (Add ALL MODELS from env) + +Dynamically call any model from any given provider without the need to predefine it in the config YAML file. As long as the relevant keys are in the environment (see [providers list](../providers/)), LiteLLM will make the call correctly. + + + +1. Setup config.yaml +``` +model_list: + - model_name: "*" # all requests where model not in your config go to this deployment + litellm_params: + model: "openai/*" # passes our validation check that a real provider is given +``` + +2. Start LiteLLM proxy + +``` +litellm --config /path/to/config.yaml +``` + +3. Try claude 3-5 sonnet from anthropic + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + {"role": "user", "content": "Hey, how'\''s it going?"}, + { + "role": "assistant", + "content": "I'\''m doing well. Would like to hear the rest of the story?" + }, + {"role": "user", "content": "Na"}, + { + "role": "assistant", + "content": "No problem, is there anything else i can help you with today?" + }, + { + "role": "user", + "content": "I think you'\''re getting cut off sometimes" + } + ] +} +' +``` + ## Load Balancing :::info -For more on this, go to [this page](./load_balancing.md) +For more on this, go to [this page](https://docs.litellm.ai/docs/proxy/load_balancing) ::: -Use this to call multiple instances of the same model and configure things like [routing strategy](../routing.md#advanced). +Use this to call multiple instances of the same model and configure things like [routing strategy](https://docs.litellm.ai/docs/routing#advanced). For optimal performance: - Set `tpm/rpm` per model deployment. Weighted picks are then based on the established tpm/rpm. -- Select your optimal routing strategy in `router_settings:routing_strategy`. +- Select your optimal routing strategy in `router_settings:routing_strategy`. LiteLLM supports ```python @@ -427,7 +475,7 @@ model_list: ```shell $ litellm --config /path/to/config.yaml -``` +``` ## Setting Embedding Models diff --git a/docs/my-website/docs/proxy/cost_tracking.md b/docs/my-website/docs/proxy/cost_tracking.md index f01e1042e3..fe3a462508 100644 --- a/docs/my-website/docs/proxy/cost_tracking.md +++ b/docs/my-website/docs/proxy/cost_tracking.md @@ -114,6 +114,16 @@ print(response) **Step3 - Verify Spend Tracked** That's IT. Now Verify your spend was tracked + + + +Expect to see `x-litellm-response-cost` in the response headers with calculated cost + + + + + + The following spend gets tracked in Table `LiteLLM_SpendLogs` ```json @@ -137,12 +147,16 @@ Navigate to the Usage Tab on the LiteLLM UI (found on https://your-proxy-endpoin -## API Endpoints to get Spend + + + +## ✨ (Enterprise) API Endpoints to get Spend #### Getting Spend Reports - To Charge Other Teams, Customers Use the `/global/spend/report` endpoint to get daily spend report per -- team -- customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm) +- Team +- Customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm) +- [LiteLLM API key](virtual_keys.md) @@ -325,6 +339,61 @@ curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end ``` + + + + + +👉 Key Change: Specify `group_by=api_key` + + +```shell +curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30&group_by=api_key' \ + -H 'Authorization: Bearer sk-1234' +``` + +##### Example Response + + +```shell +[ + { + "api_key": "ad64768847d05d978d62f623d872bff0f9616cc14b9c1e651c84d14fe3b9f539", + "total_cost": 0.0002157, + "total_input_tokens": 45.0, + "total_output_tokens": 1375.0, + "model_details": [ + { + "model": "gpt-3.5-turbo", + "total_cost": 0.0001095, + "total_input_tokens": 9, + "total_output_tokens": 70 + }, + { + "model": "llama3-8b-8192", + "total_cost": 0.0001062, + "total_input_tokens": 36, + "total_output_tokens": 1305 + } + ] + }, + { + "api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b", + "total_cost": 0.00012924, + "total_input_tokens": 36.0, + "total_output_tokens": 1593.0, + "model_details": [ + { + "model": "llama3-8b-8192", + "total_cost": 0.00012924, + "total_input_tokens": 36, + "total_output_tokens": 1593 + } + ] + } +] +``` + diff --git a/docs/my-website/docs/proxy/debugging.md b/docs/my-website/docs/proxy/debugging.md index 571a97c0ec..38680982a3 100644 --- a/docs/my-website/docs/proxy/debugging.md +++ b/docs/my-website/docs/proxy/debugging.md @@ -88,4 +88,31 @@ Expected Output: ```bash # no info statements -``` \ No newline at end of file +``` + +## Common Errors + +1. "No available deployments..." + +``` +No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a. +``` + +This can be caused due to all your models hitting rate limit errors, causing the cooldown to kick in. + +How to control this? +- Adjust the cooldown time + +```yaml +router_settings: + cooldown_time: 0 # 👈 KEY CHANGE +``` + +- Disable Cooldowns [NOT RECOMMENDED] + +```yaml +router_settings: + disable_cooldowns: True +``` + +This is not recommended, as it will lead to requests being routed to deployments over their tpm/rpm limit. \ No newline at end of file diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 9fff879e54..5dabba5ed3 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -6,21 +6,34 @@ import TabItem from '@theme/TabItem'; :::tip -Get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) +To get a license, get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ::: Features: -- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features) -- ✅ [Audit Logs](#audit-logs) -- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags) -- ✅ [Control available public, private routes](#control-available-public-private-routes) -- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) -- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) -- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding) -- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests) -- ✅ Reject calls from Blocked User list -- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) + +- **Security** + - ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features) + - ✅ [Audit Logs with retention policy](#audit-logs) + - ✅ [JWT-Auth](../docs/proxy/token_auth.md) + - ✅ [Control available public, private routes](#control-available-public-private-routes) + - ✅ [[BETA] AWS Key Manager v2 - Key Decryption](#beta-aws-key-manager---key-decryption) + - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) + - ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests) +- **Spend Tracking** + - ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags) + - ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend) +- **Advanced Metrics** + - ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens) +- **Guardrails, PII Masking, Content Moderation** + - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) + - ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) + - ✅ Reject calls from Blocked User list + - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) +- **Custom Branding** + - ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding) + - ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub) + - ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding) ## Audit Logs @@ -1019,4 +1032,35 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ Share a public page of available models for users - \ No newline at end of file + + + +## [BETA] AWS Key Manager - Key Decryption + +This is a beta feature, and subject to changes. + + +**Step 1.** Add `USE_AWS_KMS` to env + +```env +USE_AWS_KMS="True" +``` + +**Step 2.** Add `aws_kms/` to encrypted keys in env + +```env +DATABASE_URL="aws_kms/AQICAH.." +``` + +**Step 3.** Start proxy + +``` +$ litellm +``` + +How it works? +- Key Decryption runs before server starts up. [**Code**](https://github.com/BerriAI/litellm/blob/8571cb45e80cc561dc34bc6aa89611eb96b9fe3e/litellm/proxy/proxy_cli.py#L445) +- It adds the decrypted value to the `os.environ` for the python process. + +**Note:** Setting an environment variable within a Python script using os.environ will not make that variable accessible via SSH sessions or any other new processes that are started independently of the Python script. Environment variables set this way only affect the current process and its child processes. + diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index f9ed5db3dd..83bf8ee95d 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -1188,6 +1188,7 @@ litellm_settings: s3_region_name: us-west-2 # AWS Region Name for S3 s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 + s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets ``` diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md new file mode 100644 index 0000000000..1348a2fc1c --- /dev/null +++ b/docs/my-website/docs/proxy/pass_through.md @@ -0,0 +1,220 @@ +import Image from '@theme/IdealImage'; + +# ➡️ Create Pass Through Endpoints + +Add pass through routes to LiteLLM Proxy + +**Example:** Add a route `/v1/rerank` that forwards requests to `https://api.cohere.com/v1/rerank` through LiteLLM Proxy + + +💡 This allows making the following Request to LiteLLM Proxy +```shell +curl --request POST \ + --url http://localhost:4000/v1/rerank \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": ["Carson City is the capital city of the American state of Nevada."] + }' +``` + +## Tutorial - Pass through Cohere Re-Rank Endpoint + +**Step 1** Define pass through routes on [litellm config.yaml](configs.md) + +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json +``` + +**Step 2** Start Proxy Server in detailed_debug mode + +```shell +litellm --config config.yaml --detailed_debug +``` +**Step 3** Make Request to pass through endpoint + +Here `http://localhost:4000` is your litellm proxy endpoint + +```shell +curl --request POST \ + --url http://localhost:4000/v1/rerank \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": ["Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."] + }' +``` + + +🎉 **Expected Response** + +This request got forwarded from LiteLLM Proxy -> Defined Target URL (with headers) + +```shell +{ + "id": "37103a5b-8cfb-48d3-87c7-da288bedd429", + "results": [ + { + "index": 2, + "relevance_score": 0.999071 + }, + { + "index": 4, + "relevance_score": 0.7867867 + }, + { + "index": 0, + "relevance_score": 0.32713068 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } +} +``` + +## Tutorial - Pass Through Langfuse Requests + + +**Step 1** Define pass through routes on [litellm config.yaml](configs.md) + +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server + target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward + headers: + LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key + LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key +``` + +**Step 2** Start Proxy Server in detailed_debug mode + +```shell +litellm --config config.yaml --detailed_debug +``` +**Step 3** Make Request to pass through endpoint + +Run this code to make a sample trace +```python +from langfuse import Langfuse + +langfuse = Langfuse( + host="http://localhost:4000", # your litellm proxy endpoint + public_key="anything", # no key required since this is a pass through + secret_key="anything", # no key required since this is a pass through +) + +print("sending langfuse trace request") +trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough") +print("flushing langfuse request") +langfuse.flush() + +print("flushed langfuse request") +``` + + +🎉 **Expected Response** + +On success +Expect to see the following Trace Generated on your Langfuse Dashboard + + + +You will see the following endpoint called on your litellm proxy server logs + +```shell +POST /api/public/ingestion HTTP/1.1" 207 Multi-Status +``` + + +## ✨ [Enterprise] - Use LiteLLM keys/authentication on Pass Through Endpoints + +Use this if you want the pass through endpoint to honour LiteLLM keys/authentication + +Usage - set `auth: true` on the config +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + auth: true # 👈 Key change to use LiteLLM Auth / Keys + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + content-type: application/json + accept: application/json +``` + +Test Request with LiteLLM Key + +```shell +curl --request POST \ + --url http://localhost:4000/v1/rerank \ + --header 'accept: application/json' \ + --header 'Authorization: Bearer sk-1234'\ + --header 'content-type: application/json' \ + --data '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": ["Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."] + }' +``` + +## `pass_through_endpoints` Spec on config.yaml + +All possible values for `pass_through_endpoints` and what they mean + +**Example config** +```yaml +general_settings: + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json +``` + +**Spec** + +* `pass_through_endpoints` *list*: A collection of endpoint configurations for request forwarding. + * `path` *string*: The route to be added to the LiteLLM Proxy Server. + * `target` *string*: The URL to which requests for this path should be forwarded. + * `headers` *object*: Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint + * `Authorization` *string*: The authentication header for the target API. + * `content-type` *string*: The format specification for the request body. + * `accept` *string*: The expected response format from the server. + * `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse. + * `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse. + * `` *string*: Pass any custom header key/value pair \ No newline at end of file diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index 2c7481f4c6..6790b25b02 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # 📈 Prometheus metrics [BETA] LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll @@ -61,6 +64,56 @@ http://localhost:4000/metrics | `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)| +### ✨ (Enterprise) LLM Remaining Requests and Remaining Tokens +Set this on your config.yaml to allow you to track how close you are to hitting your TPM / RPM limits on each model group + +```yaml +litellm_settings: + success_callback: ["prometheus"] + failure_callback: ["prometheus"] + return_response_headers: true # ensures the LLM API calls track the response headers +``` + +| Metric Name | Description | +|----------------------|--------------------------------------| +| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment | +| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment | + +Example Metric + + + + +```shell +litellm_remaining_requests +{ + api_base="https://api.openai.com/v1", + api_provider="openai", + litellm_model_name="gpt-3.5-turbo", + model_group="gpt-3.5-turbo" +} +8998.0 +``` + + + + + +```shell +litellm_remaining_tokens +{ + api_base="https://api.openai.com/v1", + api_provider="openai", + litellm_model_name="gpt-3.5-turbo", + model_group="gpt-3.5-turbo" +} +999981.0 +``` + + + + + ## Monitor System Health To monitor the health of litellm adjacent services (redis / postgres), do: diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 27fefc7f4f..4349f985a2 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -4,7 +4,7 @@ import TabItem from '@theme/TabItem'; # 🤗 UI - Self-Serve -Allow users to creat their own keys on [Proxy UI](./ui.md). +Allow users to create their own keys on [Proxy UI](./ui.md). 1. Add user with permissions to a team on proxy diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md index cda3a46af9..cc1d5fe821 100644 --- a/docs/my-website/docs/proxy/user_keys.md +++ b/docs/my-website/docs/proxy/user_keys.md @@ -152,6 +152,27 @@ response = chat(messages) print(response) ``` + + + +```js +import { ChatOpenAI } from "@langchain/openai"; + + +const model = new ChatOpenAI({ + modelName: "gpt-4", + openAIApiKey: "sk-1234", + modelKwargs: {"metadata": "hello world"} // 👈 PASS Additional params here +}, { + basePath: "http://0.0.0.0:4000", +}); + +const message = await model.invoke("Hi there!"); + +console.log(message); + +``` + diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 240e6c8e04..905954e979 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -815,6 +815,35 @@ model_list: +**Expected Response** + +``` +No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a. +``` + +#### **Disable cooldowns** + + + + + +```python +from litellm import Router + + +router = Router(..., disable_cooldowns=True) +``` + + + +```yaml +router_settings: + disable_cooldowns: True +``` + + + + ### Retries For both async + sync functions, we support retrying failed requests. diff --git a/docs/my-website/docs/secret.md b/docs/my-website/docs/secret.md index 08c2e89d1f..91ae383686 100644 --- a/docs/my-website/docs/secret.md +++ b/docs/my-website/docs/secret.md @@ -8,7 +8,13 @@ LiteLLM supports reading secrets from Azure Key Vault and Infisical - [Infisical Secret Manager](#infisical-secret-manager) - [.env Files](#env-files) -## AWS Key Management Service +## AWS Key Management V1 + +:::tip + +[BETA] AWS Key Management v2 is on the enterprise tier. Go [here for docs](./proxy/enterprise.md#beta-aws-key-manager---key-decryption) + +::: Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment. diff --git a/docs/my-website/docs/text_to_speech.md b/docs/my-website/docs/text_to_speech.md index f4adf15eb5..73a12c4345 100644 --- a/docs/my-website/docs/text_to_speech.md +++ b/docs/my-website/docs/text_to_speech.md @@ -14,14 +14,6 @@ response = speech( model="openai/tts-1", voice="alloy", input="the quick brown fox jumped over the lazy dogs", - api_base=None, - api_key=None, - organization=None, - project=None, - max_retries=1, - timeout=600, - client=None, - optional_params={}, ) response.stream_to_file(speech_file_path) ``` @@ -84,4 +76,37 @@ curl http://0.0.0.0:4000/v1/audio/speech \ litellm --config /path/to/config.yaml # RUNNING on http://0.0.0.0:4000 +``` + +## Azure Usage + +**PROXY** + +```yaml + - model_name: azure/tts-1 + litellm_params: + model: azure/tts-1 + api_base: "os.environ/AZURE_API_BASE_TTS" + api_key: "os.environ/AZURE_API_KEY_TTS" + api_version: "os.environ/AZURE_API_VERSION" +``` + +**SDK** + +```python +from litellm import completion + +## set ENV variables +os.environ["AZURE_API_KEY"] = "" +os.environ["AZURE_API_BASE"] = "" +os.environ["AZURE_API_VERSION"] = "" + +# azure call +speech_file_path = Path(__file__).parent / "speech.mp3" +response = speech( + model="azure/ str: + return "Adafruit API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:adafruit)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/adobe.py b/enterprise/enterprise_hooks/secrets_plugins/adobe.py new file mode 100644 index 0000000000..7a58ccdf90 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/adobe.py @@ -0,0 +1,26 @@ +""" +This plugin searches for Adobe keys +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AdobeSecretDetector(RegexBasedDetector): + """Scans for Adobe client keys.""" + + @property + def secret_type(self) -> str: + return "Adobe Client Keys" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Adobe Client ID (OAuth Web) + re.compile( + r"""(?i)(?:adobe)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Adobe Client Secret + re.compile(r"(?i)\b((p8e-)[a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)"), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/age_secret_key.py b/enterprise/enterprise_hooks/secrets_plugins/age_secret_key.py new file mode 100644 index 0000000000..2c0c179102 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/age_secret_key.py @@ -0,0 +1,21 @@ +""" +This plugin searches for Age secret keys +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AgeSecretKeyDetector(RegexBasedDetector): + """Scans for Age secret keys.""" + + @property + def secret_type(self) -> str: + return "Age Secret Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile(r"""AGE-SECRET-KEY-1[QPZRY9X8GF2TVDW0S3JN54KHCE6MUA7L]{58}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/airtable_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/airtable_api_key.py new file mode 100644 index 0000000000..8abf4f6e44 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/airtable_api_key.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Airtable API keys +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AirtableApiKeyDetector(RegexBasedDetector): + """Scans for Airtable API keys.""" + + @property + def secret_type(self) -> str: + return "Airtable API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:airtable)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{17})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/algolia_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/algolia_api_key.py new file mode 100644 index 0000000000..cd6c16a8c0 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/algolia_api_key.py @@ -0,0 +1,21 @@ +""" +This plugin searches for Algolia API keys +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AlgoliaApiKeyDetector(RegexBasedDetector): + """Scans for Algolia API keys.""" + + @property + def secret_type(self) -> str: + return "Algolia API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile(r"""(?i)\b((LTAI)[a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/alibaba.py b/enterprise/enterprise_hooks/secrets_plugins/alibaba.py new file mode 100644 index 0000000000..5d071f1a9b --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/alibaba.py @@ -0,0 +1,26 @@ +""" +This plugin searches for Alibaba secrets +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AlibabaSecretDetector(RegexBasedDetector): + """Scans for Alibaba AccessKey IDs and Secret Keys.""" + + @property + def secret_type(self) -> str: + return "Alibaba Secrets" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Alibaba AccessKey ID + re.compile(r"""(?i)\b((LTAI)[a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), + # For Alibaba Secret Key + re.compile( + r"""(?i)(?:alibaba)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{30})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/asana.py b/enterprise/enterprise_hooks/secrets_plugins/asana.py new file mode 100644 index 0000000000..fd96872c63 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/asana.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Asana secrets +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AsanaSecretDetector(RegexBasedDetector): + """Scans for Asana Client IDs and Client Secrets.""" + + @property + def secret_type(self) -> str: + return "Asana Secrets" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Asana Client ID + re.compile( + r"""(?i)(?:asana)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # For Asana Client Secret + re.compile( + r"""(?i)(?:asana)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/atlassian_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/atlassian_api_token.py new file mode 100644 index 0000000000..42fd291ff4 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/atlassian_api_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Atlassian API tokens +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AtlassianApiTokenDetector(RegexBasedDetector): + """Scans for Atlassian API tokens.""" + + @property + def secret_type(self) -> str: + return "Atlassian API token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Atlassian API token + re.compile( + r"""(?i)(?:atlassian|confluence|jira)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{24})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/authress_access_key.py b/enterprise/enterprise_hooks/secrets_plugins/authress_access_key.py new file mode 100644 index 0000000000..ff7466fc44 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/authress_access_key.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Authress Service Client Access Keys +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class AuthressAccessKeyDetector(RegexBasedDetector): + """Scans for Authress Service Client Access Keys.""" + + @property + def secret_type(self) -> str: + return "Authress Service Client Access Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Authress Service Client Access Key + re.compile( + r"""(?i)\b((?:sc|ext|scauth|authress)_[a-z0-9]{5,30}\.[a-z0-9]{4,6}\.acc[_-][a-z0-9-]{10,32}\.[a-z0-9+/_=-]{30,120})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/beamer_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/beamer_api_token.py new file mode 100644 index 0000000000..5303e6262f --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/beamer_api_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Beamer API tokens +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class BeamerApiTokenDetector(RegexBasedDetector): + """Scans for Beamer API tokens.""" + + @property + def secret_type(self) -> str: + return "Beamer API token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Beamer API token + re.compile( + r"""(?i)(?:beamer)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(b_[a-z0-9=_\-]{44})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/bitbucket.py b/enterprise/enterprise_hooks/secrets_plugins/bitbucket.py new file mode 100644 index 0000000000..aae28dcc7d --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/bitbucket.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Bitbucket Client ID and Client Secret +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class BitbucketDetector(RegexBasedDetector): + """Scans for Bitbucket Client ID and Client Secret.""" + + @property + def secret_type(self) -> str: + return "Bitbucket Secrets" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Bitbucket Client ID + re.compile( + r"""(?i)(?:bitbucket)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # For Bitbucket Client Secret + re.compile( + r"""(?i)(?:bitbucket)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/bittrex.py b/enterprise/enterprise_hooks/secrets_plugins/bittrex.py new file mode 100644 index 0000000000..e8bd3347bb --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/bittrex.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Bittrex Access Key and Secret Key +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class BittrexDetector(RegexBasedDetector): + """Scans for Bittrex Access Key and Secret Key.""" + + @property + def secret_type(self) -> str: + return "Bittrex Secrets" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Bittrex Access Key + re.compile( + r"""(?i)(?:bittrex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # For Bittrex Secret Key + re.compile( + r"""(?i)(?:bittrex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/clojars_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/clojars_api_token.py new file mode 100644 index 0000000000..6eb41ec4bb --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/clojars_api_token.py @@ -0,0 +1,22 @@ +""" +This plugin searches for Clojars API tokens +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ClojarsApiTokenDetector(RegexBasedDetector): + """Scans for Clojars API tokens.""" + + @property + def secret_type(self) -> str: + return "Clojars API token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Clojars API token + re.compile(r"(?i)(CLOJARS_)[a-z0-9]{60}"), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/codecov_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/codecov_access_token.py new file mode 100644 index 0000000000..51001675f0 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/codecov_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Codecov Access Token +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class CodecovAccessTokenDetector(RegexBasedDetector): + """Scans for Codecov Access Token.""" + + @property + def secret_type(self) -> str: + return "Codecov Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Codecov Access Token + re.compile( + r"""(?i)(?:codecov)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/coinbase_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/coinbase_access_token.py new file mode 100644 index 0000000000..0af631be99 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/coinbase_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Coinbase Access Token +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class CoinbaseAccessTokenDetector(RegexBasedDetector): + """Scans for Coinbase Access Token.""" + + @property + def secret_type(self) -> str: + return "Coinbase Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Coinbase Access Token + re.compile( + r"""(?i)(?:coinbase)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/confluent.py b/enterprise/enterprise_hooks/secrets_plugins/confluent.py new file mode 100644 index 0000000000..aefbd42b94 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/confluent.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Confluent Access Token and Confluent Secret Key +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ConfluentDetector(RegexBasedDetector): + """Scans for Confluent Access Token and Confluent Secret Key.""" + + @property + def secret_type(self) -> str: + return "Confluent Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # For Confluent Access Token + re.compile( + r"""(?i)(?:confluent)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # For Confluent Secret Key + re.compile( + r"""(?i)(?:confluent)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/contentful_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/contentful_api_token.py new file mode 100644 index 0000000000..33817dc4d8 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/contentful_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Contentful delivery API token. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ContentfulApiTokenDetector(RegexBasedDetector): + """Scans for Contentful delivery API token.""" + + @property + def secret_type(self) -> str: + return "Contentful API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:contentful)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{43})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/databricks_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/databricks_api_token.py new file mode 100644 index 0000000000..9e47355b1c --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/databricks_api_token.py @@ -0,0 +1,21 @@ +""" +This plugin searches for Databricks API token. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DatabricksApiTokenDetector(RegexBasedDetector): + """Scans for Databricks API token.""" + + @property + def secret_type(self) -> str: + return "Databricks API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile(r"""(?i)\b(dapi[a-h0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/datadog_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/datadog_access_token.py new file mode 100644 index 0000000000..bdb430d9bc --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/datadog_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Datadog Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DatadogAccessTokenDetector(RegexBasedDetector): + """Scans for Datadog Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Datadog Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:datadog)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/defined_networking_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/defined_networking_api_token.py new file mode 100644 index 0000000000..b23cdb4543 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/defined_networking_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Defined Networking API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DefinedNetworkingApiTokenDetector(RegexBasedDetector): + """Scans for Defined Networking API Tokens.""" + + @property + def secret_type(self) -> str: + return "Defined Networking API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:dnkey)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(dnkey-[a-z0-9=_\-]{26}-[a-z0-9=_\-]{52})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/digitalocean.py b/enterprise/enterprise_hooks/secrets_plugins/digitalocean.py new file mode 100644 index 0000000000..5ffc4f600e --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/digitalocean.py @@ -0,0 +1,26 @@ +""" +This plugin searches for DigitalOcean tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DigitaloceanDetector(RegexBasedDetector): + """Scans for various DigitalOcean Tokens.""" + + @property + def secret_type(self) -> str: + return "DigitalOcean Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # OAuth Access Token + re.compile(r"""(?i)\b(doo_v1_[a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), + # Personal Access Token + re.compile(r"""(?i)\b(dop_v1_[a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), + # OAuth Refresh Token + re.compile(r"""(?i)\b(dor_v1_[a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/discord.py b/enterprise/enterprise_hooks/secrets_plugins/discord.py new file mode 100644 index 0000000000..c51406b606 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/discord.py @@ -0,0 +1,32 @@ +""" +This plugin searches for Discord Client tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DiscordDetector(RegexBasedDetector): + """Scans for various Discord Client Tokens.""" + + @property + def secret_type(self) -> str: + return "Discord Client Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Discord API key + re.compile( + r"""(?i)(?:discord)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Discord client ID + re.compile( + r"""(?i)(?:discord)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9]{18})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Discord client secret + re.compile( + r"""(?i)(?:discord)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/doppler_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/doppler_api_token.py new file mode 100644 index 0000000000..56c594fc1f --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/doppler_api_token.py @@ -0,0 +1,22 @@ +""" +This plugin searches for Doppler API tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DopplerApiTokenDetector(RegexBasedDetector): + """Scans for Doppler API Tokens.""" + + @property + def secret_type(self) -> str: + return "Doppler API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Doppler API token + re.compile(r"""(?i)dp\.pt\.[a-z0-9]{43}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/droneci_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/droneci_access_token.py new file mode 100644 index 0000000000..8afffb8026 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/droneci_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Droneci Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DroneciAccessTokenDetector(RegexBasedDetector): + """Scans for Droneci Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Droneci Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Droneci Access Token + re.compile( + r"""(?i)(?:droneci)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/dropbox.py b/enterprise/enterprise_hooks/secrets_plugins/dropbox.py new file mode 100644 index 0000000000..b19815b26d --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/dropbox.py @@ -0,0 +1,32 @@ +""" +This plugin searches for Dropbox tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DropboxDetector(RegexBasedDetector): + """Scans for various Dropbox Tokens.""" + + @property + def secret_type(self) -> str: + return "Dropbox Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Dropbox API secret + re.compile( + r"""(?i)(?:dropbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{15})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Dropbox long-lived API token + re.compile( + r"""(?i)(?:dropbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{11}(AAAAAAAAAA)[a-z0-9\-_=]{43})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Dropbox short-lived API token + re.compile( + r"""(?i)(?:dropbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(sl\.[a-z0-9\-=_]{135})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/duffel_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/duffel_api_token.py new file mode 100644 index 0000000000..aab681598c --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/duffel_api_token.py @@ -0,0 +1,22 @@ +""" +This plugin searches for Duffel API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DuffelApiTokenDetector(RegexBasedDetector): + """Scans for Duffel API Tokens.""" + + @property + def secret_type(self) -> str: + return "Duffel API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Duffel API Token + re.compile(r"""(?i)duffel_(test|live)_[a-z0-9_\-=]{43}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/dynatrace_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/dynatrace_api_token.py new file mode 100644 index 0000000000..caf7dd7197 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/dynatrace_api_token.py @@ -0,0 +1,22 @@ +""" +This plugin searches for Dynatrace API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class DynatraceApiTokenDetector(RegexBasedDetector): + """Scans for Dynatrace API Tokens.""" + + @property + def secret_type(self) -> str: + return "Dynatrace API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Dynatrace API Token + re.compile(r"""(?i)dt0c01\.[a-z0-9]{24}\.[a-z0-9]{64}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/easypost.py b/enterprise/enterprise_hooks/secrets_plugins/easypost.py new file mode 100644 index 0000000000..73d27cb491 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/easypost.py @@ -0,0 +1,24 @@ +""" +This plugin searches for EasyPost tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class EasyPostDetector(RegexBasedDetector): + """Scans for various EasyPost Tokens.""" + + @property + def secret_type(self) -> str: + return "EasyPost Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # EasyPost API token + re.compile(r"""(?i)\bEZAK[a-z0-9]{54}"""), + # EasyPost test API token + re.compile(r"""(?i)\bEZTK[a-z0-9]{54}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/etsy_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/etsy_access_token.py new file mode 100644 index 0000000000..1775a4b41d --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/etsy_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Etsy Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class EtsyAccessTokenDetector(RegexBasedDetector): + """Scans for Etsy Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Etsy Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Etsy Access Token + re.compile( + r"""(?i)(?:etsy)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{24})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/facebook_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/facebook_access_token.py new file mode 100644 index 0000000000..edc7d080c6 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/facebook_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Facebook Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FacebookAccessTokenDetector(RegexBasedDetector): + """Scans for Facebook Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Facebook Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Facebook Access Token + re.compile( + r"""(?i)(?:facebook)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/fastly_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/fastly_api_token.py new file mode 100644 index 0000000000..4d451cb746 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/fastly_api_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Fastly API keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FastlyApiKeyDetector(RegexBasedDetector): + """Scans for Fastly API keys.""" + + @property + def secret_type(self) -> str: + return "Fastly API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Fastly API key + re.compile( + r"""(?i)(?:fastly)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/finicity.py b/enterprise/enterprise_hooks/secrets_plugins/finicity.py new file mode 100644 index 0000000000..97414352fc --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/finicity.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Finicity API tokens and Client Secrets. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FinicityDetector(RegexBasedDetector): + """Scans for Finicity API tokens and Client Secrets.""" + + @property + def secret_type(self) -> str: + return "Finicity Credentials" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Finicity API token + re.compile( + r"""(?i)(?:finicity)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Finicity Client Secret + re.compile( + r"""(?i)(?:finicity)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/finnhub_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/finnhub_access_token.py new file mode 100644 index 0000000000..eeb09682b0 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/finnhub_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Finnhub Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FinnhubAccessTokenDetector(RegexBasedDetector): + """Scans for Finnhub Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Finnhub Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Finnhub Access Token + re.compile( + r"""(?i)(?:finnhub)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{20})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/flickr_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/flickr_access_token.py new file mode 100644 index 0000000000..530628547b --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/flickr_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Flickr Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FlickrAccessTokenDetector(RegexBasedDetector): + """Scans for Flickr Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Flickr Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Flickr Access Token + re.compile( + r"""(?i)(?:flickr)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/flutterwave.py b/enterprise/enterprise_hooks/secrets_plugins/flutterwave.py new file mode 100644 index 0000000000..fc46ba2222 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/flutterwave.py @@ -0,0 +1,26 @@ +""" +This plugin searches for Flutterwave API keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FlutterwaveDetector(RegexBasedDetector): + """Scans for Flutterwave API Keys.""" + + @property + def secret_type(self) -> str: + return "Flutterwave API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Flutterwave Encryption Key + re.compile(r"""(?i)FLWSECK_TEST-[a-h0-9]{12}"""), + # Flutterwave Public Key + re.compile(r"""(?i)FLWPUBK_TEST-[a-h0-9]{32}-X"""), + # Flutterwave Secret Key + re.compile(r"""(?i)FLWSECK_TEST-[a-h0-9]{32}-X"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/frameio_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/frameio_api_token.py new file mode 100644 index 0000000000..9524e873d4 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/frameio_api_token.py @@ -0,0 +1,22 @@ +""" +This plugin searches for Frame.io API tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FrameIoApiTokenDetector(RegexBasedDetector): + """Scans for Frame.io API Tokens.""" + + @property + def secret_type(self) -> str: + return "Frame.io API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Frame.io API token + re.compile(r"""(?i)fio-u-[a-z0-9\-_=]{64}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/freshbooks_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/freshbooks_access_token.py new file mode 100644 index 0000000000..b6b16e2b83 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/freshbooks_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Freshbooks Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class FreshbooksAccessTokenDetector(RegexBasedDetector): + """Scans for Freshbooks Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Freshbooks Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Freshbooks Access Token + re.compile( + r"""(?i)(?:freshbooks)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/gcp_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/gcp_api_key.py new file mode 100644 index 0000000000..6055cc2622 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/gcp_api_key.py @@ -0,0 +1,24 @@ +""" +This plugin searches for GCP API keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class GCPApiKeyDetector(RegexBasedDetector): + """Scans for GCP API keys.""" + + @property + def secret_type(self) -> str: + return "GCP API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # GCP API Key + re.compile( + r"""(?i)\b(AIza[0-9A-Za-z\\-_]{35})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/github_token.py b/enterprise/enterprise_hooks/secrets_plugins/github_token.py new file mode 100644 index 0000000000..acb5e3fc76 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/github_token.py @@ -0,0 +1,26 @@ +""" +This plugin searches for GitHub tokens +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class GitHubTokenCustomDetector(RegexBasedDetector): + """Scans for GitHub tokens.""" + + @property + def secret_type(self) -> str: + return "GitHub Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # GitHub App/Personal Access/OAuth Access/Refresh Token + # ref. https://github.blog/2021-04-05-behind-githubs-new-authentication-token-formats/ + re.compile(r"(?:ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9_]{36}"), + # GitHub Fine-Grained Personal Access Token + re.compile(r"github_pat_[0-9a-zA-Z_]{82}"), + re.compile(r"gho_[0-9a-zA-Z]{36}"), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/gitlab.py b/enterprise/enterprise_hooks/secrets_plugins/gitlab.py new file mode 100644 index 0000000000..2277d8a2d3 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/gitlab.py @@ -0,0 +1,26 @@ +""" +This plugin searches for GitLab secrets. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class GitLabDetector(RegexBasedDetector): + """Scans for GitLab Secrets.""" + + @property + def secret_type(self) -> str: + return "GitLab Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # GitLab Personal Access Token + re.compile(r"""glpat-[0-9a-zA-Z\-\_]{20}"""), + # GitLab Pipeline Trigger Token + re.compile(r"""glptt-[0-9a-f]{40}"""), + # GitLab Runner Registration Token + re.compile(r"""GR1348941[0-9a-zA-Z\-\_]{20}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/gitter_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/gitter_access_token.py new file mode 100644 index 0000000000..1febe70cb9 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/gitter_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Gitter Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class GitterAccessTokenDetector(RegexBasedDetector): + """Scans for Gitter Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Gitter Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Gitter Access Token + re.compile( + r"""(?i)(?:gitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/gocardless_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/gocardless_api_token.py new file mode 100644 index 0000000000..240f6e4c58 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/gocardless_api_token.py @@ -0,0 +1,25 @@ +""" +This plugin searches for GoCardless API tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class GoCardlessApiTokenDetector(RegexBasedDetector): + """Scans for GoCardless API Tokens.""" + + @property + def secret_type(self) -> str: + return "GoCardless API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # GoCardless API token + re.compile( + r"""(?:gocardless)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(live_[a-z0-9\-_=]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""", + re.IGNORECASE, + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/grafana.py b/enterprise/enterprise_hooks/secrets_plugins/grafana.py new file mode 100644 index 0000000000..fd37f0f639 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/grafana.py @@ -0,0 +1,32 @@ +""" +This plugin searches for Grafana secrets. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class GrafanaDetector(RegexBasedDetector): + """Scans for Grafana Secrets.""" + + @property + def secret_type(self) -> str: + return "Grafana Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Grafana API key or Grafana Cloud API key + re.compile( + r"""(?i)\b(eyJrIjoi[A-Za-z0-9]{70,400}={0,2})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Grafana Cloud API token + re.compile( + r"""(?i)\b(glc_[A-Za-z0-9+/]{32,400}={0,2})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Grafana Service Account token + re.compile( + r"""(?i)\b(glsa_[A-Za-z0-9]{32}_[A-Fa-f0-9]{8})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/hashicorp_tf_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/hashicorp_tf_api_token.py new file mode 100644 index 0000000000..97013fd846 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/hashicorp_tf_api_token.py @@ -0,0 +1,22 @@ +""" +This plugin searches for HashiCorp Terraform user/org API tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class HashiCorpTFApiTokenDetector(RegexBasedDetector): + """Scans for HashiCorp Terraform User/Org API Tokens.""" + + @property + def secret_type(self) -> str: + return "HashiCorp Terraform API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # HashiCorp Terraform user/org API token + re.compile(r"""(?i)[a-z0-9]{14}\.atlasv1\.[a-z0-9\-_=]{60,70}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/heroku_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/heroku_api_key.py new file mode 100644 index 0000000000..53be8aa486 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/heroku_api_key.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Heroku API Keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class HerokuApiKeyDetector(RegexBasedDetector): + """Scans for Heroku API Keys.""" + + @property + def secret_type(self) -> str: + return "Heroku API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:heroku)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/hubspot_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/hubspot_api_key.py new file mode 100644 index 0000000000..230ef659ba --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/hubspot_api_key.py @@ -0,0 +1,24 @@ +""" +This plugin searches for HubSpot API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class HubSpotApiTokenDetector(RegexBasedDetector): + """Scans for HubSpot API Tokens.""" + + @property + def secret_type(self) -> str: + return "HubSpot API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # HubSpot API Token + re.compile( + r"""(?i)(?:hubspot)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/huggingface.py b/enterprise/enterprise_hooks/secrets_plugins/huggingface.py new file mode 100644 index 0000000000..be83a3a0d5 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/huggingface.py @@ -0,0 +1,26 @@ +""" +This plugin searches for Hugging Face Access and Organization API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class HuggingFaceDetector(RegexBasedDetector): + """Scans for Hugging Face Tokens.""" + + @property + def secret_type(self) -> str: + return "Hugging Face Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Hugging Face Access token + re.compile(r"""(?:^|[\\'"` >=:])(hf_[a-zA-Z]{34})(?:$|[\\'"` <])"""), + # Hugging Face Organization API token + re.compile( + r"""(?:^|[\\'"` >=:\(,)])(api_org_[a-zA-Z]{34})(?:$|[\\'"` <\),])""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/intercom_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/intercom_api_key.py new file mode 100644 index 0000000000..24e16fc73a --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/intercom_api_key.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Intercom API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class IntercomApiTokenDetector(RegexBasedDetector): + """Scans for Intercom API Tokens.""" + + @property + def secret_type(self) -> str: + return "Intercom API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:intercom)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{60})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/jfrog.py b/enterprise/enterprise_hooks/secrets_plugins/jfrog.py new file mode 100644 index 0000000000..3eabbfe3a4 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/jfrog.py @@ -0,0 +1,28 @@ +""" +This plugin searches for JFrog-related secrets like API Key and Identity Token. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class JFrogDetector(RegexBasedDetector): + """Scans for JFrog-related secrets.""" + + @property + def secret_type(self) -> str: + return "JFrog Secrets" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # JFrog API Key + re.compile( + r"""(?i)(?:jfrog|artifactory|bintray|xray)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{73})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # JFrog Identity Token + re.compile( + r"""(?i)(?:jfrog|artifactory|bintray|xray)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/jwt.py b/enterprise/enterprise_hooks/secrets_plugins/jwt.py new file mode 100644 index 0000000000..6658a09502 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/jwt.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Base64-encoded JSON Web Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class JWTBase64Detector(RegexBasedDetector): + """Scans for Base64-encoded JSON Web Tokens.""" + + @property + def secret_type(self) -> str: + return "Base64-encoded JSON Web Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Base64-encoded JSON Web Token + re.compile( + r"""\bZXlK(?:(?PaGJHY2lPaU)|(?PaGNIVWlPaU)|(?PaGNIWWlPaU)|(?PaGRXUWlPaU)|(?PaU5qUWlP)|(?PamNtbDBJanBi)|(?PamRIa2lPaU)|(?PbGNHc2lPbn)|(?PbGJtTWlPaU)|(?PcWEzVWlPaU)|(?PcWQyc2lPb)|(?PcGMzTWlPaU)|(?PcGRpSTZJ)|(?PcmFXUWlP)|(?PclpYbGZiM0J6SWpwY)|(?PcmRIa2lPaUp)|(?PdWIyNWpaU0k2)|(?Pd01tTWlP)|(?Pd01uTWlPaU)|(?Pd2NIUWlPaU)|(?PemRXSWlPaU)|(?PemRuUWlP)|(?PMFlXY2lPaU)|(?PMGVYQWlPaUp)|(?PMWNtd2l)|(?PMWMyVWlPaUp)|(?PMlpYSWlPaU)|(?PMlpYSnphVzl1SWpv)|(?PNElqb2)|(?PNE5XTWlP)|(?PNE5YUWlPaU)|(?PNE5YUWpVekkxTmlJNkl)|(?PNE5YVWlPaU)|(?PNmFYQWlPaU))[a-zA-Z0-9\/\\_+\-\r\n]{40,}={0,2}""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/kraken_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/kraken_access_token.py new file mode 100644 index 0000000000..cb7357cfd9 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/kraken_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Kraken Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class KrakenAccessTokenDetector(RegexBasedDetector): + """Scans for Kraken Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Kraken Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Kraken Access Token + re.compile( + r"""(?i)(?:kraken)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9\/=_\+\-]{80,90})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/kucoin.py b/enterprise/enterprise_hooks/secrets_plugins/kucoin.py new file mode 100644 index 0000000000..02e990bd8b --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/kucoin.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Kucoin Access Tokens and Secret Keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class KucoinDetector(RegexBasedDetector): + """Scans for Kucoin Access Tokens and Secret Keys.""" + + @property + def secret_type(self) -> str: + return "Kucoin Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Kucoin Access Token + re.compile( + r"""(?i)(?:kucoin)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{24})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Kucoin Secret Key + re.compile( + r"""(?i)(?:kucoin)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/launchdarkly_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/launchdarkly_access_token.py new file mode 100644 index 0000000000..9779909847 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/launchdarkly_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Launchdarkly Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class LaunchdarklyAccessTokenDetector(RegexBasedDetector): + """Scans for Launchdarkly Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Launchdarkly Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:launchdarkly)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/linear.py b/enterprise/enterprise_hooks/secrets_plugins/linear.py new file mode 100644 index 0000000000..1224b5ec46 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/linear.py @@ -0,0 +1,26 @@ +""" +This plugin searches for Linear API Tokens and Linear Client Secrets. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class LinearDetector(RegexBasedDetector): + """Scans for Linear secrets.""" + + @property + def secret_type(self) -> str: + return "Linear Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Linear API Token + re.compile(r"""(?i)lin_api_[a-z0-9]{40}"""), + # Linear Client Secret + re.compile( + r"""(?i)(?:linear)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/linkedin.py b/enterprise/enterprise_hooks/secrets_plugins/linkedin.py new file mode 100644 index 0000000000..53ff0c30aa --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/linkedin.py @@ -0,0 +1,28 @@ +""" +This plugin searches for LinkedIn Client IDs and LinkedIn Client secrets. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class LinkedInDetector(RegexBasedDetector): + """Scans for LinkedIn secrets.""" + + @property + def secret_type(self) -> str: + return "LinkedIn Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # LinkedIn Client ID + re.compile( + r"""(?i)(?:linkedin|linked-in)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{14})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # LinkedIn Client secret + re.compile( + r"""(?i)(?:linkedin|linked-in)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/lob.py b/enterprise/enterprise_hooks/secrets_plugins/lob.py new file mode 100644 index 0000000000..623ac4f1f9 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/lob.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Lob API secrets and Lob Publishable API keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class LobDetector(RegexBasedDetector): + """Scans for Lob secrets.""" + + @property + def secret_type(self) -> str: + return "Lob Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Lob API Key + re.compile( + r"""(?i)(?:lob)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}((live|test)_[a-f0-9]{35})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Lob Publishable API Key + re.compile( + r"""(?i)(?:lob)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}((test|live)_pub_[a-f0-9]{31})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/mailgun.py b/enterprise/enterprise_hooks/secrets_plugins/mailgun.py new file mode 100644 index 0000000000..c403d24546 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/mailgun.py @@ -0,0 +1,32 @@ +""" +This plugin searches for Mailgun API secrets, public validation keys, and webhook signing keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class MailgunDetector(RegexBasedDetector): + """Scans for Mailgun secrets.""" + + @property + def secret_type(self) -> str: + return "Mailgun Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Mailgun Private API Token + re.compile( + r"""(?i)(?:mailgun)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(key-[a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Mailgun Public Validation Key + re.compile( + r"""(?i)(?:mailgun)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(pubkey-[a-f0-9]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Mailgun Webhook Signing Key + re.compile( + r"""(?i)(?:mailgun)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-h0-9]{32}-[a-h0-9]{8}-[a-h0-9]{8})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/mapbox_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/mapbox_api_token.py new file mode 100644 index 0000000000..0326b7102a --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/mapbox_api_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for MapBox API tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class MapBoxApiTokenDetector(RegexBasedDetector): + """Scans for MapBox API tokens.""" + + @property + def secret_type(self) -> str: + return "MapBox API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # MapBox API Token + re.compile( + r"""(?i)(?:mapbox)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(pk\.[a-z0-9]{60}\.[a-z0-9]{22})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/mattermost_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/mattermost_access_token.py new file mode 100644 index 0000000000..d65b0e7554 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/mattermost_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Mattermost Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class MattermostAccessTokenDetector(RegexBasedDetector): + """Scans for Mattermost Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Mattermost Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Mattermost Access Token + re.compile( + r"""(?i)(?:mattermost)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{26})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/messagebird.py b/enterprise/enterprise_hooks/secrets_plugins/messagebird.py new file mode 100644 index 0000000000..6adc8317a8 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/messagebird.py @@ -0,0 +1,28 @@ +""" +This plugin searches for MessageBird API tokens and client IDs. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class MessageBirdDetector(RegexBasedDetector): + """Scans for MessageBird secrets.""" + + @property + def secret_type(self) -> str: + return "MessageBird Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # MessageBird API Token + re.compile( + r"""(?i)(?:messagebird|message-bird|message_bird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{25})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # MessageBird Client ID + re.compile( + r"""(?i)(?:messagebird|message-bird|message_bird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/microsoft_teams_webhook.py b/enterprise/enterprise_hooks/secrets_plugins/microsoft_teams_webhook.py new file mode 100644 index 0000000000..298fd81b0a --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/microsoft_teams_webhook.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Microsoft Teams Webhook URLs. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class MicrosoftTeamsWebhookDetector(RegexBasedDetector): + """Scans for Microsoft Teams Webhook URLs.""" + + @property + def secret_type(self) -> str: + return "Microsoft Teams Webhook" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Microsoft Teams Webhook + re.compile( + r"""https:\/\/[a-z0-9]+\.webhook\.office\.com\/webhookb2\/[a-z0-9]{8}-([a-z0-9]{4}-){3}[a-z0-9]{12}@[a-z0-9]{8}-([a-z0-9]{4}-){3}[a-z0-9]{12}\/IncomingWebhook\/[a-z0-9]{32}\/[a-z0-9]{8}-([a-z0-9]{4}-){3}[a-z0-9]{12}""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/netlify_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/netlify_access_token.py new file mode 100644 index 0000000000..cc7a575a42 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/netlify_access_token.py @@ -0,0 +1,24 @@ +""" +This plugin searches for Netlify Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class NetlifyAccessTokenDetector(RegexBasedDetector): + """Scans for Netlify Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Netlify Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Netlify Access Token + re.compile( + r"""(?i)(?:netlify)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{40,46})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/new_relic.py b/enterprise/enterprise_hooks/secrets_plugins/new_relic.py new file mode 100644 index 0000000000..cef640155c --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/new_relic.py @@ -0,0 +1,32 @@ +""" +This plugin searches for New Relic API tokens and keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class NewRelicDetector(RegexBasedDetector): + """Scans for New Relic API tokens and keys.""" + + @property + def secret_type(self) -> str: + return "New Relic API Secrets" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # New Relic ingest browser API token + re.compile( + r"""(?i)(?:new-relic|newrelic|new_relic)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(NRJS-[a-f0-9]{19})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # New Relic user API ID + re.compile( + r"""(?i)(?:new-relic|newrelic|new_relic)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # New Relic user API Key + re.compile( + r"""(?i)(?:new-relic|newrelic|new_relic)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(NRAK-[a-z0-9]{27})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/nytimes_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/nytimes_access_token.py new file mode 100644 index 0000000000..567b885e5a --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/nytimes_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for New York Times Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class NYTimesAccessTokenDetector(RegexBasedDetector): + """Scans for New York Times Access Tokens.""" + + @property + def secret_type(self) -> str: + return "New York Times Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:nytimes|new-york-times,|newyorktimes)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{32})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/okta_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/okta_access_token.py new file mode 100644 index 0000000000..97109767b0 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/okta_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Okta Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class OktaAccessTokenDetector(RegexBasedDetector): + """Scans for Okta Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Okta Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:okta)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9=_\-]{42})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/openai_api_key.py b/enterprise/enterprise_hooks/secrets_plugins/openai_api_key.py new file mode 100644 index 0000000000..c5d20f7590 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/openai_api_key.py @@ -0,0 +1,19 @@ +""" +This plugin searches for OpenAI API Keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class OpenAIApiKeyDetector(RegexBasedDetector): + """Scans for OpenAI API Keys.""" + + @property + def secret_type(self) -> str: + return "Strict OpenAI API Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [re.compile(r"""(sk-[a-zA-Z0-9]{5,})""")] diff --git a/enterprise/enterprise_hooks/secrets_plugins/planetscale.py b/enterprise/enterprise_hooks/secrets_plugins/planetscale.py new file mode 100644 index 0000000000..23a53667e3 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/planetscale.py @@ -0,0 +1,32 @@ +""" +This plugin searches for PlanetScale API tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class PlanetScaleDetector(RegexBasedDetector): + """Scans for PlanetScale API Tokens.""" + + @property + def secret_type(self) -> str: + return "PlanetScale API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # the PlanetScale API token + re.compile( + r"""(?i)\b(pscale_tkn_[a-z0-9=\-_\.]{32,64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # the PlanetScale OAuth token + re.compile( + r"""(?i)\b(pscale_oauth_[a-z0-9=\-_\.]{32,64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # the PlanetScale password + re.compile( + r"""(?i)\b(pscale_pw_[a-z0-9=\-_\.]{32,64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/postman_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/postman_api_token.py new file mode 100644 index 0000000000..9469e8191c --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/postman_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Postman API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class PostmanApiTokenDetector(RegexBasedDetector): + """Scans for Postman API Tokens.""" + + @property + def secret_type(self) -> str: + return "Postman API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)\b(PMAK-[a-f0-9]{24}-[a-f0-9]{34})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/prefect_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/prefect_api_token.py new file mode 100644 index 0000000000..35cdb71cae --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/prefect_api_token.py @@ -0,0 +1,19 @@ +""" +This plugin searches for Prefect API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class PrefectApiTokenDetector(RegexBasedDetector): + """Scans for Prefect API Tokens.""" + + @property + def secret_type(self) -> str: + return "Prefect API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [re.compile(r"""(?i)\b(pnu_[a-z0-9]{36})(?:['|\"|\n|\r|\s|\x60|;]|$)""")] diff --git a/enterprise/enterprise_hooks/secrets_plugins/pulumi_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/pulumi_api_token.py new file mode 100644 index 0000000000..bae4ce211b --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/pulumi_api_token.py @@ -0,0 +1,19 @@ +""" +This plugin searches for Pulumi API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class PulumiApiTokenDetector(RegexBasedDetector): + """Scans for Pulumi API Tokens.""" + + @property + def secret_type(self) -> str: + return "Pulumi API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [re.compile(r"""(?i)\b(pul-[a-f0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""")] diff --git a/enterprise/enterprise_hooks/secrets_plugins/pypi_upload_token.py b/enterprise/enterprise_hooks/secrets_plugins/pypi_upload_token.py new file mode 100644 index 0000000000..d4cc913857 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/pypi_upload_token.py @@ -0,0 +1,19 @@ +""" +This plugin searches for PyPI Upload Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class PyPiUploadTokenDetector(RegexBasedDetector): + """Scans for PyPI Upload Tokens.""" + + @property + def secret_type(self) -> str: + return "PyPI Upload Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [re.compile(r"""pypi-AgEIcHlwaS5vcmc[A-Za-z0-9\-_]{50,1000}""")] diff --git a/enterprise/enterprise_hooks/secrets_plugins/rapidapi_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/rapidapi_access_token.py new file mode 100644 index 0000000000..18b2346148 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/rapidapi_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for RapidAPI Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class RapidApiAccessTokenDetector(RegexBasedDetector): + """Scans for RapidAPI Access Tokens.""" + + @property + def secret_type(self) -> str: + return "RapidAPI Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:rapidapi)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9_-]{50})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/readme_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/readme_api_token.py new file mode 100644 index 0000000000..47bdffb120 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/readme_api_token.py @@ -0,0 +1,21 @@ +""" +This plugin searches for Readme API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ReadmeApiTokenDetector(RegexBasedDetector): + """Scans for Readme API Tokens.""" + + @property + def secret_type(self) -> str: + return "Readme API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile(r"""(?i)\b(rdme_[a-z0-9]{70})(?:['|\"|\n|\r|\s|\x60|;]|$)""") + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/rubygems_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/rubygems_api_token.py new file mode 100644 index 0000000000..d49c58e73e --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/rubygems_api_token.py @@ -0,0 +1,21 @@ +""" +This plugin searches for Rubygem API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class RubygemsApiTokenDetector(RegexBasedDetector): + """Scans for Rubygem API Tokens.""" + + @property + def secret_type(self) -> str: + return "Rubygem API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile(r"""(?i)\b(rubygems_[a-f0-9]{48})(?:['|\"|\n|\r|\s|\x60|;]|$)""") + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/scalingo_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/scalingo_api_token.py new file mode 100644 index 0000000000..3f8a59ee41 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/scalingo_api_token.py @@ -0,0 +1,19 @@ +""" +This plugin searches for Scalingo API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ScalingoApiTokenDetector(RegexBasedDetector): + """Scans for Scalingo API Tokens.""" + + @property + def secret_type(self) -> str: + return "Scalingo API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [re.compile(r"""\btk-us-[a-zA-Z0-9-_]{48}\b""")] diff --git a/enterprise/enterprise_hooks/secrets_plugins/sendbird.py b/enterprise/enterprise_hooks/secrets_plugins/sendbird.py new file mode 100644 index 0000000000..4b270d71e5 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/sendbird.py @@ -0,0 +1,28 @@ +""" +This plugin searches for Sendbird Access IDs and Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SendbirdDetector(RegexBasedDetector): + """Scans for Sendbird Access IDs and Tokens.""" + + @property + def secret_type(self) -> str: + return "Sendbird Credential" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Sendbird Access ID + re.compile( + r"""(?i)(?:sendbird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Sendbird Access Token + re.compile( + r"""(?i)(?:sendbird)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/sendgrid_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/sendgrid_api_token.py new file mode 100644 index 0000000000..bf974f4fd7 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/sendgrid_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for SendGrid API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SendGridApiTokenDetector(RegexBasedDetector): + """Scans for SendGrid API Tokens.""" + + @property + def secret_type(self) -> str: + return "SendGrid API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)\b(SG\.[a-z0-9=_\-\.]{66})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/sendinblue_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/sendinblue_api_token.py new file mode 100644 index 0000000000..a6ed8c15ee --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/sendinblue_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for SendinBlue API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SendinBlueApiTokenDetector(RegexBasedDetector): + """Scans for SendinBlue API Tokens.""" + + @property + def secret_type(self) -> str: + return "SendinBlue API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)\b(xkeysib-[a-f0-9]{64}-[a-z0-9]{16})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/sentry_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/sentry_access_token.py new file mode 100644 index 0000000000..181fad2c7f --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/sentry_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Sentry Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SentryAccessTokenDetector(RegexBasedDetector): + """Scans for Sentry Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Sentry Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:sentry)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-f0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/shippo_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/shippo_api_token.py new file mode 100644 index 0000000000..4314c68768 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/shippo_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Shippo API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ShippoApiTokenDetector(RegexBasedDetector): + """Scans for Shippo API Tokens.""" + + @property + def secret_type(self) -> str: + return "Shippo API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)\b(shippo_(live|test)_[a-f0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/shopify.py b/enterprise/enterprise_hooks/secrets_plugins/shopify.py new file mode 100644 index 0000000000..f5f97c4478 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/shopify.py @@ -0,0 +1,31 @@ +""" +This plugin searches for Shopify Access Tokens, Custom Access Tokens, +Private App Access Tokens, and Shared Secrets. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ShopifyDetector(RegexBasedDetector): + """Scans for Shopify Access Tokens, Custom Access Tokens, Private App Access Tokens, + and Shared Secrets. + """ + + @property + def secret_type(self) -> str: + return "Shopify Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Shopify access token + re.compile(r"""shpat_[a-fA-F0-9]{32}"""), + # Shopify custom access token + re.compile(r"""shpca_[a-fA-F0-9]{32}"""), + # Shopify private app access token + re.compile(r"""shppa_[a-fA-F0-9]{32}"""), + # Shopify shared secret + re.compile(r"""shpss_[a-fA-F0-9]{32}"""), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/slack.py b/enterprise/enterprise_hooks/secrets_plugins/slack.py new file mode 100644 index 0000000000..4896fd76b2 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/slack.py @@ -0,0 +1,38 @@ +""" +This plugin searches for Slack tokens and webhooks. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SlackDetector(RegexBasedDetector): + """Scans for Slack tokens and webhooks.""" + + @property + def secret_type(self) -> str: + return "Slack Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Slack App-level token + re.compile(r"""(?i)(xapp-\d-[A-Z0-9]+-\d+-[a-z0-9]+)"""), + # Slack Bot token + re.compile(r"""(xoxb-[0-9]{10,13}\-[0-9]{10,13}[a-zA-Z0-9-]*)"""), + # Slack Configuration access token and refresh token + re.compile(r"""(?i)(xoxe.xox[bp]-\d-[A-Z0-9]{163,166})"""), + re.compile(r"""(?i)(xoxe-\d-[A-Z0-9]{146})"""), + # Slack Legacy bot token and token + re.compile(r"""(xoxb-[0-9]{8,14}\-[a-zA-Z0-9]{18,26})"""), + re.compile(r"""(xox[os]-\d+-\d+-\d+-[a-fA-F\d]+)"""), + # Slack Legacy Workspace token + re.compile(r"""(xox[ar]-(?:\d-)?[0-9a-zA-Z]{8,48})"""), + # Slack User token and enterprise token + re.compile(r"""(xox[pe](?:-[0-9]{10,13}){3}-[a-zA-Z0-9-]{28,34})"""), + # Slack Webhook URL + re.compile( + r"""(https?:\/\/)?hooks.slack.com\/(services|workflows)\/[A-Za-z0-9+\/]{43,46}""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/snyk_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/snyk_api_token.py new file mode 100644 index 0000000000..839bb57317 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/snyk_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Snyk API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SnykApiTokenDetector(RegexBasedDetector): + """Scans for Snyk API Tokens.""" + + @property + def secret_type(self) -> str: + return "Snyk API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:snyk)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/squarespace_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/squarespace_access_token.py new file mode 100644 index 0000000000..0dc83ad91d --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/squarespace_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Squarespace Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SquarespaceAccessTokenDetector(RegexBasedDetector): + """Scans for Squarespace Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Squarespace Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:squarespace)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/sumologic.py b/enterprise/enterprise_hooks/secrets_plugins/sumologic.py new file mode 100644 index 0000000000..7117629acc --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/sumologic.py @@ -0,0 +1,22 @@ +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class SumoLogicDetector(RegexBasedDetector): + """Scans for SumoLogic Access ID and Access Token.""" + + @property + def secret_type(self) -> str: + return "SumoLogic" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i:(?:sumo)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3})(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(su[a-zA-Z0-9]{12})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + re.compile( + r"""(?i)(?:sumo)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{64})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/telegram_bot_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/telegram_bot_api_token.py new file mode 100644 index 0000000000..30854fda1d --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/telegram_bot_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Telegram Bot API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class TelegramBotApiTokenDetector(RegexBasedDetector): + """Scans for Telegram Bot API Tokens.""" + + @property + def secret_type(self) -> str: + return "Telegram Bot API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:^|[^0-9])([0-9]{5,16}:A[a-zA-Z0-9_\-]{34})(?:$|[^a-zA-Z0-9_\-])""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/travisci_access_token.py b/enterprise/enterprise_hooks/secrets_plugins/travisci_access_token.py new file mode 100644 index 0000000000..90f9b48f46 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/travisci_access_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Travis CI Access Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class TravisCiAccessTokenDetector(RegexBasedDetector): + """Scans for Travis CI Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Travis CI Access Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:travis)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{22})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/twitch_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/twitch_api_token.py new file mode 100644 index 0000000000..1e0e3ccf8f --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/twitch_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Twitch API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class TwitchApiTokenDetector(RegexBasedDetector): + """Scans for Twitch API Tokens.""" + + @property + def secret_type(self) -> str: + return "Twitch API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:twitch)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{30})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/twitter.py b/enterprise/enterprise_hooks/secrets_plugins/twitter.py new file mode 100644 index 0000000000..99ad170d1e --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/twitter.py @@ -0,0 +1,36 @@ +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class TwitterDetector(RegexBasedDetector): + """Scans for Twitter Access Secrets, Access Tokens, API Keys, API Secrets, and Bearer Tokens.""" + + @property + def secret_type(self) -> str: + return "Twitter Secret" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Twitter Access Secret + re.compile( + r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{45})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Twitter Access Token + re.compile( + r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([0-9]{15,25}-[a-zA-Z0-9]{20,40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Twitter API Key + re.compile( + r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{25})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Twitter API Secret + re.compile( + r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{50})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Twitter Bearer Token + re.compile( + r"""(?i)(?:twitter)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(A{22}[a-zA-Z0-9%]{80,100})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/typeform_api_token.py b/enterprise/enterprise_hooks/secrets_plugins/typeform_api_token.py new file mode 100644 index 0000000000..8d9dc0e875 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/typeform_api_token.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Typeform API Tokens. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class TypeformApiTokenDetector(RegexBasedDetector): + """Scans for Typeform API Tokens.""" + + @property + def secret_type(self) -> str: + return "Typeform API Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:typeform)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(tfp_[a-z0-9\-_\.=]{59})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/vault.py b/enterprise/enterprise_hooks/secrets_plugins/vault.py new file mode 100644 index 0000000000..5ca552cd9e --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/vault.py @@ -0,0 +1,24 @@ +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class VaultDetector(RegexBasedDetector): + """Scans for Vault Batch Tokens and Vault Service Tokens.""" + + @property + def secret_type(self) -> str: + return "Vault Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Vault Batch Token + re.compile( + r"""(?i)\b(hvb\.[a-z0-9_-]{138,212})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Vault Service Token + re.compile( + r"""(?i)\b(hvs\.[a-z0-9_-]{90,100})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/yandex.py b/enterprise/enterprise_hooks/secrets_plugins/yandex.py new file mode 100644 index 0000000000..a58faec0d1 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/yandex.py @@ -0,0 +1,28 @@ +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class YandexDetector(RegexBasedDetector): + """Scans for Yandex Access Tokens, API Keys, and AWS Access Tokens.""" + + @property + def secret_type(self) -> str: + return "Yandex Token" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + # Yandex Access Token + re.compile( + r"""(?i)(?:yandex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(t1\.[A-Z0-9a-z_-]+[=]{0,2}\.[A-Z0-9a-z_-]{86}[=]{0,2})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Yandex API Key + re.compile( + r"""(?i)(?:yandex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(AQVN[A-Za-z0-9_\-]{35,38})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + # Yandex AWS Access Token + re.compile( + r"""(?i)(?:yandex)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}(YC[a-zA-Z0-9_\-]{38})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ), + ] diff --git a/enterprise/enterprise_hooks/secrets_plugins/zendesk_secret_key.py b/enterprise/enterprise_hooks/secrets_plugins/zendesk_secret_key.py new file mode 100644 index 0000000000..42c087c5b6 --- /dev/null +++ b/enterprise/enterprise_hooks/secrets_plugins/zendesk_secret_key.py @@ -0,0 +1,23 @@ +""" +This plugin searches for Zendesk Secret Keys. +""" + +import re + +from detect_secrets.plugins.base import RegexBasedDetector + + +class ZendeskSecretKeyDetector(RegexBasedDetector): + """Scans for Zendesk Secret Keys.""" + + @property + def secret_type(self) -> str: + return "Zendesk Secret Key" + + @property + def denylist(self) -> list[re.Pattern]: + return [ + re.compile( + r"""(?i)(?:zendesk)(?:[0-9a-z\-_\t .]{0,20})(?:[\s|']|[\s|"]){0,3}(?:=|>|:{1,3}=|\|\|:|<=|=>|:|\?=)(?:'|\"|\s|=|\x60){0,5}([a-z0-9]{40})(?:['|\"|\n|\r|\s|\x60|;]|$)""" + ) + ] diff --git a/entrypoint.sh b/entrypoint.sh index 80adf8d077..a028e54262 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,48 +1,13 @@ -#!/bin/sh +#!/bin/bash +echo $(pwd) -# Check if DATABASE_URL is not set -if [ -z "$DATABASE_URL" ]; then - # Check if all required variables are provided - if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then - # Construct DATABASE_URL from the provided variables - DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}" - export DATABASE_URL - else - echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." - exit 1 - fi -fi +# Run the Python migration script +python3 litellm/proxy/prisma_migration.py -# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations -if [ -z "$DIRECT_URL" ]; then - export DIRECT_URL=$DATABASE_URL -fi - -# Apply migrations -retry_count=0 -max_retries=3 -exit_code=1 - -until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] -do - retry_count=$((retry_count+1)) - echo "Attempt $retry_count..." - - # Run the Prisma db push command - prisma db push --accept-data-loss - - exit_code=$? - - if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then - echo "Retrying in 10 seconds..." - sleep 10 - fi -done - -if [ $exit_code -ne 0 ]; then - echo "Unable to push database changes after $max_retries retries." +# Check if the Python script executed successfully +if [ $? -eq 0 ]; then + echo "Migration script ran successfully!" +else + echo "Migration script failed!" exit 1 fi - -echo "Database push successful!" - diff --git a/litellm/__init__.py b/litellm/__init__.py index cee80a32df..a9e6b69ae6 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -125,6 +125,9 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" ################## ### PREVIEW FEATURES ### enable_preview_features: bool = False +return_response_headers: bool = ( + False # get response headers from LLM Api providers - example x-remaining-requests, +) ################## logging: bool = True caching: bool = ( @@ -413,6 +416,7 @@ openai_compatible_providers: List = [ "mistral", "groq", "nvidia_nim", + "volcengine", "codestral", "deepseek", "deepinfra", @@ -643,6 +647,7 @@ provider_list: List = [ "mistral", "groq", "nvidia_nim", + "volcengine", "codestral", "text-completion-codestral", "deepseek", @@ -736,6 +741,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"] from .timeout import timeout from .cost_calculator import completion_cost from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from .utils import ( client, exception_type, @@ -746,6 +752,7 @@ from .utils import ( create_pretrained_tokenizer, create_tokenizer, supports_function_calling, + supports_response_schema, supports_parallel_function_calling, supports_vision, supports_system_messages, @@ -796,7 +803,11 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig +from .llms.bedrock_httpx import ( + AmazonCohereChatConfig, + AmazonConverseConfig, + BEDROCK_CONVERSE_MODELS, +) from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, @@ -818,6 +829,7 @@ from .llms.openai import ( ) from .llms.nvidia_nim import NvidiaNimConfig from .llms.fireworks_ai import FireworksAIConfig +from .llms.volcengine import VolcEngineConfig from .llms.text_completion_codestral import MistralTextCompletionConfig from .llms.azure import ( AzureOpenAIConfig, @@ -844,6 +856,7 @@ from .exceptions import ( APIResponseValidationError, UnprocessableEntityError, InternalServerError, + JSONSchemaValidationError, LITELLM_EXCEPTION_TYPES, ) from .budget_manager import BudgetManager diff --git a/litellm/caching.py b/litellm/caching.py index 95cad01cfd..64488289a8 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -64,16 +64,55 @@ class BaseCache: class InMemoryCache(BaseCache): - def __init__(self): - # if users don't provider one, use the default litellm cache - self.cache_dict = {} - self.ttl_dict = {} + def __init__( + self, + max_size_in_memory: Optional[int] = 200, + default_ttl: Optional[ + int + ] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute + ): + """ + max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default + """ + self.max_size_in_memory = ( + max_size_in_memory or 200 + ) # set an upper bound of 200 items in-memory + self.default_ttl = default_ttl or 600 + + # in-memory cache + self.cache_dict: dict = {} + self.ttl_dict: dict = {} + + def evict_cache(self): + """ + Eviction policy: + - check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict + + + This guarantees the following: + - 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes + - 2. When ttl is set: the item will remain in memory for at least that amount of time + - 3. the size of in-memory cache is bounded + + """ + for key in list(self.ttl_dict.keys()): + if time.time() > self.ttl_dict[key]: + self.cache_dict.pop(key, None) + self.ttl_dict.pop(key, None) def set_cache(self, key, value, **kwargs): - print_verbose("InMemoryCache: set_cache") + print_verbose( + "InMemoryCache: set_cache. current size= {}".format(len(self.cache_dict)) + ) + if len(self.cache_dict) >= self.max_size_in_memory: + # only evict when cache is full + self.evict_cache() + self.cache_dict[key] = value if "ttl" in kwargs: self.ttl_dict[key] = time.time() + kwargs["ttl"] + else: + self.ttl_dict[key] = time.time() + self.default_ttl async def async_set_cache(self, key, value, **kwargs): self.set_cache(key=key, value=value, **kwargs) @@ -139,6 +178,7 @@ class InMemoryCache(BaseCache): init_value = await self.async_get_cache(key=key) or 0 value = init_value + value await self.async_set_cache(key, value, **kwargs) + return value def flush_cache(self): diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index d61e812d07..062e98be97 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1,6 +1,7 @@ # What is this? ## File for 'response_cost' calculation in Logging import time +import traceback from typing import List, Literal, Optional, Tuple, Union import litellm @@ -101,8 +102,12 @@ def cost_per_token( if custom_llm_provider is not None: model_with_provider = custom_llm_provider + "/" + model if region_name is not None: - model_with_provider_and_region = f"{custom_llm_provider}/{region_name}/{model}" - if model_with_provider_and_region in model_cost_ref: # use region based pricing, if it's available + model_with_provider_and_region = ( + f"{custom_llm_provider}/{region_name}/{model}" + ) + if ( + model_with_provider_and_region in model_cost_ref + ): # use region based pricing, if it's available model_with_provider = model_with_provider_and_region else: _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) @@ -118,7 +123,9 @@ def cost_per_token( Option2. model = "openai/gpt-4" - model = provider/model Option3. model = "anthropic.claude-3" - model = model """ - if model_with_provider in model_cost_ref: # Option 2. use model with provider, model = "openai/gpt-4" + if ( + model_with_provider in model_cost_ref + ): # Option 2. use model with provider, model = "openai/gpt-4" model = model_with_provider elif model in model_cost_ref: # Option 1. use model passed, model="gpt-4" model = model @@ -154,29 +161,45 @@ def cost_per_token( ) elif model in model_cost_ref: print_verbose(f"Success: model={model} in model_cost_map") - print_verbose(f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}") + print_verbose( + f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}" + ) if ( model_cost_ref[model].get("input_cost_per_token", None) is not None and model_cost_ref[model].get("output_cost_per_token", None) is not None ): ## COST PER TOKEN ## - prompt_tokens_cost_usd_dollar = model_cost_ref[model]["input_cost_per_token"] * prompt_tokens - completion_tokens_cost_usd_dollar = model_cost_ref[model]["output_cost_per_token"] * completion_tokens - elif model_cost_ref[model].get("output_cost_per_second", None) is not None and response_time_ms is not None: + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model]["input_cost_per_token"] * prompt_tokens + ) + completion_tokens_cost_usd_dollar = ( + model_cost_ref[model]["output_cost_per_token"] * completion_tokens + ) + elif ( + model_cost_ref[model].get("output_cost_per_second", None) is not None + and response_time_ms is not None + ): print_verbose( f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}" ) ## COST PER SECOND ## prompt_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = ( - model_cost_ref[model]["output_cost_per_second"] * response_time_ms / 1000 + model_cost_ref[model]["output_cost_per_second"] + * response_time_ms + / 1000 ) - elif model_cost_ref[model].get("input_cost_per_second", None) is not None and response_time_ms is not None: + elif ( + model_cost_ref[model].get("input_cost_per_second", None) is not None + and response_time_ms is not None + ): print_verbose( f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}" ) ## COST PER SECOND ## - prompt_tokens_cost_usd_dollar = model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000 + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000 + ) completion_tokens_cost_usd_dollar = 0.0 print_verbose( f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}" @@ -185,40 +208,57 @@ def cost_per_token( elif "ft:gpt-3.5-turbo" in model: print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm - prompt_tokens_cost_usd_dollar = model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens + ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] * completion_tokens + model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] + * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif "ft:gpt-4-0613" in model: print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:gpt-4-0613:abcd-id-cool-litellm - prompt_tokens_cost_usd_dollar = model_cost_ref["ft:gpt-4-0613"]["input_cost_per_token"] * prompt_tokens - completion_tokens_cost_usd_dollar = model_cost_ref["ft:gpt-4-0613"]["output_cost_per_token"] * completion_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref["ft:gpt-4-0613"]["input_cost_per_token"] * prompt_tokens + ) + completion_tokens_cost_usd_dollar = ( + model_cost_ref["ft:gpt-4-0613"]["output_cost_per_token"] * completion_tokens + ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif "ft:gpt-4o-2024-05-13" in model: print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:gpt-4o-2024-05-13:abcd-id-cool-litellm - prompt_tokens_cost_usd_dollar = model_cost_ref["ft:gpt-4o-2024-05-13"]["input_cost_per_token"] * prompt_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref["ft:gpt-4o-2024-05-13"]["input_cost_per_token"] + * prompt_tokens + ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:gpt-4o-2024-05-13"]["output_cost_per_token"] * completion_tokens + model_cost_ref["ft:gpt-4o-2024-05-13"]["output_cost_per_token"] + * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif "ft:davinci-002" in model: print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:davinci-002:abcd-id-cool-litellm - prompt_tokens_cost_usd_dollar = model_cost_ref["ft:davinci-002"]["input_cost_per_token"] * prompt_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref["ft:davinci-002"]["input_cost_per_token"] * prompt_tokens + ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:davinci-002"]["output_cost_per_token"] * completion_tokens + model_cost_ref["ft:davinci-002"]["output_cost_per_token"] + * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif "ft:babbage-002" in model: print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:babbage-002:abcd-id-cool-litellm - prompt_tokens_cost_usd_dollar = model_cost_ref["ft:babbage-002"]["input_cost_per_token"] * prompt_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref["ft:babbage-002"]["input_cost_per_token"] * prompt_tokens + ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:babbage-002"]["output_cost_per_token"] * completion_tokens + model_cost_ref["ft:babbage-002"]["output_cost_per_token"] + * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in litellm.azure_llms: @@ -227,17 +267,25 @@ def cost_per_token( verbose_logger.debug( f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}" ) - prompt_tokens_cost_usd_dollar = model_cost_ref[model]["input_cost_per_token"] * prompt_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model]["input_cost_per_token"] * prompt_tokens + ) verbose_logger.debug( f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}" ) - completion_tokens_cost_usd_dollar = model_cost_ref[model]["output_cost_per_token"] * completion_tokens + completion_tokens_cost_usd_dollar = ( + model_cost_ref[model]["output_cost_per_token"] * completion_tokens + ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in litellm.azure_embedding_models: verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model") model = litellm.azure_embedding_models[model] - prompt_tokens_cost_usd_dollar = model_cost_ref[model]["input_cost_per_token"] * prompt_tokens - completion_tokens_cost_usd_dollar = model_cost_ref[model]["output_cost_per_token"] * completion_tokens + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model]["input_cost_per_token"] * prompt_tokens + ) + completion_tokens_cost_usd_dollar = ( + model_cost_ref[model]["output_cost_per_token"] * completion_tokens + ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar else: # if model is not in model_prices_and_context_window.json. Raise an exception-let users know @@ -261,7 +309,9 @@ def get_model_params_and_category(model_name) -> str: import re model_name = model_name.lower() - re_params_match = re.search(r"(\d+b)", model_name) # catch all decimals like 3b, 70b, etc + re_params_match = re.search( + r"(\d+b)", model_name + ) # catch all decimals like 3b, 70b, etc category = None if re_params_match is not None: params_match = str(re_params_match.group(1)) @@ -292,7 +342,9 @@ def get_model_params_and_category(model_name) -> str: def get_replicate_completion_pricing(completion_response=None, total_time=0.0): # see https://replicate.com/pricing # for all litellm currently supported LLMs, almost all requests go to a100_80gb - a100_80gb_price_per_second_public = 0.001400 # assume all calls sent to A100 80GB for now + a100_80gb_price_per_second_public = ( + 0.001400 # assume all calls sent to A100 80GB for now + ) if total_time == 0.0: # total time is in ms start_time = completion_response["created"] end_time = getattr(completion_response, "ended", time.time()) @@ -377,13 +429,16 @@ def completion_cost( prompt_characters = 0 completion_tokens = 0 completion_characters = 0 - custom_llm_provider = None if completion_response is not None: # get input/output tokens from completion_response prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = completion_response.get("usage", {}).get("completion_tokens", 0) + completion_tokens = completion_response.get("usage", {}).get( + "completion_tokens", 0 + ) total_time = completion_response.get("_response_ms", 0) - verbose_logger.debug(f"completion_response response ms: {completion_response.get('_response_ms')} ") + verbose_logger.debug( + f"completion_response response ms: {completion_response.get('_response_ms')} " + ) model = model or completion_response.get( "model", None ) # check if user passed an override for model, if it's none check completion_response['model'] @@ -393,16 +448,30 @@ def completion_cost( and len(completion_response._hidden_params["model"]) > 0 ): model = completion_response._hidden_params.get("model", model) - custom_llm_provider = completion_response._hidden_params.get("custom_llm_provider", "") - region_name = completion_response._hidden_params.get("region_name", region_name) - size = completion_response._hidden_params.get("optional_params", {}).get( + custom_llm_provider = completion_response._hidden_params.get( + "custom_llm_provider", "" + ) + region_name = completion_response._hidden_params.get( + "region_name", region_name + ) + size = completion_response._hidden_params.get( + "optional_params", {} + ).get( "size", "1024-x-1024" ) # openai default - quality = completion_response._hidden_params.get("optional_params", {}).get( + quality = completion_response._hidden_params.get( + "optional_params", {} + ).get( "quality", "standard" ) # openai default - n = completion_response._hidden_params.get("optional_params", {}).get("n", 1) # openai default + n = completion_response._hidden_params.get("optional_params", {}).get( + "n", 1 + ) # openai default else: + if model is None: + raise ValueError( + f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}" + ) if len(messages) > 0: prompt_tokens = token_counter(model=model, messages=messages) elif len(prompt) > 0: @@ -413,7 +482,19 @@ def completion_cost( f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}" ) - if call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value: + if custom_llm_provider is None: + try: + _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) + except Exception as e: + verbose_logger.error( + "litellm.cost_calculator.py::completion_cost() - Error inferring custom_llm_provider - {}".format( + str(e) + ) + ) + if ( + call_type == CallTypes.image_generation.value + or call_type == CallTypes.aimage_generation.value + ): ### IMAGE GENERATION COST CALCULATION ### if custom_llm_provider == "vertex_ai": # https://cloud.google.com/vertex-ai/generative-ai/pricing @@ -431,23 +512,43 @@ def completion_cost( height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024 width = int(size[1]) verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}") - verbose_logger.debug(f"image_gen_model_name_with_quality: {image_gen_model_name_with_quality}") + verbose_logger.debug( + f"image_gen_model_name_with_quality: {image_gen_model_name_with_quality}" + ) if image_gen_model_name in litellm.model_cost: - return litellm.model_cost[image_gen_model_name]["input_cost_per_pixel"] * height * width * n + return ( + litellm.model_cost[image_gen_model_name]["input_cost_per_pixel"] + * height + * width + * n + ) elif image_gen_model_name_with_quality in litellm.model_cost: return ( - litellm.model_cost[image_gen_model_name_with_quality]["input_cost_per_pixel"] * height * width * n + litellm.model_cost[image_gen_model_name_with_quality][ + "input_cost_per_pixel" + ] + * height + * width + * n ) else: - raise Exception(f"Model={image_gen_model_name} not found in completion cost model map") + raise Exception( + f"Model={image_gen_model_name} not found in completion cost model map" + ) # Calculate cost based on prompt_tokens, completion_tokens - if "togethercomputer" in model or "together_ai" in model or custom_llm_provider == "together_ai": + if ( + "togethercomputer" in model + or "together_ai" in model + or custom_llm_provider == "together_ai" + ): # together ai prices based on size of llm # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) # replicate llms are calculate based on time for request running # see https://replicate.com/pricing - elif (model in litellm.replicate_models or "replicate" in model) and model not in litellm.model_cost: + elif ( + model in litellm.replicate_models or "replicate" in model + ) and model not in litellm.model_cost: # for unmapped replicate model, default to replicate's time tracking logic return get_replicate_completion_pricing(completion_response, total_time) @@ -456,23 +557,25 @@ def completion_cost( f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}" ) - if ( - custom_llm_provider is not None - and custom_llm_provider == "vertex_ai" - and completion_response is not None - and isinstance(completion_response, ModelResponse) - ): + if custom_llm_provider is not None and custom_llm_provider == "vertex_ai": # Calculate the prompt characters + response characters if len("messages") > 0: - prompt_string = litellm.utils.get_formatted_prompt(data={"messages": messages}, call_type="completion") + prompt_string = litellm.utils.get_formatted_prompt( + data={"messages": messages}, call_type="completion" + ) else: prompt_string = "" prompt_characters = litellm.utils._count_characters(text=prompt_string) - - completion_string = litellm.utils.get_response_string(response_obj=completion_response) - - completion_characters = litellm.utils._count_characters(text=completion_string) + if completion_response is not None and isinstance( + completion_response, ModelResponse + ): + completion_string = litellm.utils.get_response_string( + response_obj=completion_response + ) + completion_characters = litellm.utils._count_characters( + text=completion_string + ) ( prompt_tokens_cost_usd_dollar, @@ -507,7 +610,7 @@ def response_cost_calculator( TextCompletionResponse, ], model: str, - custom_llm_provider: str, + custom_llm_provider: Optional[str], call_type: Literal[ "embedding", "aembedding", @@ -529,6 +632,10 @@ def response_cost_calculator( base_model: Optional[str] = None, custom_pricing: Optional[bool] = None, ) -> Optional[float]: + """ + Returns + - float or None: cost of response OR none if error. + """ try: response_cost: float = 0.0 if cache_hit is not None and cache_hit is True: @@ -544,7 +651,9 @@ def response_cost_calculator( ) else: if ( - model in litellm.model_cost and custom_pricing is not None and custom_llm_provider is True + model in litellm.model_cost + and custom_pricing is not None + and custom_llm_provider is True ): # override defaults if custom pricing is set base_model = model # base_model defaults to None if not set on model_info @@ -556,5 +665,14 @@ def response_cost_calculator( ) return response_cost except litellm.NotFoundError as e: - print_verbose(f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map.") + print_verbose( + f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map." + ) + return None + except Exception as e: + verbose_logger.error( + "litellm.cost_calculator.py::response_cost_calculator - Exception occurred - {}/n{}".format( + str(e), traceback.format_exc() + ) + ) return None diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 98b5192784..d85510b1d8 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore message, llm_provider, model, - request: httpx.Request, + request: Optional[httpx.Request] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + if request is None: + request = httpx.Request(method="POST", url="https://api.openai.com/v1") super().__init__(self.message, request=request, body=None) # type: ignore def __str__(self): @@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore self.llm_provider = "openai" +class JSONSchemaValidationError(APIError): + def __init__( + self, model: str, llm_provider: str, raw_response: str, schema: str + ) -> None: + self.raw_response = raw_response + self.schema = schema + self.model = model + message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format( + model, raw_response, schema + ) + self.message = message + super().__init__( + model=model, message=message, llm_provider=llm_provider, status_code=500 + ) + + LITELLM_EXCEPTION_TYPES = [ AuthenticationError, NotFoundError, @@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [ APIResponseValidationError, OpenAIError, InternalServerError, + JSONSchemaValidationError, ] diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index eae8b8e22a..983ec39428 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -311,22 +311,17 @@ class LangFuseLogger: try: tags = [] - try: - metadata = copy.deepcopy( - metadata - ) # Avoid modifying the original metadata - except: - new_metadata = {} - for key, value in metadata.items(): - if ( - isinstance(value, list) - or isinstance(value, dict) - or isinstance(value, str) - or isinstance(value, int) - or isinstance(value, float) - ): - new_metadata[key] = copy.deepcopy(value) - metadata = new_metadata + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, dict) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = copy.deepcopy(value) + metadata = new_metadata supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3") diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 4f0ffa387e..6cd7469079 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -2,14 +2,20 @@ #### What this does #### # On success, log events to Prometheus -import dotenv, os -import requests # type: ignore +import datetime +import os +import subprocess +import sys import traceback -import datetime, subprocess, sys -import litellm, uuid -from litellm._logging import print_verbose, verbose_logger +import uuid from typing import Optional, Union +import dotenv +import requests # type: ignore + +import litellm +from litellm._logging import print_verbose, verbose_logger + class PrometheusLogger: # Class variables or attributes @@ -20,6 +26,8 @@ class PrometheusLogger: try: from prometheus_client import Counter, Gauge + from litellm.proxy.proxy_server import premium_user + self.litellm_llm_api_failed_requests_metric = Counter( name="litellm_llm_api_failed_requests_metric", documentation="Total number of failed LLM API calls via litellm", @@ -88,6 +96,31 @@ class PrometheusLogger: labelnames=["hashed_api_key", "api_key_alias"], ) + # Litellm-Enterprise Metrics + if premium_user is True: + # Remaining Rate Limit for model + self.litellm_remaining_requests_metric = Gauge( + "litellm_remaining_requests", + "remaining requests for model, returned from LLM API Provider", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + ], + ) + + self.litellm_remaining_tokens_metric = Gauge( + "litellm_remaining_tokens", + "remaining tokens for model, returned from LLM API Provider", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + ], + ) + except Exception as e: print_verbose(f"Got exception on init prometheus client {str(e)}") raise e @@ -104,6 +137,8 @@ class PrometheusLogger: ): try: # Define prometheus client + from litellm.proxy.proxy_server import premium_user + verbose_logger.debug( f"prometheus Logging - Enters logging function for model {kwargs}" ) @@ -199,6 +234,10 @@ class PrometheusLogger: user_api_key, user_api_key_alias ).set(_remaining_api_key_budget) + # set x-ratelimit headers + if premium_user is True: + self.set_remaining_tokens_requests_metric(kwargs) + ### FAILURE INCREMENT ### if "exception" in kwargs: self.litellm_llm_api_failed_requests_metric.labels( @@ -216,6 +255,58 @@ class PrometheusLogger: verbose_logger.debug(traceback.format_exc()) pass + def set_remaining_tokens_requests_metric(self, request_kwargs: dict): + try: + verbose_logger.debug("setting remaining tokens requests metric") + _response_headers = request_kwargs.get("response_headers") + _litellm_params = request_kwargs.get("litellm_params", {}) or {} + _metadata = _litellm_params.get("metadata", {}) + litellm_model_name = request_kwargs.get("model", None) + model_group = _metadata.get("model_group", None) + api_base = _metadata.get("api_base", None) + llm_provider = _litellm_params.get("custom_llm_provider", None) + + remaining_requests = None + remaining_tokens = None + # OpenAI / OpenAI Compatible headers + if ( + _response_headers + and "x-ratelimit-remaining-requests" in _response_headers + ): + remaining_requests = _response_headers["x-ratelimit-remaining-requests"] + if ( + _response_headers + and "x-ratelimit-remaining-tokens" in _response_headers + ): + remaining_tokens = _response_headers["x-ratelimit-remaining-tokens"] + verbose_logger.debug( + f"remaining requests: {remaining_requests}, remaining tokens: {remaining_tokens}" + ) + + if remaining_requests: + """ + "model_group", + "api_provider", + "api_base", + "litellm_model_name" + """ + self.litellm_remaining_requests_metric.labels( + model_group, llm_provider, api_base, litellm_model_name + ).set(remaining_requests) + + if remaining_tokens: + self.litellm_remaining_tokens_metric.labels( + model_group, llm_provider, api_base, litellm_model_name + ).set(remaining_tokens) + + except Exception as e: + verbose_logger.error( + "Prometheus Error: set_remaining_tokens_requests_metric. Exception occured - {}".format( + str(e) + ) + ) + return + def safe_get_remaining_budget( max_budget: Optional[float], spend: Optional[float] diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index 0796d1048b..6e8c4a4e43 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -1,10 +1,14 @@ #### What this does #### # On success + failure, log events to Supabase +import datetime import os +import subprocess +import sys import traceback -import datetime, subprocess, sys -import litellm, uuid +import uuid + +import litellm from litellm._logging import print_verbose, verbose_logger @@ -54,6 +58,7 @@ class S3Logger: "s3_aws_session_token" ) s3_config = litellm.s3_callback_params.get("s3_config") + s3_path = litellm.s3_callback_params.get("s3_path") # done reading litellm.s3_callback_params self.bucket_name = s3_bucket_name diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index bce0fef8cd..04195705a0 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -606,6 +606,13 @@ class SlackAlerting(CustomLogger): and request_data.get("litellm_status", "") != "success" and request_data.get("litellm_status", "") != "fail" ): + ## CHECK IF CACHE IS UPDATED + litellm_call_id = request_data.get("litellm_call_id", "") + status: Optional[str] = await self.internal_usage_cache.async_get_cache( + key="request_status:{}".format(litellm_call_id), local_only=True + ) + if status is not None and (status == "success" or status == "fail"): + return if request_data.get("deployment", None) is not None and isinstance( request_data["deployment"], dict ): diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 7b911895d1..d8d551048b 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,5 +1,5 @@ # What is this? -## Helper utilities for the model response objects +## Helper utilities def map_finish_reason( @@ -26,7 +26,7 @@ def map_finish_reason( finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP" ): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] return "stop" - elif finish_reason == "SAFETY": # vertex ai + elif finish_reason == "SAFETY" or finish_reason == "RECITATION": # vertex ai return "content_filter" elif finish_reason == "STOP": # vertex ai return "stop" diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py new file mode 100644 index 0000000000..5ac26c7ae5 --- /dev/null +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -0,0 +1,40 @@ +import json +from typing import Optional + + +def get_error_message(error_obj) -> Optional[str]: + """ + OpenAI Returns Error message that is nested, this extract the message + + Example: + { + 'request': "", + 'message': "Error code: 400 - {\'error\': {\'message\': \"Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.\", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'decimal_above_max_value'}}", + 'body': { + 'message': "Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.", + 'type': 'invalid_request_error', + 'param': 'temperature', + 'code': 'decimal_above_max_value' + }, + 'code': 'decimal_above_max_value', + 'param': 'temperature', + 'type': 'invalid_request_error', + 'response': "", + 'status_code': 400, + 'request_id': 'req_f287898caa6364cd42bc01355f74dd2a' + } + """ + try: + # First, try to access the message directly from the 'body' key + if error_obj is None: + return None + + if hasattr(error_obj, "body"): + _error_obj_body = getattr(error_obj, "body") + if isinstance(_error_obj_body, dict): + return _error_obj_body.get("message") + + # If all else fails, return None + return None + except Exception as e: + return None diff --git a/litellm/litellm_core_utils/json_validation_rule.py b/litellm/litellm_core_utils/json_validation_rule.py new file mode 100644 index 0000000000..f19144aaf1 --- /dev/null +++ b/litellm/litellm_core_utils/json_validation_rule.py @@ -0,0 +1,23 @@ +import json + + +def validate_schema(schema: dict, response: str): + """ + Validate if the returned json response follows the schema. + + Params: + - schema - dict: JSON schema + - response - str: Received json response as string. + """ + from jsonschema import ValidationError, validate + + from litellm import JSONSchemaValidationError + + response_dict = json.loads(response) + + try: + validate(response_dict, schema=schema) + except ValidationError: + raise JSONSchemaValidationError( + model="", llm_provider="", raw_response=response, schema=json.dumps(schema) + ) diff --git a/litellm/litellm_core_utils/token_counter.py b/litellm/litellm_core_utils/token_counter.py new file mode 100644 index 0000000000..ebc0765c05 --- /dev/null +++ b/litellm/litellm_core_utils/token_counter.py @@ -0,0 +1,83 @@ +# What is this? +## Helper utilities for token counting +from typing import Optional + +import litellm +from litellm import verbose_logger + + +def get_modified_max_tokens( + model: str, + base_model: str, + messages: Optional[list], + user_max_tokens: Optional[int], + buffer_perc: Optional[float], + buffer_num: Optional[float], +) -> Optional[int]: + """ + Params: + + Returns the user's max output tokens, adjusted for: + - the size of input - for models where input + output can't exceed X + - model max output tokens - for models where there is a separate output token limit + """ + try: + if user_max_tokens is None: + return None + + ## MODEL INFO + _model_info = litellm.get_model_info(model=model) + + max_output_tokens = litellm.get_max_tokens( + model=base_model + ) # assume min context window is 4k tokens + + ## UNKNOWN MAX OUTPUT TOKENS - return user defined amount + if max_output_tokens is None: + return user_max_tokens + + input_tokens = litellm.token_counter(model=base_model, messages=messages) + + # token buffer + if buffer_perc is None: + buffer_perc = 0.1 + if buffer_num is None: + buffer_num = 10 + token_buffer = max( + buffer_perc * input_tokens, buffer_num + ) # give at least a 10 token buffer. token counting can be imprecise. + + input_tokens += int(token_buffer) + verbose_logger.debug( + f"max_output_tokens: {max_output_tokens}, user_max_tokens: {user_max_tokens}" + ) + ## CASE 1: model input + output can't exceed X - happens when max input = max output, e.g. gpt-3.5-turbo + if _model_info["max_input_tokens"] == max_output_tokens: + verbose_logger.debug( + f"input_tokens: {input_tokens}, max_output_tokens: {max_output_tokens}" + ) + if input_tokens > max_output_tokens: + pass # allow call to fail normally - don't set max_tokens to negative. + elif ( + user_max_tokens + input_tokens > max_output_tokens + ): # we can still modify to keep it positive but below the limit + verbose_logger.debug( + f"MODIFYING MAX TOKENS - user_max_tokens={user_max_tokens}, input_tokens={input_tokens}, max_output_tokens={max_output_tokens}" + ) + user_max_tokens = int(max_output_tokens - input_tokens) + ## CASE 2: user_max_tokens> model max output tokens + elif user_max_tokens > max_output_tokens: + user_max_tokens = max_output_tokens + + verbose_logger.debug( + f"litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - user_max_tokens: {user_max_tokens}" + ) + + return user_max_tokens + except Exception as e: + verbose_logger.error( + "litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - Error while checking max token limit: {}\nmodel={}, base_model={}".format( + str(e), model, base_model + ) + ) + return user_max_tokens diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 808813c05e..1051a56b77 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -1,23 +1,28 @@ -import os, types +import copy import json -from enum import Enum -import requests, copy # type: ignore +import os import time +import types +from enum import Enum from functools import partial -from typing import Callable, Optional, List, Union -import litellm.litellm_core_utils -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -from litellm.litellm_core_utils.core_helpers import map_finish_reason +from typing import Callable, List, Optional, Union + +import httpx # type: ignore +import requests # type: ignore + import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt +import litellm.litellm_core_utils +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, _get_async_httpx_client, _get_httpx_client, ) -from .base import BaseLLM -import httpx # type: ignore from litellm.types.llms.anthropic import AnthropicMessagesToolChoice +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import custom_prompt, prompt_factory class AnthropicConstants(Enum): @@ -179,10 +184,19 @@ async def make_call( if client is None: client = _get_async_httpx_client() # Create a new client if none provided - response = await client.post(api_base, headers=headers, data=data, stream=True) + try: + response = await client.post(api_base, headers=headers, data=data, stream=True) + except httpx.HTTPStatusError as e: + raise AnthropicError( + status_code=e.response.status_code, message=await e.response.aread() + ) + except Exception as e: + raise AnthropicError(status_code=500, message=str(e)) if response.status_code != 200: - raise AnthropicError(status_code=response.status_code, message=response.text) + raise AnthropicError( + status_code=response.status_code, message=await response.aread() + ) completion_stream = response.aiter_lines() diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index b763a7c955..8932e44941 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1,6 +1,7 @@ import asyncio import json import os +import time import types import uuid from typing import ( @@ -21,8 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI from typing_extensions import overload import litellm -from litellm import OpenAIConfig +from litellm import ImageResponse, OpenAIConfig from litellm.caching import DualCache +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import ( Choices, CustomStreamWrapper, @@ -32,6 +35,7 @@ from litellm.utils import ( UnsupportedParamsError, convert_to_model_response_object, get_secret, + modify_url, ) from ..types.llms.openai import ( @@ -42,6 +46,7 @@ from ..types.llms.openai import ( AsyncAssistantEventHandler, AsyncAssistantStreamManager, AsyncCursorPage, + HttpxBinaryResponseContent, MessageData, OpenAICreateThreadParamsMessage, OpenAIMessage, @@ -414,6 +419,79 @@ class AzureChatCompletion(BaseLLM): headers["Authorization"] = f"Bearer {azure_ad_token}" return headers + def _get_sync_azure_client( + self, + api_version: Optional[str], + api_base: Optional[str], + api_key: Optional[str], + azure_ad_token: Optional[str], + model: str, + max_retries: int, + timeout: Union[float, httpx.Timeout], + client: Optional[Any], + client_type: Literal["sync", "async"], + ): + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + azure_client_params["azure_ad_token"] = azure_ad_token + if client is None: + if client_type == "sync": + azure_client = AzureOpenAI(**azure_client_params) # type: ignore + elif client_type == "async": + azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore + else: + azure_client = client + if api_version is not None and isinstance(azure_client._custom_query, dict): + # set api_version to version passed by user + azure_client._custom_query.setdefault("api-version", api_version) + + return azure_client + + async def make_azure_openai_chat_completion_request( + self, + azure_client: AsyncAzureOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = ( + await azure_client.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await azure_client.chat.completions.create( + **data, timeout=timeout + ) + return None, response + except Exception as e: + raise e + def completion( self, model: str, @@ -426,7 +504,7 @@ class AzureChatCompletion(BaseLLM): azure_ad_token: str, print_verbose: Callable, timeout: Union[float, httpx.Timeout], - logging_obj, + logging_obj: LiteLLMLoggingObj, optional_params, litellm_params, logger_fn, @@ -605,9 +683,9 @@ class AzureChatCompletion(BaseLLM): data: dict, timeout: Any, model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, azure_ad_token: Optional[str] = None, client=None, # this is the AsyncAzureOpenAI - logging_obj=None, ): response = None try: @@ -657,19 +735,52 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.chat.completions.create( - **data, timeout=timeout + + headers, response = await self.make_azure_openai_chat_completion_request( + azure_client=azure_client, + data=data, + timeout=timeout, + ) + logging_obj.model_call_details["response_headers"] = headers + + stringified_response = response.model_dump() + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, ) return convert_to_model_response_object( - response_object=response.model_dump(), + response_object=stringified_response, model_response_object=model_response, ) except AzureOpenAIError as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) exception_mapping_worked = True raise e except asyncio.CancelledError as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) raise AzureOpenAIError(status_code=500, message=str(e)) except Exception as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) if hasattr(e, "status_code"): raise e else: @@ -739,7 +850,7 @@ class AzureChatCompletion(BaseLLM): async def async_streaming( self, - logging_obj, + logging_obj: LiteLLMLoggingObj, api_base: str, api_key: str, api_version: str, @@ -788,9 +899,14 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.chat.completions.create( - **data, timeout=timeout + + headers, response = await self.make_azure_openai_chat_completion_request( + azure_client=azure_client, + data=data, + timeout=timeout, ) + logging_obj.model_call_details["response_headers"] = headers + # return response streamwrapper = CustomStreamWrapper( completion_stream=response, @@ -812,7 +928,7 @@ class AzureChatCompletion(BaseLLM): azure_client_params: dict, api_key: str, input: list, - client=None, + client: Optional[AsyncAzureOpenAI] = None, logging_obj=None, timeout=None, ): @@ -911,6 +1027,7 @@ class AzureChatCompletion(BaseLLM): model_response=model_response, azure_client_params=azure_client_params, timeout=timeout, + client=client, ) return response if client is None: @@ -937,6 +1054,234 @@ class AzureChatCompletion(BaseLLM): else: raise AzureOpenAIError(status_code=500, message=str(e)) + async def make_async_azure_httpx_request( + self, + client: Optional[AsyncHTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + api_base: str, + api_version: str, + api_key: str, + data: dict, + ) -> httpx.Response: + """ + Implemented for azure dall-e-2 image gen calls + + Alternative to needing a custom transport implementation + """ + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + async_handler = AsyncHTTPHandler(**_params) # type: ignore + else: + async_handler = client # type: ignore + + if ( + "images/generations" in api_base + and api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): # CREATE + POLL for azure dall-e-2 calls + + api_base = modify_url( + original_url=api_base, new_path="/openai/images/generations:submit" + ) + + data.pop( + "model", None + ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview + response = await async_handler.post( + url=api_base, + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "api-key": api_key, + }, + ) + operation_location_url = response.headers["operation-location"] + response = await async_handler.get( + url=operation_location_url, + headers={ + "api-key": api_key, + }, + ) + + await response.aread() + + timeout_secs: int = 120 + start_time = time.time() + if "status" not in response.json(): + raise Exception( + "Expected 'status' in response. Got={}".format(response.json()) + ) + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + timeout_msg = { + "error": { + "code": "Timeout", + "message": "Operation polling timed out.", + } + } + + raise AzureOpenAIError( + status_code=408, message="Operation polling timed out." + ) + + await asyncio.sleep(int(response.headers.get("retry-after") or 10)) + response = await async_handler.get( + url=operation_location_url, + headers={ + "api-key": api_key, + }, + ) + await response.aread() + + if response.json()["status"] == "failed": + error_data = response.json() + raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + + return response + return await async_handler.post( + url=api_base, + json=data, + headers={ + "Content-Type": "application/json;", + "api-key": api_key, + }, + ) + + def make_sync_azure_httpx_request( + self, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + api_base: str, + api_version: str, + api_key: str, + data: dict, + ) -> httpx.Response: + """ + Implemented for azure dall-e-2 image gen calls + + Alternative to needing a custom transport implementation + """ + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + if ( + "images/generations" in api_base + and api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): # CREATE + POLL for azure dall-e-2 calls + + api_base = modify_url( + original_url=api_base, new_path="/openai/images/generations:submit" + ) + + data.pop( + "model", None + ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview + response = sync_handler.post( + url=api_base, + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "api-key": api_key, + }, + ) + operation_location_url = response.headers["operation-location"] + response = sync_handler.get( + url=operation_location_url, + headers={ + "api-key": api_key, + }, + ) + + response.read() + + timeout_secs: int = 120 + start_time = time.time() + if "status" not in response.json(): + raise Exception( + "Expected 'status' in response. Got={}".format(response.json()) + ) + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + raise AzureOpenAIError( + status_code=408, message="Operation polling timed out." + ) + + time.sleep(int(response.headers.get("retry-after") or 10)) + response = sync_handler.get( + url=operation_location_url, + headers={ + "api-key": api_key, + }, + ) + response.read() + + if response.json()["status"] == "failed": + error_data = response.json() + raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + + return response + return sync_handler.post( + url=api_base, + json=data, + headers={ + "Content-Type": "application/json;", + "api-key": api_key, + }, + ) + + def create_azure_base_url( + self, azure_client_params: dict, model: Optional[str] + ) -> str: + + api_base: str = azure_client_params.get( + "azure_endpoint", "" + ) # "https://example-endpoint.openai.azure.com" + if api_base.endswith("/"): + api_base = api_base.rstrip("/") + api_version: str = azure_client_params.get("api_version", "") + if model is None: + model = "" + new_api_base = ( + api_base + + "/openai/deployments/" + + model + + "/images/generations" + + "?api-version=" + + api_version + ) + + return new_api_base + async def aimage_generation( self, data: dict, @@ -948,30 +1293,40 @@ class AzureChatCompletion(BaseLLM): logging_obj=None, timeout=None, ): - response = None + response: Optional[dict] = None try: - if client is None: - client_session = litellm.aclient_session or httpx.AsyncClient( - transport=AsyncCustomHTTPTransport(), - ) - azure_client = AsyncAzureOpenAI( - http_client=client_session, **azure_client_params - ) - else: - azure_client = client + # response = await azure_client.images.generate(**data, timeout=timeout) + api_base: str = azure_client_params.get( + "api_base", "" + ) # "https://example-endpoint.openai.azure.com" + if api_base.endswith("/"): + api_base = api_base.rstrip("/") + api_version: str = azure_client_params.get("api_version", "") + img_gen_api_base = self.create_azure_base_url( + azure_client_params=azure_client_params, model=data.get("model", "") + ) + ## LOGGING logging_obj.pre_call( input=data["prompt"], - api_key=azure_client.api_key, + api_key=api_key, additional_args={ - "headers": {"api_key": azure_client.api_key}, - "api_base": azure_client._base_url._uri_reference, - "acompletion": True, "complete_input_dict": data, + "api_base": img_gen_api_base, + "headers": {"api_key": api_key}, }, ) - response = await azure_client.images.generate(**data, timeout=timeout) - stringified_response = response.model_dump() + httpx_response: httpx.Response = await self.make_async_azure_httpx_request( + client=None, + timeout=timeout, + api_base=img_gen_api_base, + api_version=api_version, + api_key=api_key, + data=data, + ) + response = httpx_response.json()["result"] + + stringified_response = response ## LOGGING logging_obj.post_call( input=input, @@ -1054,28 +1409,30 @@ class AzureChatCompletion(BaseLLM): response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore return response - if client is None: - client_session = litellm.client_session or httpx.Client( - transport=CustomHTTPTransport(), - ) - azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore - else: - azure_client = client + img_gen_api_base = self.create_azure_base_url( + azure_client_params=azure_client_params, model=data.get("model", "") + ) ## LOGGING logging_obj.pre_call( - input=prompt, - api_key=azure_client.api_key, + input=data["prompt"], + api_key=api_key, additional_args={ - "headers": {"api_key": azure_client.api_key}, - "api_base": azure_client._base_url._uri_reference, - "acompletion": False, "complete_input_dict": data, + "api_base": img_gen_api_base, + "headers": {"api_key": api_key}, }, ) + httpx_response: httpx.Response = self.make_sync_azure_httpx_request( + client=None, + timeout=timeout, + api_base=img_gen_api_base, + api_version=api_version or "", + api_key=api_key or "", + data=data, + ) + response = httpx_response.json()["result"] - ## COMPLETION CALL - response = azure_client.images.generate(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( input=prompt, @@ -1084,7 +1441,7 @@ class AzureChatCompletion(BaseLLM): original_response=response, ) # return response - return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore + return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore except AzureOpenAIError as e: exception_mapping_worked = True raise e @@ -1247,6 +1604,96 @@ class AzureChatCompletion(BaseLLM): ) raise e + def audio_speech( + self, + model: str, + input: str, + voice: str, + optional_params: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + organization: Optional[str], + max_retries: int, + timeout: Union[float, httpx.Timeout], + azure_ad_token: Optional[str] = None, + aspeech: Optional[bool] = None, + client=None, + ) -> HttpxBinaryResponseContent: + + max_retries = optional_params.pop("max_retries", 2) + + if aspeech is not None and aspeech is True: + return self.async_audio_speech( + model=model, + input=input, + voice=voice, + optional_params=optional_params, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + max_retries=max_retries, + timeout=timeout, + client=client, + ) # type: ignore + + azure_client: AzureOpenAI = self._get_sync_azure_client( + api_base=api_base, + api_version=api_version, + api_key=api_key, + azure_ad_token=azure_ad_token, + model=model, + max_retries=max_retries, + timeout=timeout, + client=client, + client_type="sync", + ) # type: ignore + + response = azure_client.audio.speech.create( + model=model, + voice=voice, # type: ignore + input=input, + **optional_params, + ) + return response + + async def async_audio_speech( + self, + model: str, + input: str, + voice: str, + optional_params: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + max_retries: int, + timeout: Union[float, httpx.Timeout], + client=None, + ) -> HttpxBinaryResponseContent: + + azure_client: AsyncAzureOpenAI = self._get_sync_azure_client( + api_base=api_base, + api_version=api_version, + api_key=api_key, + azure_ad_token=azure_ad_token, + model=model, + max_retries=max_retries, + timeout=timeout, + client=client, + client_type="async", + ) # type: ignore + + response = await azure_client.audio.speech.create( + model=model, + voice=voice, # type: ignore + input=input, + **optional_params, + ) + + return response + def get_headers( self, model: Optional[str], diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 14abec784f..7b4628a76e 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -60,6 +60,17 @@ from .prompt_templates.factory import ( prompt_factory, ) +BEDROCK_CONVERSE_MODELS = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-v2", + "anthropic.claude-v2:1", + "anthropic.claude-v1", + "anthropic.claude-instant-v1", +] + iam_cache = DualCache() @@ -305,6 +316,7 @@ class BedrockLLM(BaseLLM): self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, @@ -320,6 +332,7 @@ class BedrockLLM(BaseLLM): params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -337,6 +350,7 @@ class BedrockLLM(BaseLLM): ( aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -430,6 +444,19 @@ class BedrockLLM(BaseLLM): client = boto3.Session(profile_name=aws_profile_name) return client.get_credentials() + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_session_token is not None + ): ### CHECK FOR AWS SESSION TOKEN ### + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + return credentials else: session = boto3.Session( aws_access_key_id=aws_access_key_id, @@ -734,9 +761,10 @@ class BedrockLLM(BaseLLM): provider = model.split(".")[0] ## CREDENTIALS ## - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) @@ -768,6 +796,7 @@ class BedrockLLM(BaseLLM): credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, @@ -1422,6 +1451,7 @@ class BedrockConverseLLM(BaseLLM): self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, @@ -1437,6 +1467,7 @@ class BedrockConverseLLM(BaseLLM): params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -1454,6 +1485,7 @@ class BedrockConverseLLM(BaseLLM): ( aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -1547,6 +1579,19 @@ class BedrockConverseLLM(BaseLLM): client = boto3.Session(profile_name=aws_profile_name) return client.get_credentials() + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_session_token is not None + ): ### CHECK FOR AWS SESSION TOKEN ### + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + return credentials else: session = boto3.Session( aws_access_key_id=aws_access_key_id, @@ -1682,6 +1727,7 @@ class BedrockConverseLLM(BaseLLM): # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) @@ -1713,6 +1759,7 @@ class BedrockConverseLLM(BaseLLM): credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index f361ede5bf..a6726eb98c 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,4 +1,8 @@ -import time, json, httpx, asyncio +import asyncio +import json +import time + +import httpx class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): @@ -7,15 +11,18 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): """ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - if "images/generations" in request.url.path and request.url.params[ - "api-version" - ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict - "2023-06-01-preview", - "2023-07-01-preview", - "2023-08-01-preview", - "2023-09-01-preview", - "2023-10-01-preview", - ]: + _api_version = request.url.params.get("api-version", "") + if ( + "images/generations" in request.url.path + and _api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): request.url = request.url.copy_with( path="/openai/images/generations:submit" ) @@ -77,15 +84,18 @@ class CustomHTTPTransport(httpx.HTTPTransport): self, request: httpx.Request, ) -> httpx.Response: - if "images/generations" in request.url.path and request.url.params[ - "api-version" - ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict - "2023-06-01-preview", - "2023-07-01-preview", - "2023-08-01-preview", - "2023-09-01-preview", - "2023-10-01-preview", - ]: + _api_version = request.url.params.get("api-version", "") + if ( + "images/generations" in request.url.path + and _api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): request.url = request.url.copy_with( path="/openai/images/generations:submit" ) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index a3c5865fa3..9b01c96b16 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,6 +1,11 @@ +import asyncio +import os +import traceback +from typing import Any, Mapping, Optional, Union + +import httpx + import litellm -import httpx, asyncio, traceback, os -from typing import Optional, Union, Mapping, Any # https://www.python-httpx.org/advanced/timeouts _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) @@ -93,7 +98,7 @@ class AsyncHTTPHandler: response = await self.client.send(req, stream=stream) response.raise_for_status() return response - except httpx.RemoteProtocolError: + except (httpx.RemoteProtocolError, httpx.ConnectError): # Retry the request with a new session if there is a connection error new_client = self.create_client(timeout=self.timeout, concurrent_limit=1) try: @@ -109,6 +114,11 @@ class AsyncHTTPHandler: finally: await new_client.aclose() except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) raise e except Exception as e: raise e @@ -208,6 +218,7 @@ class HTTPHandler: headers: Optional[dict] = None, stream: bool = False, ): + req = self.client.build_request( "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 55a0d97daf..990ef2faeb 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from typing_extensions import overload, override import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.utils import ProviderField from litellm.utils import ( Choices, @@ -652,6 +653,36 @@ class OpenAIChatCompletion(BaseLLM): else: return client + async def make_openai_chat_completion_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = ( + await openai_aclient.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await openai_aclient.chat.completions.create( + **data, timeout=timeout + ) + return None, response + except Exception as e: + raise e + def completion( self, model_response: ModelResponse, @@ -678,17 +709,17 @@ class OpenAIChatCompletion(BaseLLM): if headers: optional_params["extra_headers"] = headers if model is None or messages is None: - raise OpenAIError(status_code=422, message=f"Missing model or messages") + raise OpenAIError(status_code=422, message="Missing model or messages") if not isinstance(timeout, float) and not isinstance( timeout, httpx.Timeout ): raise OpenAIError( status_code=422, - message=f"Timeout needs to be a float or httpx.Timeout", + message="Timeout needs to be a float or httpx.Timeout", ) - if custom_llm_provider != "openai": + if custom_llm_provider is not None and custom_llm_provider != "openai": model_response.model = f"{custom_llm_provider}/{model}" # process all OpenAI compatible provider logic here if custom_llm_provider == "mistral": @@ -836,13 +867,13 @@ class OpenAIChatCompletion(BaseLLM): self, data: dict, model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, timeout: Union[float, httpx.Timeout], api_key: Optional[str] = None, api_base: Optional[str] = None, organization: Optional[str] = None, client=None, max_retries=None, - logging_obj=None, headers=None, ): response = None @@ -869,8 +900,8 @@ class OpenAIChatCompletion(BaseLLM): }, ) - response = await openai_aclient.chat.completions.create( - **data, timeout=timeout + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, data=data, timeout=timeout ) stringified_response = response.model_dump() logging_obj.post_call( @@ -879,9 +910,11 @@ class OpenAIChatCompletion(BaseLLM): original_response=stringified_response, additional_args={"complete_input_dict": data}, ) + logging_obj.model_call_details["response_headers"] = headers return convert_to_model_response_object( response_object=stringified_response, model_response_object=model_response, + hidden_params={"headers": headers}, ) except Exception as e: raise e @@ -931,10 +964,10 @@ class OpenAIChatCompletion(BaseLLM): async def async_streaming( self, - logging_obj, timeout: Union[float, httpx.Timeout], data: dict, model: str, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, organization: Optional[str] = None, @@ -965,9 +998,10 @@ class OpenAIChatCompletion(BaseLLM): }, ) - response = await openai_aclient.chat.completions.create( - **data, timeout=timeout + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, data=data, timeout=timeout ) + logging_obj.model_call_details["response_headers"] = headers streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -992,17 +1026,43 @@ class OpenAIChatCompletion(BaseLLM): else: raise OpenAIError(status_code=500, message=f"{str(e)}") + # Embedding + async def make_openai_embedding_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call embeddings.create.with_raw_response when litellm.return_response_headers is True + - call embeddings.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = await openai_aclient.embeddings.with_raw_response.create( + **data, timeout=timeout + ) # type: ignore + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore + return None, response + except Exception as e: + raise e + async def aembedding( self, input: list, data: dict, - model_response: ModelResponse, + model_response: litellm.utils.EmbeddingResponse, timeout: float, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, - client=None, + client: Optional[AsyncOpenAI] = None, max_retries=None, - logging_obj=None, ): response = None try: @@ -1014,7 +1074,10 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, client=client, ) - response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore + headers, response = await self.make_openai_embedding_request( + openai_aclient=openai_aclient, data=data, timeout=timeout + ) + logging_obj.model_call_details["response_headers"] = headers stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -1039,9 +1102,9 @@ class OpenAIChatCompletion(BaseLLM): input: list, timeout: float, logging_obj, + model_response: litellm.utils.EmbeddingResponse, api_key: Optional[str] = None, api_base: Optional[str] = None, - model_response: Optional[litellm.utils.EmbeddingResponse] = None, optional_params=None, client=None, aembedding=None, @@ -1062,7 +1125,17 @@ class OpenAIChatCompletion(BaseLLM): ) if aembedding is True: - response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + response = self.aembedding( + data=data, + input=input, + logging_obj=logging_obj, + model_response=model_response, + api_base=api_base, + api_key=api_key, + timeout=timeout, + client=client, + max_retries=max_retries, + ) return response openai_client = self._get_openai_client( @@ -1219,6 +1292,34 @@ class OpenAIChatCompletion(BaseLLM): else: raise OpenAIError(status_code=500, message=str(e)) + # Audio Transcriptions + async def make_openai_audio_transcriptions_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True + - call openai_aclient.audio.transcriptions.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = ( + await openai_aclient.audio.transcriptions.with_raw_response.create( + **data, timeout=timeout + ) + ) # type: ignore + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore + return None, response + except Exception as e: + raise e + def audio_transcriptions( self, model: str, @@ -1276,11 +1377,11 @@ class OpenAIChatCompletion(BaseLLM): data: dict, model_response: TranscriptionResponse, timeout: float, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, client=None, max_retries=None, - logging_obj=None, ): try: openai_aclient = self._get_openai_client( @@ -1292,9 +1393,12 @@ class OpenAIChatCompletion(BaseLLM): client=client, ) - response = await openai_aclient.audio.transcriptions.create( - **data, timeout=timeout - ) # type: ignore + headers, response = await self.make_openai_audio_transcriptions_request( + openai_aclient=openai_aclient, + data=data, + timeout=timeout, + ) + logging_obj.model_call_details["response_headers"] = headers stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -1487,9 +1591,9 @@ class OpenAITextCompletion(BaseLLM): model: str, messages: list, timeout: float, + logging_obj: LiteLLMLoggingObj, print_verbose: Optional[Callable] = None, api_base: Optional[str] = None, - logging_obj=None, acompletion: bool = False, optional_params=None, litellm_params=None, diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index a97d6812c8..87af2a6bdc 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -663,19 +663,23 @@ def convert_url_to_base64(url): image_bytes = response.content base64_image = base64.b64encode(image_bytes).decode("utf-8") - img_type = url.split(".")[-1].lower() - if img_type == "jpg" or img_type == "jpeg": - img_type = "image/jpeg" - elif img_type == "png": - img_type = "image/png" - elif img_type == "gif": - img_type = "image/gif" - elif img_type == "webp": - img_type = "image/webp" + image_type = response.headers.get("Content-Type", None) + if image_type is not None and image_type.startswith("image/"): + img_type = image_type else: - raise Exception( - f"Error: Unsupported image format. Format={img_type}. Supported types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']" - ) + img_type = url.split(".")[-1].lower() + if img_type == "jpg" or img_type == "jpeg": + img_type = "image/jpeg" + elif img_type == "png": + img_type = "image/png" + elif img_type == "gif": + img_type = "image/gif" + elif img_type == "webp": + img_type = "image/webp" + else: + raise Exception( + f"Error: Unsupported image format. Format={img_type}. Supported types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']" + ) return f"data:{img_type};base64,{base64_image}" else: @@ -2029,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list): return messages +def response_schema_prompt(model: str, response_schema: dict) -> str: + """ + Decides if a user-defined custom prompt or default needs to be used + + Returns the prompt str that's passed to the model as a user message + """ + custom_prompt_details: Optional[dict] = None + response_schema_as_message = [ + {"role": "user", "content": "{}".format(response_schema)} + ] + if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: + + custom_prompt_details = litellm.custom_prompt_dict[ + f"{model}/response_schema_prompt" + ] # allow user to define custom response schema prompt by model + elif "response_schema_prompt" in litellm.custom_prompt_dict: + custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"] + + if custom_prompt_details is not None: + return custom_prompt( + role_dict=custom_prompt_details["roles"], + initial_prompt_value=custom_prompt_details["initial_prompt_value"], + final_prompt_value=custom_prompt_details["final_prompt_value"], + messages=response_schema_as_message, + ) + else: + return default_response_schema_prompt(response_schema=response_schema) + + +def default_response_schema_prompt(response_schema: dict) -> str: + """ + Used if provider/model doesn't support 'response_schema' param. + + This is the default prompt. Allow user to override this with a custom_prompt. + """ + prompt_str = """Use this JSON schema: + ```json + {} + ```""".format( + response_schema + ) + return prompt_str + + # Custom prompt template def custom_prompt( role_dict: dict, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 1dbd93048d..c1e628d175 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,6 +12,7 @@ import requests # type: ignore from pydantic import BaseModel import litellm +from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.prompt_templates.factory import ( convert_to_anthropic_image_obj, @@ -328,80 +329,86 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: contents: List[ContentType] = [] msg_i = 0 - while msg_i < len(messages): - user_content: List[PartType] = [] - init_msg_i = msg_i - ## MERGE CONSECUTIVE USER CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: - if isinstance(messages[msg_i]["content"], list): - _parts: List[PartType] = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text" and len(element["text"]) > 0: - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - user_content.extend(_parts) - elif ( - isinstance(messages[msg_i]["content"], str) - and len(messages[msg_i]["content"]) > 0 + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types ): - _part = PartType(text=messages[msg_i]["content"]) - user_content.append(_part) + if isinstance(messages[msg_i]["content"], list): + _parts: List[PartType] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text" and len(element["text"]) > 0: + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + user_content.extend(_parts) + elif ( + isinstance(messages[msg_i]["content"], str) + and len(messages[msg_i]["content"]) > 0 + ): + _part = PartType(text=messages[msg_i]["content"]) + user_content.append(_part) - msg_i += 1 + msg_i += 1 - if user_content: - contents.append(ContentType(role="user", parts=user_content)) - assistant_content = [] - ## MERGE CONSECUTIVE ASSISTANT CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - if isinstance(messages[msg_i]["content"], list): - _parts = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text": - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - assistant_content.extend(_parts) - elif messages[msg_i].get( - "tool_calls", [] - ): # support assistant tool invoke conversion - assistant_content.extend( - convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"]) + if user_content: + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + _parts = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + assistant_content.extend(_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke( + messages[msg_i]["tool_calls"] + ) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(PartType(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + _part = convert_to_gemini_tool_call_result(messages[msg_i]) + contents.append(ContentType(parts=[_part])) # type: ignore + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) ) - else: - assistant_text = ( - messages[msg_i].get("content") or "" - ) # either string or none - if assistant_text: - assistant_content.append(PartType(text=assistant_text)) - - msg_i += 1 - - if assistant_content: - contents.append(ContentType(role="model", parts=assistant_content)) - - ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and messages[msg_i]["role"] == "tool": - _part = convert_to_gemini_tool_call_result(messages[msg_i]) - contents.append(ContentType(parts=[_part])) # type: ignore - msg_i += 1 - if msg_i == init_msg_i: # prevent infinite loops - raise Exception( - "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( - messages[msg_i] - ) - ) - - return contents + return contents + except Exception as e: + raise e def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): @@ -437,7 +444,7 @@ def completion( except: raise VertexAIError( status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", + message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM", ) if not ( diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index ee6653afcb..6b39716f18 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -1,24 +1,32 @@ # What is this? ## Handler file for calling claude-3 on vertex ai -import os, types +import copy import json +import os +import time +import types +import uuid from enum import Enum -import requests, copy # type: ignore -import time, uuid -from typing import Callable, Optional, List -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -from litellm.litellm_core_utils.core_helpers import map_finish_reason +from typing import Any, Callable, List, Optional, Tuple + +import httpx # type: ignore +import requests # type: ignore + import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + from .prompt_templates.factory import ( - contains_tag, - prompt_factory, - custom_prompt, construct_tool_use_system_prompt, + contains_tag, + custom_prompt, extract_between_tags, parse_xml_params, + prompt_factory, + response_schema_prompt, ) -import httpx # type: ignore class VertexAIError(Exception): @@ -104,6 +112,7 @@ class VertexAIAnthropicConfig: "stop", "temperature", "top_p", + "response_format", ] def map_openai_params(self, non_default_params: dict, optional_params: dict): @@ -120,6 +129,8 @@ class VertexAIAnthropicConfig: optional_params["temperature"] = value if param == "top_p": optional_params["top_p"] = value + if param == "response_format" and "response_schema" in value: + optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore return optional_params @@ -129,7 +140,6 @@ class VertexAIAnthropicConfig: """ -# makes headers for API call def refresh_auth( credentials, ) -> str: # used when user passes in credentials as json string @@ -144,6 +154,40 @@ def refresh_auth( return credentials.token +def get_vertex_client( + client: Any, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], +) -> Tuple[Any, Optional[str]]: + args = locals() + from litellm.llms.vertex_httpx import VertexLLM + + try: + from anthropic import AnthropicVertex + except Exception: + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", + ) + + access_token: Optional[str] = None + + if client is None: + _credentials, cred_project_id = VertexLLM().load_auth( + credentials=vertex_credentials, project_id=vertex_project + ) + vertex_ai_client = AnthropicVertex( + project_id=vertex_project or cred_project_id, + region=vertex_location or "us-central1", + access_token=_credentials.token, + ) + else: + vertex_ai_client = client + + return vertex_ai_client, access_token + + def completion( model: str, messages: list, @@ -151,10 +195,10 @@ def completion( print_verbose: Callable, encoding, logging_obj, + optional_params: dict, vertex_project=None, vertex_location=None, vertex_credentials=None, - optional_params=None, litellm_params=None, logger_fn=None, acompletion: bool = False, @@ -178,6 +222,13 @@ def completion( ) try: + vertex_ai_client, access_token = get_vertex_client( + client=client, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + ) + ## Load Config config = litellm.VertexAIAnthropicConfig.get_config() for k, v in config.items(): @@ -186,6 +237,7 @@ def completion( ## Format Prompt _is_function_call = False + _is_json_schema = False messages = copy.deepcopy(messages) optional_params = copy.deepcopy(optional_params) # Separate system prompt from rest of message @@ -200,6 +252,29 @@ def completion( messages.pop(idx) if len(system_prompt) > 0: optional_params["system"] = system_prompt + # Checks for 'response_schema' support - if passed in + if "response_format" in optional_params: + response_format_chunk = ResponseFormatChunk( + **optional_params["response_format"] # type: ignore + ) + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if ( + supports_response_schema is False + and response_format_chunk["type"] == "json_object" + and "response_schema" in response_format_chunk + ): + _is_json_schema = True + user_response_schema_message = response_schema_prompt( + model=model, + response_schema=response_format_chunk["response_schema"], + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + messages.append({"role": "assistant", "content": "{"}) + optional_params.pop("response_format") # Format rest of message according to anthropic guidelines try: messages = prompt_factory( @@ -233,32 +308,6 @@ def completion( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" ) - access_token = None - if client is None: - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - try: - json_obj = json.loads(vertex_credentials) - except json.JSONDecodeError: - json_obj = json.load(open(vertex_credentials)) - - creds = ( - google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - ) - ### CHECK IF ACCESS - access_token = refresh_auth(credentials=creds) - - vertex_ai_client = AnthropicVertex( - project_id=vertex_project, - region=vertex_location, - access_token=access_token, - ) - else: - vertex_ai_client = client if acompletion == True: """ @@ -315,7 +364,16 @@ def completion( ) message = vertex_ai_client.messages.create(**data) # type: ignore - text_content = message.content[0].text + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=message, + additional_args={"complete_input_dict": data}, + ) + + text_content: str = message.content[0].text ## TOOL CALLING - OUTPUT PARSE if text_content is not None and contains_tag("invoke", text_content): function_name = extract_between_tags("tool_name", text_content)[0] @@ -339,7 +397,13 @@ def completion( ) model_response.choices[0].message = _message # type: ignore else: - model_response.choices[0].message.content = text_content # type: ignore + if ( + _is_json_schema + ): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb + json_response = "{" + text_content[: text_content.rfind("}") + 1] + model_response.choices[0].message.content = json_response # type: ignore + else: + model_response.choices[0].message.content = text_content # type: ignore model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) ## CALCULATING USAGE diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 856b05f61c..2ea0e199e8 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore -import ijson import requests # type: ignore import litellm @@ -21,7 +20,10 @@ import litellm.litellm_core_utils.litellm_logging from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.prompt_templates.factory import convert_url_to_base64 +from litellm.llms.prompt_templates.factory import ( + convert_url_to_base64, + response_schema_prompt, +) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.llms.openai import ( ChatCompletionResponseMessage, @@ -183,10 +185,17 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty if param == "tools" and isinstance(value, list): gtool_func_declarations = [] for tool in value: + _parameters = tool.get("function", {}).get("parameters", {}) + _properties = _parameters.get("properties", {}) + if isinstance(_properties, dict): + for _, _property in _properties.items(): + if "enum" in _property and "format" not in _property: + _property["format"] = "enum" + gtool_func_declaration = FunctionDeclaration( name=tool["function"]["name"], description=tool["function"].get("description", ""), - parameters=tool["function"].get("parameters", {}), + parameters=_parameters, ) gtool_func_declarations.append(gtool_func_declaration) optional_params["tools"] = [ @@ -349,6 +358,7 @@ class VertexGeminiConfig: model: str, non_default_params: dict, optional_params: dict, + drop_params: bool, ): for param, value in non_default_params.items(): if param == "temperature": @@ -368,8 +378,13 @@ class VertexGeminiConfig: optional_params["stop_sequences"] = value if param == "max_tokens": optional_params["max_output_tokens"] = value - if param == "response_format" and value["type"] == "json_object": # type: ignore - optional_params["response_mime_type"] = "application/json" + if param == "response_format" and isinstance(value, dict): # type: ignore + if value["type"] == "json_object": + optional_params["response_mime_type"] = "application/json" + elif value["type"] == "text": + optional_params["response_mime_type"] = "text/plain" + if "response_schema" in value: + optional_params["response_schema"] = value["response_schema"] if param == "frequency_penalty": optional_params["frequency_penalty"] = value if param == "presence_penalty": @@ -460,7 +475,7 @@ async def make_call( raise VertexAIError(status_code=response.status_code, message=response.text) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_bytes(), sync_stream=False + streaming_response=response.aiter_lines(), sync_stream=False ) # LOGGING logging_obj.post_call( @@ -491,7 +506,7 @@ def make_sync_call( raise VertexAIError(status_code=response.status_code, message=response.read()) completion_stream = ModelResponseIterator( - streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True + streaming_response=response.iter_lines(), sync_stream=True ) # LOGGING @@ -767,11 +782,11 @@ class VertexLLM(BaseLLM): return self.access_token, self.project_id if not self._credentials: - self._credentials, project_id = self.load_auth( + self._credentials, cred_project_id = self.load_auth( credentials=credentials, project_id=project_id ) if not self.project_id: - self.project_id = project_id + self.project_id = project_id or cred_project_id else: self.refresh_auth(self._credentials) @@ -811,12 +826,13 @@ class VertexLLM(BaseLLM): endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - - url = ( - "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( + _gemini_model_name, endpoint, gemini_api_key + ) + else: + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( _gemini_model_name, endpoint, gemini_api_key ) - ) else: auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project @@ -827,7 +843,9 @@ class VertexLLM(BaseLLM): endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" if ( api_base is not None @@ -840,6 +858,9 @@ class VertexLLM(BaseLLM): else: url = "{}:{}".format(api_base, endpoint) + if stream is True: + url = url + "?alt=sse" + return auth_header, url async def async_streaming( @@ -992,35 +1013,58 @@ class VertexLLM(BaseLLM): if len(system_prompt_indices) > 0: for idx in reversed(system_prompt_indices): messages.pop(idx) - content = _gemini_convert_messages_with_history(messages=messages) - tools: Optional[Tools] = optional_params.pop("tools", None) - tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) - safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( - "safety_settings", None - ) # type: ignore - generation_config: Optional[GenerationConfig] = GenerationConfig( - **optional_params - ) - data = RequestBody(contents=content) - if len(system_content_blocks) > 0: - system_instructions = SystemInstructions(parts=system_content_blocks) - data["system_instruction"] = system_instructions - if tools is not None: - data["tools"] = tools - if tool_choice is not None: - data["toolConfig"] = tool_choice - if safety_settings is not None: - data["safetySettings"] = safety_settings - if generation_config is not None: - data["generationConfig"] = generation_config - headers = { - "Content-Type": "application/json; charset=utf-8", - } - if auth_header is not None: - headers["Authorization"] = f"Bearer {auth_header}" - if extra_headers is not None: - headers.update(extra_headers) + # Checks for 'response_schema' support - if passed in + if "response_schema" in optional_params: + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if supports_response_schema is False: + user_response_schema_message = response_schema_prompt( + model=model, response_schema=optional_params.get("response_schema") # type: ignore + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + optional_params.pop("response_schema") + + try: + content = _gemini_convert_messages_with_history(messages=messages) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( + "safety_settings", None + ) # type: ignore + cached_content: Optional[str] = optional_params.pop( + "cached_content", None + ) + generation_config: Optional[GenerationConfig] = GenerationConfig( + **optional_params + ) + data = RequestBody(contents=content) + if len(system_content_blocks) > 0: + system_instructions = SystemInstructions(parts=system_content_blocks) + data["system_instruction"] = system_instructions + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice + if safety_settings is not None: + data["safetySettings"] = safety_settings + if generation_config is not None: + data["generationConfig"] = generation_config + if cached_content is not None: + data["cachedContent"] = cached_content + + headers = { + "Content-Type": "application/json", + } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) + except Exception as e: + raise e ## LOGGING logging_obj.pre_call( @@ -1268,11 +1312,6 @@ class VertexLLM(BaseLLM): class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response - if sync_stream: - self.response_iterator = iter(self.streaming_response) - - self.events = ijson.sendable_list() - self.coro = ijson.items_coro(self.events, "item") def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: @@ -1302,9 +1341,9 @@ class ModelResponseIterator: if "usageMetadata" in processed_chunk: usage = ChatCompletionUsageBlock( prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"], - completion_tokens=processed_chunk["usageMetadata"][ - "candidatesTokenCount" - ], + completion_tokens=processed_chunk["usageMetadata"].get( + "candidatesTokenCount", 0 + ), total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"], ) @@ -1322,31 +1361,36 @@ class ModelResponseIterator: # Sync iterator def __iter__(self): + self.response_iterator = self.streaming_response return self def __next__(self): try: chunk = self.response_iterator.__next__() - self.coro.send(chunk) - if self.events: - event = self.events.pop(0) - json_chunk = event - return self.chunk_parser(chunk=json_chunk) - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) except StopIteration: - if self.events: # flush the events - event = self.events.pop(0) # Remove the first event - return self.chunk_parser(chunk=event) raise StopIteration except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e}") + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") # Async iterator def __aiter__(self): @@ -1356,23 +1400,27 @@ class ModelResponseIterator: async def __anext__(self): try: chunk = await self.async_response_iterator.__anext__() - self.coro.send(chunk) - if self.events: - event = self.events.pop(0) - json_chunk = event - return self.chunk_parser(chunk=json_chunk) - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) except StopAsyncIteration: - if self.events: # flush the events - event = self.events.pop(0) # Remove the first event - return self.chunk_parser(chunk=event) raise StopAsyncIteration except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e}") + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/llms/volcengine.py b/litellm/llms/volcengine.py new file mode 100644 index 0000000000..eb289d1c49 --- /dev/null +++ b/litellm/llms/volcengine.py @@ -0,0 +1,87 @@ +import types +from typing import Literal, Optional, Union + +import litellm + + +class VolcEngineConfig: + frequency_penalty: Optional[int] = None + function_call: Optional[Union[str, dict]] = None + functions: Optional[list] = None + logit_bias: Optional[dict] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + response_format: Optional[dict] = None + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + response_format: Optional[dict] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self, model: str) -> list: + return [ + "frequency_penalty", + "logit_bias", + "logprobs", + "top_logprobs", + "max_tokens", + "n", + "presence_penalty", + "seed", + "stop", + "stream", + "stream_options", + "temperature", + "top_p", + "tools", + "tool_choice", + "function_call", + "functions", + "max_retries", + "extra_headers", + ] # works across all models + + def map_openai_params( + self, non_default_params: dict, optional_params: dict, model: str + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params diff --git a/litellm/main.py b/litellm/main.py index cf6f4c7106..880b0ebead 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -48,6 +48,7 @@ from litellm import ( # type: ignore get_litellm_params, get_optional_params, ) +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.utils import ( CustomStreamWrapper, Usage, @@ -349,6 +350,7 @@ async def acompletion( or custom_llm_provider == "perplexity" or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "volcengine" or custom_llm_provider == "codestral" or custom_llm_provider == "text-completion-codestral" or custom_llm_provider == "deepseek" @@ -475,6 +477,15 @@ def mock_completion( model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) + elif ( + isinstance(mock_response, str) and mock_response == "litellm.RateLimitError" + ): + raise litellm.RateLimitError( + message="this is a mock rate limit error", + status_code=getattr(mock_response, "status_code", 429), # type: ignore + llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + model=model, + ) time_delay = kwargs.get("mock_delay", None) if time_delay is not None: time.sleep(time_delay) @@ -675,6 +686,8 @@ def completion( client = kwargs.get("client", None) ### Admin Controls ### no_log = kwargs.get("no-log", False) + ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 + messages = deepcopy(messages) ######## end of unpacking kwargs ########### openai_params = [ "functions", @@ -1024,7 +1037,7 @@ def completion( client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) - if optional_params.get("stream", False) or acompletion == True: + if optional_params.get("stream", False): ## LOGGING logging.post_call( input=messages, @@ -1192,6 +1205,7 @@ def completion( or custom_llm_provider == "perplexity" or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "volcengine" or custom_llm_provider == "codestral" or custom_llm_provider == "deepseek" or custom_llm_provider == "anyscale" @@ -1826,6 +1840,7 @@ def completion( logging_obj=logging, acompletion=acompletion, timeout=timeout, # type: ignore + custom_llm_provider="openrouter", ) ## LOGGING logging.post_call( @@ -2197,13 +2212,33 @@ def completion( # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - if ( - "aws_bedrock_client" in optional_params - ): # use old bedrock flow for aws_bedrock_client users. - response = bedrock.completion( + if "aws_bedrock_client" in optional_params: + verbose_logger.warning( + "'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication." + ) + # Extract credentials for legacy boto3 client and pass thru to httpx + aws_bedrock_client = optional_params.pop("aws_bedrock_client") + creds = aws_bedrock_client._get_credentials().get_frozen_credentials() + + if creds.access_key: + optional_params["aws_access_key_id"] = creds.access_key + if creds.secret_key: + optional_params["aws_secret_access_key"] = creds.secret_key + if creds.token: + optional_params["aws_session_token"] = creds.token + if ( + "aws_region_name" not in optional_params + or optional_params["aws_region_name"] is None + ): + optional_params["aws_region_name"] = ( + aws_bedrock_client.meta.region_name + ) + + if model in litellm.BEDROCK_CONVERSE_MODELS: + response = bedrock_converse_chat_completion.completion( model=model, messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, + custom_prompt_dict=custom_prompt_dict, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -2213,63 +2248,27 @@ def completion( logging_obj=logging, extra_headers=extra_headers, timeout=timeout, + acompletion=acompletion, + client=client, + ) + else: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - if "ai21" in model: - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - if model.startswith("anthropic"): - response = bedrock_converse_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) - else: - response = bedrock_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) if optional_params.get("stream", False): ## LOGGING logging.post_call( @@ -2954,6 +2953,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "perplexity" or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "volcengine" or custom_llm_provider == "deepseek" or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" @@ -3533,6 +3533,7 @@ async def atext_completion( or custom_llm_provider == "perplexity" or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "volcengine" or custom_llm_provider == "text-completion-codestral" or custom_llm_provider == "deepseek" or custom_llm_provider == "fireworks_ai" @@ -4262,7 +4263,7 @@ def transcription( api_base: Optional[str] = None, api_version: Optional[str] = None, max_retries: Optional[int] = None, - litellm_logging_obj=None, + litellm_logging_obj: Optional[LiteLLMLoggingObj] = None, custom_llm_provider=None, **kwargs, ): @@ -4277,6 +4278,18 @@ def transcription( proxy_server_request = kwargs.get("proxy_server_request", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) + client: Optional[ + Union[ + openai.AsyncOpenAI, + openai.OpenAI, + openai.AzureOpenAI, + openai.AsyncAzureOpenAI, + ] + ] = kwargs.pop("client", None) + + if litellm_logging_obj: + litellm_logging_obj.model_call_details["client"] = str(client) + if max_retries is None: max_retries = openai.DEFAULT_MAX_RETRIES @@ -4316,6 +4329,7 @@ def transcription( optional_params=optional_params, model_response=model_response, atranscription=atranscription, + client=client, timeout=timeout, logging_obj=litellm_logging_obj, api_base=api_base, @@ -4349,6 +4363,7 @@ def transcription( optional_params=optional_params, model_response=model_response, atranscription=atranscription, + client=client, timeout=timeout, logging_obj=litellm_logging_obj, max_retries=max_retries, @@ -4406,6 +4421,7 @@ def speech( voice: str, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, organization: Optional[str] = None, project: Optional[str] = None, max_retries: Optional[int] = None, @@ -4479,6 +4495,45 @@ def speech( client=client, # pass AsyncOpenAI, OpenAI client aspeech=aspeech, ) + elif custom_llm_provider == "azure": + # azure configs + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + azure_ad_token: Optional[str] = optional_params.get("extra_body", {}).pop( # type: ignore + "azure_ad_token", None + ) or get_secret( + "AZURE_AD_TOKEN" + ) + + headers = headers or litellm.headers + + response = azure_chat_completions.audio_speech( + model=model, + input=input, + voice=voice, + optional_params=optional_params, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + organization=organization, + max_retries=max_retries, + timeout=timeout, + client=client, # pass AsyncOpenAI, OpenAI client + aspeech=aspeech, + ) if response is None: raise Exception( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index acd03aeea8..7f08b9eb19 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -863,6 +863,46 @@ "litellm_provider": "deepseek", "mode": "chat" }, + "codestral/codestral-latest": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "codestral", + "mode": "chat", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, + "codestral/codestral-2405": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "codestral", + "mode": "chat", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, + "text-completion-codestral/codestral-latest": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "text-completion-codestral", + "mode": "completion", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, + "text-completion-codestral/codestral-2405": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "text-completion-codestral", + "mode": "completion", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, "deepseek-coder": { "max_tokens": 4096, "max_input_tokens": 32000, @@ -1028,21 +1068,55 @@ "tool_use_system_prompt_tokens": 159 }, "text-bison": { - "max_tokens": 1024, + "max_tokens": 2048, "max_input_tokens": 8192, - "max_output_tokens": 1024, - "input_cost_per_token": 0.000000125, - "output_cost_per_token": 0.000000125, + "max_output_tokens": 2048, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-text-models", "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "text-bison@001": { + "max_tokens": 1024, + "max_input_tokens": 8192, + "max_output_tokens": 1024, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "text-bison@002": { + "max_tokens": 1024, + "max_input_tokens": 8192, + "max_output_tokens": 1024, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "text-bison32k": { "max_tokens": 1024, "max_input_tokens": 8192, "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "text-bison32k@002": { + "max_tokens": 1024, + "max_input_tokens": 8192, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-text-models", "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1073,6 +1147,8 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1083,6 +1159,8 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1093,6 +1171,8 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1103,6 +1183,20 @@ "max_output_tokens": 8192, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "chat-bison-32k@002": { + "max_tokens": 8192, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1113,6 +1207,8 @@ "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-text-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1123,6 +1219,44 @@ "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "code-bison@002": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "code-bison32k": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "code-bison-32k@002": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-text-models", "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1157,12 +1291,36 @@ "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, + "code-gecko-latest": { + "max_tokens": 64, + "max_input_tokens": 2048, + "max_output_tokens": 64, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "codechat-bison@latest": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, "codechat-bison": { "max_tokens": 1024, "max_input_tokens": 6144, "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1173,6 +1331,20 @@ "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "codechat-bison@002": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1183,6 +1355,20 @@ "max_output_tokens": 8192, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "codechat-bison-32k@002": { + "max_tokens": 8192, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1232,6 +1418,36 @@ "supports_function_calling": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, + "gemini-1.0-ultra": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 2048, + "input_cost_per_image": 0.0025, + "input_cost_per_video_per_second": 0.002, + "input_cost_per_token": 0.0000005, + "input_cost_per_character": 0.000000125, + "output_cost_per_token": 0.0000015, + "output_cost_per_character": 0.000000375, + "litellm_provider": "vertex_ai-language-models", + "mode": "chat", + "supports_function_calling": true, + "source": "As of Jun, 2024. There is no available doc on vertex ai pricing gemini-1.0-ultra-001. Using gemini-1.0-pro pricing. Got max_tokens info here: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "gemini-1.0-ultra-001": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 2048, + "input_cost_per_image": 0.0025, + "input_cost_per_video_per_second": 0.002, + "input_cost_per_token": 0.0000005, + "input_cost_per_character": 0.000000125, + "output_cost_per_token": 0.0000015, + "output_cost_per_character": 0.000000375, + "litellm_provider": "vertex_ai-language-models", + "mode": "chat", + "supports_function_calling": true, + "source": "As of Jun, 2024. There is no available doc on vertex ai pricing gemini-1.0-ultra-001. Using gemini-1.0-pro pricing. Got max_tokens info here: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, "gemini-1.0-pro-002": { "max_tokens": 8192, "max_input_tokens": 32760, @@ -1249,7 +1465,7 @@ }, "gemini-1.5-pro": { "max_tokens": 8192, - "max_input_tokens": 1000000, + "max_input_tokens": 2097152, "max_output_tokens": 8192, "input_cost_per_image": 0.001315, "input_cost_per_audio_per_second": 0.000125, @@ -1270,6 +1486,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-001": { @@ -1295,6 +1512,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0514": { @@ -1320,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1345,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1368,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { @@ -1779,7 +2000,7 @@ }, "gemini/gemini-1.5-pro": { "max_tokens": 8192, - "max_input_tokens": 1000000, + "max_input_tokens": 2097152, "max_output_tokens": 8192, "input_cost_per_token": 0.00000035, "input_cost_per_token_above_128k_tokens": 0.0000007, @@ -1791,6 +2012,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -1807,6 +2029,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": { diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 938e74b5e7..31c96a0661 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,54 +1,9 @@ -# model_list: -# - model_name: my-fake-model -# litellm_params: -# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 -# api_key: my-fake-key -# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 -# mock_response: "Hello world 1" -# model_info: -# max_input_tokens: 0 # trigger context window fallback -# - model_name: my-fake-model -# litellm_params: -# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 -# api_key: my-fake-key -# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 -# mock_response: "Hello world 2" -# model_info: -# max_input_tokens: 0 - -# router_settings: -# enable_pre_call_checks: True - - -# litellm_settings: -# failure_callback: ["langfuse"] - model_list: - - model_name: summarize + - model_name: claude-3-5-sonnet # all requests where model not in your config go to this deployment litellm_params: - model: openai/gpt-4o - rpm: 10000 - tpm: 12000000 - api_key: os.environ/OPENAI_API_KEY - mock_response: Hello world 1 + model: "openai/*" + mock_response: "Hello world!" - - model_name: summarize-l - litellm_params: - model: claude-3-5-sonnet-20240620 - rpm: 4000 - tpm: 400000 - api_key: os.environ/ANTHROPIC_API_KEY - mock_response: Hello world 2 - -litellm_settings: - num_retries: 3 - request_timeout: 120 - allowed_fails: 3 - # fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}] - # context_window_fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}] - - - -router_settings: - routing_strategy: simple-shuffle - enable_pre_call_checks: true. +general_settings: + alerting: ["slack"] + alerting_threshold: 10 \ No newline at end of file diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 2060f61ca4..ede853094e 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -1,7 +1,11 @@ model_list: +- model_name: claude-3-5-sonnet + litellm_params: + model: anthropic/claude-3-5-sonnet - model_name: gemini-1.5-flash-gemini litellm_params: - model: gemini/gemini-1.5-flash + model: vertex_ai_beta/gemini-1.5-flash + api_base: https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash - litellm_params: api_base: http://0.0.0.0:8080 api_key: '' @@ -18,7 +22,6 @@ model_list: api_key: os.environ/PREDIBASE_API_KEY tenant_id: os.environ/PREDIBASE_TENANT_ID max_new_tokens: 256 - # - litellm_params: # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ # api_key: os.environ/AZURE_EUROPE_API_KEY diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 640c7695a0..1f1aaf0eea 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1622,7 +1622,7 @@ class ProxyException(Exception): } -class CommonProxyErrors(enum.Enum): +class CommonProxyErrors(str, enum.Enum): db_not_connected_error = "DB not connected" no_llm_router = "No models configured on proxy" not_allowed_access = "Admin-only endpoint. Not allowed to access this." diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 2e670de852..aec6215ced 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -144,10 +144,13 @@ async def add_litellm_data_to_request( ) # do not store the original `sk-..` api key in the db data[_metadata_variable_name]["headers"] = _headers data[_metadata_variable_name]["endpoint"] = str(request.url) + + # OTEL Controls / Tracing # Add the OTEL Parent Trace before sending it LiteLLM data[_metadata_variable_name][ "litellm_parent_otel_span" ] = user_api_key_dict.parent_otel_span + _add_otel_traceparent_to_data(data, request=request) ### END-USER SPECIFIC PARAMS ### if user_api_key_dict.allowed_model_region is not None: @@ -169,3 +172,23 @@ async def add_litellm_data_to_request( } # add the team-specific configs to the completion call return data + + +def _add_otel_traceparent_to_data(data: dict, request: Request): + from litellm.proxy.proxy_server import open_telemetry_logger + if data is None: + return + if open_telemetry_logger is None: + # if user is not use OTEL don't send extra_headers + # relevant issue: https://github.com/BerriAI/litellm/issues/4448 + return + if request.headers: + if "traceparent" in request.headers: + # we want to forward this to the LLM Provider + # Relevant issue: https://github.com/BerriAI/litellm/issues/4419 + # pass this in extra_headers + if "extra_headers" not in data: + data["extra_headers"] = {} + _exra_headers = data["extra_headers"] + if "traceparent" not in _exra_headers: + _exra_headers["traceparent"] = request.headers["traceparent"] diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py new file mode 100644 index 0000000000..218032e012 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -0,0 +1,173 @@ +import ast +import traceback +from base64 import b64encode + +import httpx +from fastapi import ( + APIRouter, + Depends, + FastAPI, + HTTPException, + Request, + Response, + status, +) +from fastapi.responses import StreamingResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ProxyException +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +async_client = httpx.AsyncClient() + + +async def set_env_variables_in_header(custom_headers: dict): + """ + checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc + + only runs for headers defined on config.yaml + + example header can be + + {"Authorization": "bearer os.environ/COHERE_API_KEY"} + """ + headers = {} + for key, value in custom_headers.items(): + # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys + # we can then get the b64 encoded keys here + if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": + # langfuse requires b64 encoded headers - we construct that here + _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] + _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] + if isinstance( + _langfuse_public_key, str + ) and _langfuse_public_key.startswith("os.environ/"): + _langfuse_public_key = litellm.get_secret(_langfuse_public_key) + if isinstance( + _langfuse_secret_key, str + ) and _langfuse_secret_key.startswith("os.environ/"): + _langfuse_secret_key = litellm.get_secret(_langfuse_secret_key) + headers["Authorization"] = "Basic " + b64encode( + f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") + ).decode("ascii") + else: + # for all other headers + headers[key] = value + if isinstance(value, str) and "os.environ/" in value: + verbose_proxy_logger.debug( + "pass through endpoint - looking up 'os.environ/' variable" + ) + # get string section that is os.environ/ + start_index = value.find("os.environ/") + _variable_name = value[start_index:] + + verbose_proxy_logger.debug( + "pass through endpoint - getting secret for variable name: %s", + _variable_name, + ) + _secret_value = litellm.get_secret(_variable_name) + new_value = value.replace(_variable_name, _secret_value) + headers[key] = new_value + return headers + + +async def pass_through_request(request: Request, target: str, custom_headers: dict): + try: + + url = httpx.URL(target) + headers = custom_headers + + request_body = await request.body() + _parsed_body = ast.literal_eval(request_body.decode("utf-8")) + + verbose_proxy_logger.debug( + "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( + url, headers, _parsed_body + ) + ) + + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=request.query_params, + json=_parsed_body, + ) + + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.text) + + content = await response.aread() + return Response( + content=content, + status_code=response.status_code, + headers=dict(response.headers), + ) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.pass through endpoint(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +def create_pass_through_route(endpoint, target, custom_headers=None): + async def endpoint_func(request: Request): + return await pass_through_request(request, target, custom_headers) + + return endpoint_func + + +async def initialize_pass_through_endpoints(pass_through_endpoints: list): + + verbose_proxy_logger.debug("initializing pass through endpoints") + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import app, premium_user + + for endpoint in pass_through_endpoints: + _target = endpoint.get("target", None) + _path = endpoint.get("path", None) + _custom_headers = endpoint.get("headers", None) + _custom_headers = await set_env_variables_in_header( + custom_headers=_custom_headers + ) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) + + if _target is None: + continue + + verbose_proxy_logger.debug("adding pass through endpoint: %s", _path) + + app.add_api_route( + path=_path, + endpoint=create_pass_through_route(_path, _target, _custom_headers), + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=_dependencies, + ) + + verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py new file mode 100644 index 0000000000..6ee09c22b6 --- /dev/null +++ b/litellm/proxy/prisma_migration.py @@ -0,0 +1,68 @@ +# What is this? +## Script to apply initial prisma migration on Docker setup + +import os +import subprocess +import sys +import time + +sys.path.insert( + 0, os.path.abspath("./") +) # Adds the parent directory to the system path +from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + +if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + +# Check if DATABASE_URL is not set +database_url = os.getenv("DATABASE_URL") +if not database_url: + # Check if all required variables are provided + database_host = os.getenv("DATABASE_HOST") + database_username = os.getenv("DATABASE_USERNAME") + database_password = os.getenv("DATABASE_PASSWORD") + database_name = os.getenv("DATABASE_NAME") + + if database_host and database_username and database_password and database_name: + # Construct DATABASE_URL from the provided variables + database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}" + os.environ["DATABASE_URL"] = database_url + else: + print( # noqa + "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa + ) + exit(1) + +# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations +direct_url = os.getenv("DIRECT_URL") +if not direct_url: + os.environ["DIRECT_URL"] = database_url + +# Apply migrations +retry_count = 0 +max_retries = 3 +exit_code = 1 + +while retry_count < max_retries and exit_code != 0: + retry_count += 1 + print(f"Attempt {retry_count}...") # noqa + + # Run the Prisma db push command + result = subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"], capture_output=True + ) + exit_code = result.returncode + + if exit_code != 0 and retry_count < max_retries: + print("Retrying in 10 seconds...") # noqa + time.sleep(10) + +if exit_code != 0: + print(f"Unable to push database changes after {max_retries} retries.") # noqa + exit(1) + +print("Database push successful!") # noqa diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 6e6d1f4a9e..e987046428 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -442,6 +442,20 @@ def run_server( db_connection_pool_limit = 100 db_connection_timeout = 60 + ### DECRYPT ENV VAR ### + + from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + + if ( + os.getenv("USE_AWS_KMS", None) is not None + and os.getenv("USE_AWS_KMS") == "True" + ): + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + if config is not None: """ Allow user to pass in db url via config @@ -459,6 +473,7 @@ def run_server( proxy_config = ProxyConfig() _config = asyncio.run(proxy_config.get_config(config_file_path=config)) + ### LITELLM SETTINGS ### litellm_settings = _config.get("litellm_settings", None) if ( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 0c0365f43d..9f2324e51c 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -20,11 +20,23 @@ model_list: general_settings: master_key: sk-1234 - alerting: ["slack", "email"] - public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] - + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + auth: true # 👈 Key change to use LiteLLM Auth / Keys + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + content-type: application/json + accept: application/json + - path: "/api/public/ingestion" + target: "https://us.cloud.langfuse.com/api/public/ingestion" + auth: true + headers: + LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" + LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" litellm_settings: + return_response_headers: true success_callback: ["prometheus"] callbacks: ["otel", "hide_secrets"] failure_callback: ["prometheus"] @@ -34,6 +46,5 @@ litellm_settings: - user - metadata - metadata.generation_name - cache: True diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c3b855c5f5..1ca1807223 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -161,6 +161,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( router as key_management_router, ) from litellm.proxy.management_endpoints.team_endpoints import router as team_router +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + initialize_pass_through_endpoints, +) from litellm.proxy.secret_managers.aws_secret_manager import ( load_aws_kms, load_aws_secret_manager, @@ -433,6 +436,7 @@ def get_custom_headers( api_base: Optional[str] = None, version: Optional[str] = None, model_region: Optional[str] = None, + response_cost: Optional[Union[float, str]] = None, fastest_response_batch_completion: Optional[bool] = None, **kwargs, ) -> dict: @@ -443,6 +447,7 @@ def get_custom_headers( "x-litellm-model-api-base": api_base, "x-litellm-version": version, "x-litellm-model-region": model_region, + "x-litellm-response-cost": str(response_cost), "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), "x-litellm-fastest_response_batch_completion": ( @@ -588,7 +593,7 @@ async def _PROXY_failure_handler( _model_id = _metadata.get("model_info", {}).get("id", "") _model_group = _metadata.get("model_group", "") api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params) - _exception_string = str(_exception)[:500] + _exception_string = str(_exception) error_log = LiteLLM_ErrorLogs( request_id=str(uuid.uuid4()), @@ -1177,9 +1182,13 @@ async def _run_background_health_check(): Update health_check_results, based on this. """ global health_check_results, llm_model_list, health_check_interval + + # make 1 deep copy of llm_model_list -> use this for all background health checks + _llm_model_list = copy.deepcopy(llm_model_list) + while True: healthy_endpoints, unhealthy_endpoints = await perform_health_check( - model_list=llm_model_list + model_list=_llm_model_list ) # Update the global variable with the health check results @@ -1854,6 +1863,11 @@ class ProxyConfig: user_custom_key_generate = get_instance_fn( value=custom_key_generate, config_file_path=config_file_path ) + ## pass through endpoints + if general_settings.get("pass_through_endpoints", None) is not None: + await initialize_pass_through_endpoints( + pass_through_endpoints=general_settings["pass_through_endpoints"] + ) ## dynamodb database_type = general_settings.get("database_type", None) if database_type is not None and ( @@ -2954,6 +2968,11 @@ async def chat_completion( if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map: data["model"] = litellm.model_alias_map[data["model"]] + ### CALL HOOKS ### - modify/reject incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + ) + ## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call data["litellm_call_id"] = str(uuid.uuid4()) logging_obj, data = litellm.utils.function_setup( @@ -2965,11 +2984,6 @@ async def chat_completion( data["litellm_logging_obj"] = logging_obj - ### CALL HOOKS ### - modify/reject incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( # type: ignore - user_api_key_dict=user_api_key_dict, data=data, call_type="completion" - ) - tasks = [] tasks.append( proxy_logging_obj.during_call_hook( @@ -3048,6 +3062,7 @@ async def chat_completion( model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" fastest_response_batch_completion = hidden_params.get( "fastest_response_batch_completion", None ) @@ -3055,8 +3070,11 @@ async def chat_completion( # Post Call Processing if llm_router is not None: data["deployment"] = llm_router.get_deployment(model_id=model_id) - data["litellm_status"] = "success" # used for alerting - + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses @@ -3066,6 +3084,7 @@ async def chat_completion( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=fastest_response_batch_completion, ) @@ -3095,6 +3114,7 @@ async def chat_completion( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=fastest_response_batch_completion, **additional_headers, @@ -3104,7 +3124,6 @@ async def chat_completion( return response except RejectedRequestError as e: _data = e.request_data - _data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, @@ -3137,7 +3156,6 @@ async def chat_completion( _chat_response.usage = _usage # type: ignore return _chat_response except Exception as e: - data["litellm_status"] = "fail" # used for alerting verbose_proxy_logger.error( "litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format( get_error_message_str(e=e), traceback.format_exc() @@ -3290,9 +3308,14 @@ async def completion( model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) verbose_proxy_logger.debug("final response: %s", response) if ( @@ -3304,6 +3327,7 @@ async def completion( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, ) selected_data_generator = select_data_generator( response=response, @@ -3323,13 +3347,13 @@ async def completion( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, ) ) return response except RejectedRequestError as e: _data = e.request_data - _data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, @@ -3368,7 +3392,6 @@ async def completion( _response.choices[0].text = e.message return _response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -3520,13 +3543,18 @@ async def embeddings( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" fastapi_response.headers.update( get_custom_headers( @@ -3535,13 +3563,13 @@ async def embeddings( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), ) ) return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -3669,13 +3697,17 @@ async def image_generation( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting - + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" fastapi_response.headers.update( get_custom_headers( @@ -3684,13 +3716,13 @@ async def image_generation( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), ) ) return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -3805,13 +3837,18 @@ async def audio_speech( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" # Printing each chunk size async def generate(_response: HttpxBinaryResponseContent): @@ -3825,6 +3862,7 @@ async def audio_speech( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=None, ) @@ -3969,13 +4007,18 @@ async def audio_transcriptions( os.remove(file.filename) # Delete the saved file ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" fastapi_response.headers.update( get_custom_headers( @@ -3984,13 +4027,13 @@ async def audio_transcriptions( cache_key=cache_key, api_base=api_base, version=version, + response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), ) ) return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4069,7 +4112,11 @@ async def get_assistants( response = await llm_router.aget_assistants(**data) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4090,7 +4137,6 @@ async def get_assistants( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4161,7 +4207,11 @@ async def create_threads( response = await llm_router.acreate_thread(**data) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4182,7 +4232,6 @@ async def create_threads( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4252,7 +4301,11 @@ async def get_thread( response = await llm_router.aget_thread(thread_id=thread_id, **data) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4273,7 +4326,6 @@ async def get_thread( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4346,7 +4398,11 @@ async def add_messages( response = await llm_router.a_add_message(thread_id=thread_id, **data) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4367,7 +4423,6 @@ async def add_messages( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4436,7 +4491,11 @@ async def get_messages( response = await llm_router.aget_messages(thread_id=thread_id, **data) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4457,7 +4516,6 @@ async def get_messages( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4540,7 +4598,11 @@ async def run_thread( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4561,7 +4623,6 @@ async def run_thread( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4651,7 +4712,11 @@ async def create_batch( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4672,7 +4737,6 @@ async def create_batch( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4757,7 +4821,11 @@ async def retrieve_batch( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4778,7 +4846,6 @@ async def retrieve_batch( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -4873,7 +4940,11 @@ async def create_file( ) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -4894,7 +4965,6 @@ async def create_file( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -5017,7 +5087,11 @@ async def moderations( response = await litellm.amoderation(**data) ### ALERTING ### - data["litellm_status"] = "success" # used for alerting + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -5038,7 +5112,6 @@ async def moderations( return response except Exception as e: - data["litellm_status"] = "fail" # used for alerting await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) @@ -6284,7 +6357,7 @@ async def model_info_v2( raise HTTPException( status_code=500, detail={ - "error": f"Invalid llm model list. llm_model_list={llm_model_list}" + "error": f"No model list passed, models={llm_model_list}. You can add a model through the config.yaml or on the LiteLLM Admin UI." }, ) @@ -7487,7 +7560,9 @@ async def login(request: Request): # Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin user_role = LitellmUserRoles.PROXY_ADMIN user_id = username - key_user_id = user_id + + # we want the key created to have PROXY_ADMIN_PERMISSIONS + key_user_id = litellm_proxy_admin_name if ( os.getenv("PROXY_ADMIN_ID", None) is not None and os.environ["PROXY_ADMIN_ID"] == user_id @@ -7507,7 +7582,17 @@ async def login(request: Request): if os.getenv("DATABASE_URL") is not None: response = await generate_key_helper_fn( request_type="key", - **{"user_role": LitellmUserRoles.PROXY_ADMIN, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"}, # type: ignore + **{ + "user_role": LitellmUserRoles.PROXY_ADMIN, + "duration": "2hr", + "key_max_budget": 5, + "models": [], + "aliases": {}, + "config": {}, + "spend": 0, + "user_id": key_user_id, + "team_id": "litellm-dashboard", + }, # type: ignore ) else: raise ProxyException( diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py index 8dd6772cf7..c4afaedc21 100644 --- a/litellm/proxy/secret_managers/aws_secret_manager.py +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -8,9 +8,13 @@ Requires: * `pip install boto3>=1.28.57` """ -import litellm +import ast +import base64 import os -from typing import Optional +import re +from typing import Any, Dict, Optional + +import litellm from litellm.proxy._types import KeyManagementSystem @@ -57,3 +61,99 @@ def load_aws_kms(use_aws_kms: Optional[bool]): except Exception as e: raise e + + +class AWSKeyManagementService_V2: + """ + V2 Clean Class for decrypting keys from AWS KeyManagementService + """ + + def __init__(self) -> None: + self.validate_environment() + self.kms_client = self.load_aws_kms(use_aws_kms=True) + + def validate_environment( + self, + ): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + ## CHECK IF LICENSE IN ENV ## - premium feature + if os.getenv("LITELLM_LICENSE", None) is None: + raise ValueError( + "AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment." + ) + + def load_aws_kms(self, use_aws_kms: Optional[bool]): + if use_aws_kms is None or use_aws_kms is False: + return + try: + import boto3 + + validate_environment() + + # Create a Secrets Manager client + kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) + + return kms_client + except Exception as e: + raise e + + def decrypt_value(self, secret_name: str) -> Any: + if self.kms_client is None: + raise ValueError("kms_client is None") + encrypted_value = os.getenv(secret_name, None) + if encrypted_value is None: + raise Exception( + "AWS KMS - Encrypted Value of Key={} is None".format(secret_name) + ) + if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"): + encrypted_value = encrypted_value.replace("aws_kms/", "") + + # Decode the base64 encoded ciphertext + ciphertext_blob = base64.b64decode(encrypted_value) + + # Set up the parameters for the decrypt call + params = {"CiphertextBlob": ciphertext_blob} + # Perform the decryption + response = self.kms_client.decrypt(**params) + + # Extract and decode the plaintext + plaintext = response["Plaintext"] + secret = plaintext.decode("utf-8") + if isinstance(secret, str): + secret = secret.strip() + try: + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + except Exception: + pass + + return secret + + +""" +- look for all values in the env with `aws_kms/` +- decrypt keys +- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist. +""" + + +def decrypt_env_var() -> Dict[str, Any]: + # setup client class + aws_kms = AWSKeyManagementService_V2() + # iterate through env - for `aws_kms/` + new_values = {} + for k, v in os.environ.items(): + if ( + k is not None + and isinstance(k, str) + and k.lower().startswith("litellm_secret_aws_kms") + ) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")): + decrypted_value = aws_kms.decrypt_value(secret_name=k) + # reset env var + k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE) + new_values[k] = decrypted_value + + return new_values diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index 1fbd95b3cf..87bd85078c 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -817,9 +817,9 @@ async def get_global_spend_report( default=None, description="Time till which to view spend", ), - group_by: Optional[Literal["team", "customer"]] = fastapi.Query( + group_by: Optional[Literal["team", "customer", "api_key"]] = fastapi.Query( default="team", - description="Group spend by internal team or customer", + description="Group spend by internal team or customer or api_key", ), ): """ @@ -860,7 +860,7 @@ async def get_global_spend_report( start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - from litellm.proxy.proxy_server import prisma_client + from litellm.proxy.proxy_server import premium_user, prisma_client try: if prisma_client is None: @@ -868,6 +868,12 @@ async def get_global_spend_report( f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" ) + if premium_user is not True: + verbose_proxy_logger.debug("accessing /spend/report but not a premium user") + raise ValueError( + "/spend/report endpoint " + CommonProxyErrors.not_premium_user.value + ) + if group_by == "team": # first get data from spend logs -> SpendByModelApiKey # then read data from "SpendByModelApiKey" to format the response obj @@ -992,6 +998,48 @@ async def get_global_spend_report( return [] return db_response + elif group_by == "api_key": + sql_query = """ + WITH SpendByModelApiKey AS ( + SELECT + sl.api_key, + sl.model, + SUM(sl.spend) AS model_cost, + SUM(sl.prompt_tokens) AS model_input_tokens, + SUM(sl.completion_tokens) AS model_output_tokens + FROM + "LiteLLM_SpendLogs" sl + WHERE + sl."startTime" BETWEEN $1::date AND $2::date + GROUP BY + sl.api_key, + sl.model + ) + SELECT + api_key, + SUM(model_cost) AS total_cost, + SUM(model_input_tokens) AS total_input_tokens, + SUM(model_output_tokens) AS total_output_tokens, + jsonb_agg(jsonb_build_object( + 'model', model, + 'total_cost', model_cost, + 'total_input_tokens', model_input_tokens, + 'total_output_tokens', model_output_tokens + )) AS model_details + FROM + SpendByModelApiKey + GROUP BY + api_key + ORDER BY + total_cost DESC; + """ + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + if db_response is None: + return [] + + return db_response except Exception as e: raise HTTPException( diff --git a/litellm/proxy/tests/test_pass_through_langfuse.py b/litellm/proxy/tests/test_pass_through_langfuse.py new file mode 100644 index 0000000000..dfc91ee1b1 --- /dev/null +++ b/litellm/proxy/tests/test_pass_through_langfuse.py @@ -0,0 +1,14 @@ +from langfuse import Langfuse + +langfuse = Langfuse( + host="http://localhost:4000", + public_key="anything", + secret_key="anything", +) + +print("sending langfuse trace request") +trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough") +print("flushing langfuse request") +langfuse.flush() + +print("flushed langfuse request") diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 96aeb4a816..179d094667 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -272,6 +272,16 @@ class ProxyLogging: callback_list=callback_list ) + async def update_request_status( + self, litellm_call_id: str, status: Literal["success", "fail"] + ): + await self.internal_usage_cache.async_set_cache( + key="request_status:{}".format(litellm_call_id), + value=status, + local_only=True, + ttl=3600, + ) + # The actual implementation of the function async def pre_call_hook( self, @@ -560,6 +570,9 @@ class ProxyLogging: """ ### ALERTING ### + await self.update_request_status( + litellm_call_id=request_data.get("litellm_call_id", ""), status="fail" + ) if "llm_exceptions" in self.alert_types and not isinstance( original_exception, HTTPException ): @@ -611,6 +624,7 @@ class ProxyLogging: Covers: 1. /chat/completions """ + for callback in litellm.callbacks: try: _callback: Optional[CustomLogger] = None diff --git a/litellm/router.py b/litellm/router.py index ec8cd09e9d..7839d0431a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -106,7 +106,9 @@ class Router: def __init__( self, - model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None, + model_list: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, ## ASSISTANTS API ## assistants_config: Optional[AssistantsTypedDict] = None, ## CACHING ## @@ -155,6 +157,7 @@ class Router: cooldown_time: Optional[ float ] = None, # (seconds) time to cooldown a deployment after failure + disable_cooldowns: Optional[bool] = None, routing_strategy: Literal[ "simple-shuffle", "least-busy", @@ -306,6 +309,7 @@ class Router: self.allowed_fails = allowed_fails or litellm.allowed_fails self.cooldown_time = cooldown_time or 60 + self.disable_cooldowns = disable_cooldowns self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown @@ -2989,6 +2993,8 @@ class Router: the exception is not one that should be immediately retried (e.g. 401) """ + if self.disable_cooldowns is True: + return if deployment is None: return @@ -3029,24 +3035,50 @@ class Router: exception_status = 500 _should_retry = litellm._should_retry(status_code=exception_status) - if updated_fails > allowed_fails or _should_retry == False: + if updated_fails > allowed_fails or _should_retry is False: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls - cached_value = self.cache.get_cache(key=cooldown_key) + cached_value = self.cache.get_cache( + key=cooldown_key + ) # [(deployment_id, {last_error_str, last_error_status_code})] + cached_value_deployment_ids = [] + if ( + cached_value is not None + and isinstance(cached_value, list) + and len(cached_value) > 0 + and isinstance(cached_value[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cached_value] verbose_router_logger.debug(f"adding {deployment} to cooldown models") # update value - try: - if deployment in cached_value: + if cached_value is not None and len(cached_value_deployment_ids) > 0: + if deployment in cached_value_deployment_ids: pass else: - cached_value = cached_value + [deployment] + cached_value = cached_value + [ + ( + deployment, + { + "Exception Received": str(original_exception), + "Status Code": str(exception_status), + }, + ) + ] # save updated value self.cache.set_cache( value=cached_value, key=cooldown_key, ttl=cooldown_time ) - except: - cached_value = [deployment] + else: + cached_value = [ + ( + deployment, + { + "Exception Received": str(original_exception), + "Status Code": str(exception_status), + }, + ) + ] # save updated value self.cache.set_cache( value=cached_value, key=cooldown_key, ttl=cooldown_time @@ -3062,7 +3094,33 @@ class Router: key=deployment, value=updated_fails, ttl=cooldown_time ) - async def _async_get_cooldown_deployments(self): + async def _async_get_cooldown_deployments(self) -> List[str]: + """ + Async implementation of '_get_cooldown_deployments' + """ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + # get the current cooldown list for that minute + cooldown_key = f"{current_minute}:cooldown_models" + + # ---------------------- + # Return cooldown models + # ---------------------- + cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or [] + + cached_value_deployment_ids = [] + if ( + cooldown_models is not None + and isinstance(cooldown_models, list) + and len(cooldown_models) > 0 + and isinstance(cooldown_models[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cooldown_models] + + verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") + return cached_value_deployment_ids + + async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]: """ Async implementation of '_get_cooldown_deployments' """ @@ -3079,7 +3137,7 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models - def _get_cooldown_deployments(self): + def _get_cooldown_deployments(self) -> List[str]: """ Get the list of models being cooled down for this minute """ @@ -3093,8 +3151,17 @@ class Router: # ---------------------- cooldown_models = self.cache.get_cache(key=cooldown_key) or [] + cached_value_deployment_ids = [] + if ( + cooldown_models is not None + and isinstance(cooldown_models, list) + and len(cooldown_models) > 0 + and isinstance(cooldown_models[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cooldown_models] + verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") - return cooldown_models + return cached_value_deployment_ids def _get_healthy_deployments(self, model: str): _all_deployments: list = [] @@ -3969,16 +4036,36 @@ class Router: Augment litellm info with additional params set in `model_info`. + For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set. + Returns - ModelInfo - If found -> typed dict with max tokens, input cost, etc. + + Raises: + - ValueError -> If model is not mapped yet """ - ## SET MODEL NAME + ## GET BASE MODEL base_model = deployment.get("model_info", {}).get("base_model", None) if base_model is None: base_model = deployment.get("litellm_params", {}).get("base_model", None) - model = base_model or deployment.get("litellm_params", {}).get("model", None) - ## GET LITELLM MODEL INFO + model = base_model + + ## GET PROVIDER + _model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=deployment.get("litellm_params", {}).get("model", ""), + litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})), + ) + + ## SET MODEL TO 'model=' - if base_model is None + not azure + if custom_llm_provider == "azure" and base_model is None: + verbose_router_logger.error( + "Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models" + ) + elif custom_llm_provider != "azure": + model = _model + + ## GET LITELLM MODEL INFO - raises exception, if model is not mapped model_info = litellm.get_model_info(model=model) ## CHECK USER SET MODEL INFO @@ -4364,7 +4451,7 @@ class Router: """ Filter out model in model group, if: - - model context window < message length + - model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models - filter models above rpm limits - if region given, filter out models not in that region / unknown region - [TODO] function call and model doesn't support function calling @@ -4381,6 +4468,11 @@ class Router: try: input_tokens = litellm.token_counter(messages=messages) except Exception as e: + verbose_router_logger.error( + "litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format( + str(e) + ) + ) return _returned_deployments _context_window_error = False @@ -4424,7 +4516,7 @@ class Router: ) continue except Exception as e: - verbose_router_logger.debug("An error occurs - {}".format(str(e))) + verbose_router_logger.error("An error occurs - {}".format(str(e))) _litellm_params = deployment.get("litellm_params", {}) model_id = deployment.get("model_info", {}).get("id", "") @@ -4685,7 +4777,7 @@ class Router: if _allowed_model_region is None: _allowed_model_region = "n/a" raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}" + f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}" ) if ( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index c9e5501a8c..c4705325b9 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -329,11 +329,14 @@ def test_vertex_ai(): "code-gecko@001", "code-gecko@002", "code-gecko@latest", + "codechat-bison@latest", "code-bison@001", "text-bison@001", "gemini-1.5-pro", "gemini-1.5-pro-preview-0215", - ]: + ] or ( + "gecko" in model or "32k" in model or "ultra" in model or "002" in model + ): # our account does not have access to this model continue print("making request", model) @@ -381,12 +384,15 @@ def test_vertex_ai_stream(): "code-gecko@001", "code-gecko@002", "code-gecko@latest", + "codechat-bison@latest", "code-bison@001", "text-bison@001", "gemini-1.5-pro", "gemini-1.5-pro-preview-0215", - ]: - # ouraccount does not have access to this model + ] or ( + "gecko" in model or "32k" in model or "ultra" in model or "002" in model + ): + # our account does not have access to this model continue print("making request", model) response = completion( @@ -433,11 +439,12 @@ async def test_async_vertexai_response(): "code-gecko@001", "code-gecko@002", "code-gecko@latest", + "codechat-bison@latest", "code-bison@001", "text-bison@001", "gemini-1.5-pro", "gemini-1.5-pro-preview-0215", - ]: + ] or ("gecko" in model or "32k" in model or "ultra" in model or "002" in model): # our account does not have access to this model continue try: @@ -479,11 +486,12 @@ async def test_async_vertexai_streaming_response(): "code-gecko@001", "code-gecko@002", "code-gecko@latest", + "codechat-bison@latest", "code-bison@001", "text-bison@001", "gemini-1.5-pro", "gemini-1.5-pro-preview-0215", - ]: + ] or ("gecko" in model or "32k" in model or "ultra" in model or "002" in model): # our account does not have access to this model continue try: @@ -872,6 +880,208 @@ Using this JSON schema: mock_call.assert_called_once() +def vertex_httpx_mock_post_valid_response(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n' + } + ], + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.09790669, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.11736965, + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.1261379, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08601588, + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.083441176, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0355444, + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071981624, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08108212, + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 60, + "candidatesTokenCount": 55, + "totalTokenCount": 115, + }, + } + return mock_response + + +def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + {"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'} + ], + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.09790669, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.11736965, + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.1261379, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08601588, + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.083441176, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0355444, + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071981624, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08108212, + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 60, + "candidatesTokenCount": 55, + "totalTokenCount": 115, + }, + } + return mock_response + + +@pytest.mark.parametrize( + "model, vertex_location, supports_response_schema", + [ + ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True), + ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False), + ], +) +@pytest.mark.parametrize( + "invalid_response", + [True, False], +) +@pytest.mark.parametrize( + "enforce_validation", + [True, False], +) +@pytest.mark.asyncio +async def test_gemini_pro_json_schema_args_sent_httpx( + model, + supports_response_schema, + vertex_location, + invalid_response, + enforce_validation, +): + load_vertex_ai_credentials() + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + litellm.set_verbose = True + messages = [{"role": "user", "content": "List 5 cookie recipes"}] + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + response_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + } + + client = HTTPHandler() + + httpx_response = MagicMock() + if invalid_response is True: + httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response + else: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response + with patch.object(client, "post", new=httpx_response) as mock_call: + try: + _ = completion( + model=model, + messages=messages, + response_format={ + "type": "json_object", + "response_schema": response_schema, + "enforce_validation": enforce_validation, + }, + vertex_location=vertex_location, + client=client, + ) + if invalid_response is True and enforce_validation is True: + pytest.fail("Expected this to fail") + except litellm.JSONSchemaValidationError as e: + if invalid_response is False and "claude-3" not in model: + pytest.fail("Expected this to pass. Got={}".format(e)) + + mock_call.assert_called_once() + print(mock_call.call_args.kwargs) + print(mock_call.call_args.kwargs["json"]["generationConfig"]) + + if supports_response_schema: + assert ( + "response_schema" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + else: + assert ( + "response_schema" + not in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + "Use this JSON schema:" + in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] + ) + + @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.asyncio async def test_gemini_pro_httpx_custom_api_base(provider): diff --git a/litellm/tests/test_audio_speech.py b/litellm/tests/test_audio_speech.py index dde196d9cc..285334f7ef 100644 --- a/litellm/tests/test_audio_speech.py +++ b/litellm/tests/test_audio_speech.py @@ -1,8 +1,14 @@ # What is this? ## unit tests for openai tts endpoint -import sys, os, asyncio, time, random, uuid +import asyncio +import os +import random +import sys +import time import traceback +import uuid + from dotenv import load_dotenv load_dotenv() @@ -11,23 +17,40 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest -import litellm, openai from pathlib import Path +import openai +import pytest -@pytest.mark.parametrize("sync_mode", [True, False]) +import litellm + + +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) +@pytest.mark.parametrize( + "model, api_key, api_base", + [ + ( + "azure/azure-tts", + os.getenv("AZURE_SWEDEN_API_KEY"), + os.getenv("AZURE_SWEDEN_API_BASE"), + ), + ("openai/tts-1", os.getenv("OPENAI_API_KEY"), None), + ], +) # , @pytest.mark.asyncio -async def test_audio_speech_litellm(sync_mode): +async def test_audio_speech_litellm(sync_mode, model, api_base, api_key): speech_file_path = Path(__file__).parent / "speech.mp3" if sync_mode: response = litellm.speech( - model="openai/tts-1", + model=model, voice="alloy", input="the quick brown fox jumped over the lazy dogs", - api_base=None, - api_key=None, + api_base=api_base, + api_key=api_key, organization=None, project=None, max_retries=1, @@ -41,11 +64,11 @@ async def test_audio_speech_litellm(sync_mode): assert isinstance(response, HttpxBinaryResponseContent) else: response = await litellm.aspeech( - model="openai/tts-1", + model=model, voice="alloy", input="the quick brown fox jumped over the lazy dogs", - api_base=None, - api_key=None, + api_base=api_base, + api_key=api_key, organization=None, project=None, max_retries=1, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 24eefceeff..fb4ba7556b 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -25,6 +25,7 @@ from litellm import ( completion_cost, embedding, ) +from litellm.llms.bedrock_httpx import BedrockLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler # litellm.num_retries = 3 @@ -217,6 +218,234 @@ def test_completion_bedrock_claude_sts_client_auth(): pytest.fail(f"Error occurred: {e}") +@pytest.fixture() +def bedrock_session_token_creds(): + print("\ncalling oidc auto to get aws_session_token credentials") + import os + + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") + + bllm = BedrockLLM() + if aws_session_token is not None: + # For local testing + creds = bllm.get_credentials( + aws_region_name=aws_region_name, + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + aws_session_token=aws_session_token, + ) + else: + # For circle-ci testing + # aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] + # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually + aws_role_name = ( + "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci" + ) + aws_web_identity_token = "oidc/circleci_v2/" + + creds = bllm.get_credentials( + aws_region_name=aws_region_name, + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + return creds + + +def process_stream_response(res, messages): + import types + + if isinstance(res, litellm.utils.CustomStreamWrapper): + chunks = [] + for part in res: + chunks.append(part) + text = part.choices[0].delta.content or "" + print(text, end="") + res = litellm.stream_chunk_builder(chunks, messages=messages) + else: + raise ValueError("Response object is not a streaming response") + + return res + + +@pytest.mark.skipif( + os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, + reason="Cannot run without being in CircleCI Runner", +) +def test_completion_bedrock_claude_aws_session_token(bedrock_session_token_creds): + print("\ncalling bedrock claude with aws_session_token auth") + + import os + + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_access_key_id = bedrock_session_token_creds.access_key + aws_secret_access_key = bedrock_session_token_creds.secret_key + aws_session_token = bedrock_session_token_creds.token + + try: + litellm.set_verbose = True + + response_1 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + print(response_1) + assert len(response_1.choices) > 0 + assert len(response_1.choices[0].message.content) > 0 + + # This second call is to verify that the cache isn't breaking anything + response_2 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=5, + temperature=0.2, + aws_region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + print(response_2) + assert len(response_2.choices) > 0 + assert len(response_2.choices[0].message.content) > 0 + + # This third call is to verify that the cache isn't used for a different region + response_3 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + print(response_3) + assert len(response_3.choices) > 0 + assert len(response_3.choices[0].message.content) > 0 + + # This fourth call is to verify streaming api works + response_4 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + stream=True, + ) + response_4 = process_stream_response(response_4, messages) + print(response_4) + assert len(response_4.choices) > 0 + assert len(response_4.choices[0].message.content) > 0 + + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.skipif( + os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, + reason="Cannot run without being in CircleCI Runner", +) +def test_completion_bedrock_claude_aws_bedrock_client(bedrock_session_token_creds): + print("\ncalling bedrock claude with aws_session_token auth") + + import os + + import boto3 + from botocore.client import Config + + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_access_key_id = bedrock_session_token_creds.access_key + aws_secret_access_key = bedrock_session_token_creds.secret_key + aws_session_token = bedrock_session_token_creds.token + + aws_bedrock_client_west = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=Config(read_timeout=600), + ) + + try: + litellm.set_verbose = True + + response_1 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_bedrock_client=aws_bedrock_client_west, + ) + print(response_1) + assert len(response_1.choices) > 0 + assert len(response_1.choices[0].message.content) > 0 + + # This second call is to verify that the cache isn't breaking anything + response_2 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=5, + temperature=0.2, + aws_bedrock_client=aws_bedrock_client_west, + ) + print(response_2) + assert len(response_2.choices) > 0 + assert len(response_2.choices[0].message.content) > 0 + + # This third call is to verify that the cache isn't used for a different region + aws_bedrock_client_east = boto3.client( + service_name="bedrock-runtime", + region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=Config(read_timeout=600), + ) + + response_3 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_bedrock_client=aws_bedrock_client_east, + ) + print(response_3) + assert len(response_3.choices) > 0 + assert len(response_3.choices[0].message.content) > 0 + + # This fourth call is to verify streaming api works + response_4 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_bedrock_client=aws_bedrock_client_east, + stream=True, + ) + response_4 = process_stream_response(response_4, messages) + print(response_4) + assert len(response_4.choices) > 0 + assert len(response_4.choices[0].message.content) > 0 + + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_completion_bedrock_claude_sts_client_auth() @@ -489,61 +718,6 @@ def test_completion_claude_3_base64(): pytest.fail(f"An exception occurred - {str(e)}") -def test_provisioned_throughput(): - try: - litellm.set_verbose = True - import io - import json - - import botocore - import botocore.session - from botocore.stub import Stubber - - bedrock_client = botocore.session.get_session().create_client( - "bedrock-runtime", region_name="us-east-1" - ) - - expected_params = { - "accept": "application/json", - "body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", ' - '"max_tokens_to_sample": 256}', - "contentType": "application/json", - "modelId": "provisioned-model-arn", - } - response_from_bedrock = { - "body": io.StringIO( - json.dumps( - { - "completion": " Here is a short poem about the sky:", - "stop_reason": "max_tokens", - "stop": None, - } - ) - ), - "contentType": "contentType", - "ResponseMetadata": {"HTTPStatusCode": 200}, - } - - with Stubber(bedrock_client) as stubber: - stubber.add_response( - "invoke_model", - service_response=response_from_bedrock, - expected_params=expected_params, - ) - response = litellm.completion( - model="bedrock/anthropic.claude-instant-v1", - model_id="provisioned-model-arn", - messages=[{"content": "Hello, how are you?", "role": "user"}], - aws_bedrock_client=bedrock_client, - ) - print("response stubbed", response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -# test_provisioned_throughput() - - def test_completion_bedrock_mistral_completion_auth(): print("calling bedrock mistral completion params auth") import os @@ -682,3 +856,56 @@ async def test_bedrock_custom_prompt_template(): prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"] assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>" mock_client_post.assert_called_once() + + +def test_completion_bedrock_external_client_region(): + print("\ncalling bedrock claude external client auth") + import os + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_region_name = "us-east-1" + + os.environ.pop("AWS_ACCESS_KEY_ID", None) + os.environ.pop("AWS_SECRET_ACCESS_KEY", None) + + client = HTTPHandler() + + try: + import boto3 + + litellm.set_verbose = True + + bedrock = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com", + ) + with patch.object(client, "post", new=Mock()) as mock_client_post: + try: + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_bedrock_client=bedrock, + client=client, + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pass + + print(f"mock_client_post.call_args: {mock_client_post.call_args}") + assert "us-east-1" in mock_client_post.call_args.kwargs["url"] + + mock_client_post.assert_called_once() + + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index a3b0e6ea26..1c10ef461e 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -11,7 +11,7 @@ import os sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds-the parent directory to the system path import os from unittest.mock import MagicMock, patch @@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries = 3 +# litellm.num_retries=3 litellm.cache = None litellm.success_callback = [] user_message = "Write a short poem about the sky" @@ -1222,44 +1222,6 @@ def test_completion_fireworks_ai(): pytest.fail(f"Error occurred: {e}") -def test_fireworks_ai_tool_calling(): - litellm.set_verbose = True - model_name = "fireworks_ai/accounts/fireworks/models/firefunction-v2" - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } - ] - messages = [ - { - "role": "user", - "content": "What's the weather like in Boston today in Fahrenheit?", - } - ] - response = completion( - model=model_name, - messages=messages, - tools=tools, - tool_choice="required", - ) - print(response) - - @pytest.mark.skip(reason="this test is flaky") def test_completion_perplexity_api(): try: @@ -3508,6 +3470,30 @@ def test_completion_deep_infra_mistral(): # test_completion_deep_infra_mistral() +@pytest.mark.skip(reason="Local test - don't have a volcengine account as yet") +def test_completion_volcengine(): + litellm.set_verbose = True + model_name = "volcengine/" + try: + response = completion( + model=model_name, + messages=[ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ], + api_key="", + ) + # Add any assertions here to check the response + print(response) + + except litellm.exceptions.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_nvidia_nim(): model_name = "nvidia_nim/databricks/dbrx-instruct" try: diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index e854345b3b..bffb68e0e5 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -4,7 +4,9 @@ import traceback import litellm.cost_calculator -sys.path.insert(0, os.path.abspath("../..")) # Adds the parent directory to the system path +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path import asyncio import time from typing import Optional @@ -167,11 +169,15 @@ def test_cost_ft_gpt_35(): input_cost = model_cost["ft:gpt-3.5-turbo"]["input_cost_per_token"] output_cost = model_cost["ft:gpt-3.5-turbo"]["output_cost_per_token"] print(input_cost, output_cost) - expected_cost = (input_cost * resp.usage.prompt_tokens) + (output_cost * resp.usage.completion_tokens) + expected_cost = (input_cost * resp.usage.prompt_tokens) + ( + output_cost * resp.usage.completion_tokens + ) print("\n Excpected cost", expected_cost) assert cost == expected_cost except Exception as e: - pytest.fail(f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}") + pytest.fail( + f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}" + ) # test_cost_ft_gpt_35() @@ -200,15 +206,21 @@ def test_cost_azure_gpt_35(): usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38), ) - cost = litellm.completion_cost(completion_response=resp, model="azure/gpt-35-turbo") + cost = litellm.completion_cost( + completion_response=resp, model="azure/gpt-35-turbo" + ) print("\n Calculated Cost for azure/gpt-3.5-turbo", cost) input_cost = model_cost["azure/gpt-35-turbo"]["input_cost_per_token"] output_cost = model_cost["azure/gpt-35-turbo"]["output_cost_per_token"] - expected_cost = (input_cost * resp.usage.prompt_tokens) + (output_cost * resp.usage.completion_tokens) + expected_cost = (input_cost * resp.usage.prompt_tokens) + ( + output_cost * resp.usage.completion_tokens + ) print("\n Excpected cost", expected_cost) assert cost == expected_cost except Exception as e: - pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}") + pytest.fail( + f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}" + ) # test_cost_azure_gpt_35() @@ -239,7 +251,9 @@ def test_cost_azure_embedding(): assert cost == expected_cost except Exception as e: - pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}") + pytest.fail( + f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}" + ) # test_cost_azure_embedding() @@ -315,7 +329,9 @@ def test_cost_bedrock_pricing_actual_calls(): litellm.set_verbose = True model = "anthropic.claude-instant-v1" messages = [{"role": "user", "content": "Hey, how's it going?"}] - response = litellm.completion(model=model, messages=messages, mock_response="hello cool one") + response = litellm.completion( + model=model, messages=messages, mock_response="hello cool one" + ) print("response", response) cost = litellm.completion_cost( @@ -345,7 +361,8 @@ def test_whisper_openai(): print(f"cost: {cost}") print(f"whisper dict: {litellm.model_cost['whisper-1']}") expected_cost = round( - litellm.model_cost["whisper-1"]["output_cost_per_second"] * _total_time_in_seconds, + litellm.model_cost["whisper-1"]["output_cost_per_second"] + * _total_time_in_seconds, 5, ) assert cost == expected_cost @@ -365,12 +382,15 @@ def test_whisper_azure(): _total_time_in_seconds = 3 transcription._response_ms = _total_time_in_seconds * 1000 - cost = litellm.completion_cost(model="azure/azure-whisper", completion_response=transcription) + cost = litellm.completion_cost( + model="azure/azure-whisper", completion_response=transcription + ) print(f"cost: {cost}") print(f"whisper dict: {litellm.model_cost['whisper-1']}") expected_cost = round( - litellm.model_cost["whisper-1"]["output_cost_per_second"] * _total_time_in_seconds, + litellm.model_cost["whisper-1"]["output_cost_per_second"] + * _total_time_in_seconds, 5, ) assert cost == expected_cost @@ -401,7 +421,9 @@ def test_dalle_3_azure_cost_tracking(): response.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} response._hidden_params = {"model": "dall-e-3", "model_id": None} print(f"response hidden params: {response._hidden_params}") - cost = litellm.completion_cost(completion_response=response, call_type="image_generation") + cost = litellm.completion_cost( + completion_response=response, call_type="image_generation" + ) assert cost > 0 @@ -433,7 +455,9 @@ def test_replicate_llama3_cost_tracking(): model="replicate/meta/meta-llama-3-8b-instruct", object="chat.completion", system_fingerprint=None, - usage=litellm.utils.Usage(prompt_tokens=48, completion_tokens=31, total_tokens=79), + usage=litellm.utils.Usage( + prompt_tokens=48, completion_tokens=31, total_tokens=79 + ), ) cost = litellm.completion_cost( completion_response=response, @@ -443,8 +467,14 @@ def test_replicate_llama3_cost_tracking(): print(f"cost: {cost}") cost = round(cost, 5) expected_cost = round( - litellm.model_cost["replicate/meta/meta-llama-3-8b-instruct"]["input_cost_per_token"] * 48 - + litellm.model_cost["replicate/meta/meta-llama-3-8b-instruct"]["output_cost_per_token"] * 31, + litellm.model_cost["replicate/meta/meta-llama-3-8b-instruct"][ + "input_cost_per_token" + ] + * 48 + + litellm.model_cost["replicate/meta/meta-llama-3-8b-instruct"][ + "output_cost_per_token" + ] + * 31, 5, ) assert cost == expected_cost @@ -538,7 +568,9 @@ def test_together_ai_qwen_completion_cost(): "custom_cost_per_second": None, } - response = litellm.cost_calculator.get_model_params_and_category(model_name="qwen/Qwen2-72B-Instruct") + response = litellm.cost_calculator.get_model_params_and_category( + model_name="qwen/Qwen2-72B-Instruct" + ) assert response == "together-ai-41.1b-80b" @@ -576,8 +608,12 @@ def test_gemini_completion_cost(above_128k, provider): ), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format( model_name, model_info ) - input_cost = prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"] - output_cost = output_tokens * model_info["output_cost_per_token_above_128k_tokens"] + input_cost = ( + prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"] + ) + output_cost = ( + output_tokens * model_info["output_cost_per_token_above_128k_tokens"] + ) else: input_cost = prompt_tokens * model_info["input_cost_per_token"] output_cost = output_tokens * model_info["output_cost_per_token"] @@ -674,3 +710,32 @@ def test_vertex_ai_claude_completion_cost(): ) predicted_cost = input_tokens * 0.000003 + 0.000015 * output_tokens assert cost == predicted_cost + + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_cost_hidden_params(sync_mode): + if sync_mode: + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="Hello world", + ) + else: + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="Hello world", + ) + + assert "response_cost" in response._hidden_params + assert isinstance(response._hidden_params["response_cost"], float) + +def test_vertex_ai_gemini_predict_cost(): + model = "gemini-1.5-flash" + messages = [{"role": "user", "content": "Hey, hows it going???"}] + predictive_cost = completion_cost(model=model, messages=messages) + + assert predictive_cost > 0 + diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 3d8cb3c2a3..fb390bb488 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -249,6 +249,25 @@ def test_completion_azure_exception(): # test_completion_azure_exception() +def test_azure_embedding_exceptions(): + try: + + response = litellm.embedding( + model="azure/azure-embedding-model", + input="hello", + messages="hello", + ) + pytest.fail(f"Bad request this should have failed but got {response}") + + except Exception as e: + print(vars(e)) + # CRUCIAL Test - Ensures our exceptions are readable and not overly complicated. some users have complained exceptions will randomly have another exception raised in our exception mapping + assert ( + e.message + == "litellm.APIError: AzureException APIError - Embeddings.create() got an unexpected keyword argument 'messages'" + ) + + async def asynctest_completion_azure_exception(): try: import openai diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 49ec18f24c..67857b8c86 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -1,20 +1,23 @@ # What this tests? ## This tests the litellm support for the openai /generations endpoint -import sys, os -import traceback -from dotenv import load_dotenv import logging +import os +import sys +import traceback + +from dotenv import load_dotenv logging.basicConfig(level=logging.DEBUG) load_dotenv() -import os import asyncio +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest + import litellm @@ -39,13 +42,25 @@ def test_image_generation_openai(): # test_image_generation_openai() -def test_image_generation_azure(): +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) # +@pytest.mark.asyncio +async def test_image_generation_azure(sync_mode): try: - response = litellm.image_generation( - prompt="A cute baby sea otter", - model="azure/", - api_version="2023-06-01-preview", - ) + if sync_mode: + response = litellm.image_generation( + prompt="A cute baby sea otter", + model="azure/", + api_version="2023-06-01-preview", + ) + else: + response = await litellm.aimage_generation( + prompt="A cute baby sea otter", + model="azure/", + api_version="2023-06-01-preview", + ) print(f"response: {response}") assert len(response.data) > 0 except litellm.RateLimitError as e: diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 72d4d7b1bf..e287946ae4 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -61,7 +61,6 @@ async def test_token_single_public_key(): import jwt jwt_handler = JWTHandler() - backend_keys = { "keys": [ { diff --git a/litellm/tests/test_lakera_ai_prompt_injection.py b/litellm/tests/test_lakera_ai_prompt_injection.py index 6227eabaa3..3e328c8244 100644 --- a/litellm/tests/test_lakera_ai_prompt_injection.py +++ b/litellm/tests/test_lakera_ai_prompt_injection.py @@ -1,10 +1,16 @@ # What is this? ## This tests the Lakera AI integration -import sys, os, asyncio, time, random -from datetime import datetime +import asyncio +import os +import random +import sys +import time import traceback +from datetime import datetime + from dotenv import load_dotenv +from fastapi import HTTPException load_dotenv() import os @@ -12,17 +18,19 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import logging + import pytest + import litellm +from litellm import Router, mock_completion +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( _ENTERPRISE_lakeraAI_Moderation, ) -from litellm import Router, mock_completion from litellm.proxy.utils import ProxyLogging, hash_token -from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache -from litellm._logging import verbose_proxy_logger -import logging verbose_proxy_logger.setLevel(logging.DEBUG) @@ -55,10 +63,12 @@ async def test_lakera_prompt_injection_detection(): call_type="completion", ) pytest.fail(f"Should have failed") - except Exception as e: - print("Got exception: ", e) - assert "Violated content safety policy" in str(e) - pass + except HTTPException as http_exception: + print("http exception details=", http_exception.detail) + + # Assert that the laker ai response is in the exception raise + assert "lakera_ai_response" in http_exception.detail + assert "Violated content safety policy" in str(http_exception) @pytest.mark.asyncio diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py new file mode 100644 index 0000000000..0f234dfa8b --- /dev/null +++ b/litellm/tests/test_pass_through_endpoints.py @@ -0,0 +1,85 @@ +import os +import sys + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds-the parent directory to the system path + +import asyncio + +import httpx + +from litellm.proxy.proxy_server import app, initialize_pass_through_endpoints + + +# Mock the async_client used in the pass_through_request function +async def mock_request(*args, **kwargs): + return httpx.Response(200, json={"message": "Mocked response"}) + + +@pytest.fixture +def client(): + return TestClient(app) + + +@pytest.mark.asyncio +async def test_pass_through_endpoint(client, monkeypatch): + # Mock the httpx.AsyncClient.request method + monkeypatch.setattr("httpx.AsyncClient.request", mock_request) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/test-endpoint", + "target": "https://api.example.com/v1/chat/completions", + "headers": {"Authorization": "Bearer test-token"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + # Make a request to the pass-through endpoint + response = client.post("/test-endpoint", json={"prompt": "Hello, world!"}) + + # Assert the response + assert response.status_code == 200 + assert response.json() == {"message": "Mocked response"} + + +@pytest.mark.asyncio +async def test_pass_through_endpoint_rerank(client): + _cohere_api_key = os.environ.get("COHERE_API_KEY") + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/v1/rerank", + "target": "https://api.cohere.com/v1/rerank", + "headers": {"Authorization": f"bearer {_cohere_api_key}"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + _json_data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada." + ], + } + + # Make a request to the pass-through endpoint + response = client.post("/v1/rerank", json=_json_data) + + print("JSON response: ", _json_data) + + # Assert the response + assert response.status_code == 200 diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index b3aafab6e6..5a368f92d3 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -1,7 +1,8 @@ #### What this tests #### # This tests if prompts are being correctly formatted -import sys import os +import sys + import pytest sys.path.insert(0, os.path.abspath("../..")) @@ -10,12 +11,13 @@ sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm import completion from litellm.llms.prompt_templates.factory import ( - anthropic_pt, + _bedrock_tools_pt, anthropic_messages_pt, + anthropic_pt, claude_2_1_pt, + convert_url_to_base64, llama_2_chat_pt, prompt_factory, - _bedrock_tools_pt, ) @@ -153,3 +155,11 @@ def test_bedrock_tool_calling_pt(): converted_tools = _bedrock_tools_pt(tools=tools) print(converted_tools) + + +def test_convert_url_to_img(): + response_url = convert_url_to_base64( + url="https://images.pexels.com/photos/1319515/pexels-photo-1319515.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1" + ) + + assert "image/jpeg" in response_url diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index 4988426616..4fb1e71349 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -1,25 +1,31 @@ # test that the proxy actually does exception mapping to the OpenAI format -import sys, os -from unittest import mock import json +import os +import sys +from unittest import mock + from dotenv import load_dotenv load_dotenv() -import os, io, asyncio +import asyncio +import io +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import openai import pytest -import litellm, openai -from fastapi.testclient import TestClient from fastapi import Response -from litellm.proxy.proxy_server import ( +from fastapi.testclient import TestClient + +import litellm +from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined + initialize, router, save_worker_config, - initialize, -) # Replace with the actual module where your FastAPI router is defined +) invalid_authentication_error_response = Response( status_code=401, @@ -66,6 +72,12 @@ def test_chat_completion_exception(client): json_response = response.json() print("keys in json response", json_response.keys()) assert json_response.keys() == {"error"} + print("ERROR=", json_response["error"]) + assert isinstance(json_response["error"]["message"], str) + assert ( + json_response["error"]["message"] + == "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys." + ) # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 0bb866f549..ccc68921a5 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -16,6 +16,7 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock, MagicMock, patch import httpx from dotenv import load_dotenv @@ -811,6 +812,7 @@ def test_router_context_window_check_pre_call_check(): "base_model": "azure/gpt-35-turbo", "mock_response": "Hello world 1!", }, + "model_info": {"base_model": "azure/gpt-35-turbo"}, }, { "model_name": "gpt-3.5-turbo", # openai model name @@ -1886,6 +1888,7 @@ async def test_router_model_usage(mock_response): raise e + @pytest.mark.asyncio async def test_is_proxy_set(): """ @@ -1922,3 +1925,106 @@ async def test_is_proxy_set(): ) # type: ignore assert check_proxy(client=model_client._client) is True + +@pytest.mark.parametrize( + "model, base_model, llm_provider", + [ + ("azure/gpt-4", None, "azure"), + ("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"), + ("gpt-4", None, "openai"), + ], +) +def test_router_get_model_info(model, base_model, llm_provider): + """ + Test if router get model info works based on provider + + For azure -> only if base model set + For openai -> use model= + """ + router = Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": model, + "api_key": "my-fake-key", + "api_base": "my-fake-base", + }, + "model_info": {"base_model": base_model, "id": "1"}, + } + ] + ) + + deployment = router.get_deployment(model_id="1") + + assert deployment is not None + + if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"): + router.get_router_model_info(deployment=deployment.to_json()) + else: + try: + router.get_router_model_info(deployment=deployment.to_json()) + pytest.fail("Expected this to raise model not mapped error") + except Exception as e: + if "This model isn't mapped yet" in str(e): + pass + + +@pytest.mark.parametrize( + "model, base_model, llm_provider", + [ + ("azure/gpt-4", None, "azure"), + ("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"), + ("gpt-4", None, "openai"), + ], +) +def test_router_context_window_pre_call_check(model, base_model, llm_provider): + """ + - For an azure model + - if no base model set + - don't enforce context window limits + """ + try: + model_list = [ + { + "model_name": "gpt-4", + "litellm_params": { + "model": model, + "api_key": "my-fake-key", + "api_base": "my-fake-base", + }, + "model_info": {"base_model": base_model, "id": "1"}, + } + ] + router = Router( + model_list=model_list, + set_verbose=True, + enable_pre_call_checks=True, + num_retries=0, + ) + + litellm.token_counter = MagicMock() + + def token_counter_side_effect(*args, **kwargs): + # Process args and kwargs if needed + return 1000000 + + litellm.token_counter.side_effect = token_counter_side_effect + try: + updated_list = router._pre_call_checks( + model="gpt-4", + healthy_deployments=model_list, + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + if llm_provider == "azure" and base_model is None: + assert len(updated_list) == 1 + else: + pytest.fail("Expected to raise an error. Got={}".format(updated_list)) + except Exception as e: + if ( + llm_provider == "azure" and base_model is not None + ) or llm_provider == "openai": + pass + except Exception as e: + pytest.fail(f"Got unexpected exception on router! - {str(e)}") + diff --git a/litellm/tests/test_router_debug_logs.py b/litellm/tests/test_router_debug_logs.py index 09590e5ac6..20eccf5dea 100644 --- a/litellm/tests/test_router_debug_logs.py +++ b/litellm/tests/test_router_debug_logs.py @@ -1,16 +1,23 @@ -import sys, os, time -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import litellm, asyncio, logging +import asyncio +import logging + +import litellm from litellm import Router # this tests debug logs from litellm router and litellm proxy server -from litellm._logging import verbose_router_logger, verbose_logger, verbose_proxy_logger +from litellm._logging import verbose_logger, verbose_proxy_logger, verbose_router_logger # this tests debug logs from litellm router and litellm proxy server @@ -81,7 +88,7 @@ def test_async_fallbacks(caplog): # Define the expected log messages # - error request, falling back notice, success notice expected_logs = [ - "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m", + "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.\x1b[0m", "Falling back to model_group = azure/gpt-3.5-turbo", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", "Successful fallback b/w models.", diff --git a/litellm/tests/test_secret_detect_hook.py b/litellm/tests/test_secret_detect_hook.py index a1bf10ebad..2c20071646 100644 --- a/litellm/tests/test_secret_detect_hook.py +++ b/litellm/tests/test_secret_detect_hook.py @@ -21,15 +21,20 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest +from fastapi import Request, Response +from starlette.datastructures import URL import litellm from litellm import Router, mock_completion from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.secret_detection import ( _ENTERPRISE_SecretDetection, ) +from litellm.proxy.proxy_server import chat_completion from litellm.proxy.utils import ProxyLogging, hash_token +from litellm.router import Router ### UNIT TESTS FOR OpenAI Moderation ### @@ -64,6 +69,10 @@ async def test_basic_secret_detection_chat(): "role": "user", "content": "this is my OPENAI_API_KEY = 'sk_1234567890abcdef'", }, + { + "role": "user", + "content": "My hi API Key is sk-Pc4nlxVoMz41290028TbMCxx, does it seem to be in the correct format?", + }, {"role": "user", "content": "i think it is +1 412-555-5555"}, ], "model": "gpt-3.5-turbo", @@ -88,6 +97,10 @@ async def test_basic_secret_detection_chat(): "content": "Hello! I'm doing well. How can I assist you today?", }, {"role": "user", "content": "this is my OPENAI_API_KEY = '[REDACTED]'"}, + { + "role": "user", + "content": "My hi API Key is [REDACTED], does it seem to be in the correct format?", + }, {"role": "user", "content": "i think it is +1 412-555-5555"}, ], "model": "gpt-3.5-turbo", @@ -214,3 +227,82 @@ async def test_basic_secret_detection_embeddings_list(): ], "model": "gpt-3.5-turbo", } + + +class testLogger(CustomLogger): + + def __init__(self): + self.logged_message = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Async Success") + + self.logged_message = kwargs.get("messages") + + +router = Router( + model_list=[ + { + "model_name": "fake-model", + "litellm_params": { + "model": "openai/fake", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + "api_key": "sk-12345", + }, + } + ] +) + + +@pytest.mark.asyncio +async def test_chat_completion_request_with_redaction(): + """ + IMPORTANT Enterprise Test - Do not delete it: + Makes a /chat/completions request on LiteLLM Proxy + + Ensures that the secret is redacted EVEN on the callback + """ + from litellm.proxy import proxy_server + + setattr(proxy_server, "llm_router", router) + _test_logger = testLogger() + litellm.callbacks = [_ENTERPRISE_SecretDetection(), _test_logger] + litellm.set_verbose = True + + # Prepare the query string + query_params = "param1=value1¶m2=value2" + + # Create the Request object with query parameters + request = Request( + scope={ + "type": "http", + "method": "POST", + "headers": [(b"content-type", b"application/json")], + "query_string": query_params.encode(), + } + ) + + request._url = URL(url="/chat/completions") + + async def return_body(): + return b'{"model": "fake-model", "messages": [{"role": "user", "content": "Hello here is my OPENAI_API_KEY = sk-12345"}]}' + + request.body = return_body + + response = await chat_completion( + request=request, + user_api_key_dict=UserAPIKeyAuth( + api_key="sk-12345", + token="hashed_sk-12345", + ), + fastapi_response=Response(), + ) + + await asyncio.sleep(3) + + print("Info in callback after running request=", _test_logger.logged_message) + + assert _test_logger.logged_message == [ + {"role": "user", "content": "Hello here is my OPENAI_API_KEY = [REDACTED]"} + ] + pass diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 3042e91b34..1f1b253a06 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -742,7 +742,10 @@ def test_completion_palm_stream(): # test_completion_palm_stream() -@pytest.mark.parametrize("sync_mode", [False]) # True, +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) # , @pytest.mark.asyncio async def test_completion_gemini_stream(sync_mode): try: @@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode): pytest.fail(f"Error occurred: {e}") -@pytest.mark.asyncio -async def test_acompletion_gemini_stream(): - try: - litellm.set_verbose = True - print("Streaming gemini response") - messages = [ - # {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": "What do you know?", - }, - ] - print("testing gemini streaming") - response = await acompletion( - model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True - ) - print(f"type of response at the top: {response}") - complete_response = "" - idx = 0 - # Add any assertions here to check, the response - async for chunk in response: - print(f"chunk in acompletion gemini: {chunk}") - print(chunk.choices[0].delta) - chunk, finished = streaming_format_tests(idx, chunk) - if finished: - break - print(f"chunk: {chunk}") - complete_response += chunk - idx += 1 - print(f"completion_response: {complete_response}") - if complete_response.strip() == "": - raise Exception("Empty response received") - except litellm.APIError as e: - pass - except litellm.RateLimitError as e: - pass - except Exception as e: - if "429 Resource has been exhausted" in str(e): - pass - else: - pytest.fail(f"Error occurred: {e}") - - # asyncio.run(test_acompletion_gemini_stream()) @@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key(): # test_completion_replicate_stream() -@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta" +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "" def test_vertex_ai_stream(provider): from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials @@ -1080,14 +1040,27 @@ def test_vertex_ai_stream(provider): litellm.vertex_project = "adroit-crow-413218" import random - test_models = ["gemini-1.0-pro"] + test_models = ["gemini-1.5-pro"] for model in test_models: try: print("making request", model) response = completion( model="{}/{}".format(provider, model), messages=[ - {"role": "user", "content": "write 10 line code code for saying hi"} + {"role": "user", "content": "Hey, how's it going?"}, + { + "role": "assistant", + "content": "I'm doing well. Would like to hear the rest of the story?", + }, + {"role": "user", "content": "Na"}, + { + "role": "assistant", + "content": "No problem, is there anything else i can help you with today?", + }, + { + "role": "user", + "content": "I think you're getting cut off sometimes", + }, ], stream=True, ) @@ -1104,6 +1077,8 @@ def test_vertex_ai_stream(provider): raise Exception("Empty response received") print(f"completion_response: {complete_response}") assert is_finished == True + + assert False except litellm.RateLimitError as e: pass except Exception as e: @@ -1251,6 +1226,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): messages=messages, max_tokens=10, # type: ignore stream=True, + num_retries=3, ) complete_response = "" # Add any assertions here to check the response @@ -1272,6 +1248,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): messages=messages, max_tokens=100, # type: ignore stream=True, + num_retries=3, ) complete_response = "" # Add any assertions here to check the response @@ -1290,6 +1267,8 @@ async def test_completion_replicate_llama3_streaming(sync_mode): raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") + except litellm.UnprocessableEntityError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_token_counter.py b/litellm/tests/test_token_counter.py index 2c3eb89fde..e617621315 100644 --- a/litellm/tests/test_token_counter.py +++ b/litellm/tests/test_token_counter.py @@ -1,15 +1,25 @@ #### What this tests #### # This tests litellm.token_counter() function -import sys, os +import os +import sys import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import time -from litellm import token_counter, create_pretrained_tokenizer, encode, decode +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm import ( + create_pretrained_tokenizer, + decode, + encode, + get_modified_max_tokens, + token_counter, +) from litellm.tests.large_text import text @@ -227,3 +237,55 @@ def test_openai_token_with_image_and_text(): token_count = token_counter(model=model, messages=messages) print(token_count) + + +@pytest.mark.parametrize( + "model, base_model, input_tokens, user_max_tokens, expected_value", + [ + ("random-model", "random-model", 1024, 1024, 1024), + ("command", "command", 1000000, None, None), # model max = 4096 + ("command", "command", 4000, 256, 96), # model max = 4096 + ("command", "command", 4000, 10, 10), # model max = 4096 + ("gpt-3.5-turbo", "gpt-3.5-turbo", 4000, 5000, 4096), # model max output = 4096 + ], +) +def test_get_modified_max_tokens( + model, base_model, input_tokens, user_max_tokens, expected_value +): + """ + - Test when max_output is not known => expect user_max_tokens + - Test when max_output == max_input, + - input > max_output, no max_tokens => expect None + - input + max_tokens > max_output => expect remainder + - input + max_tokens < max_output => expect max_tokens + - Test when max_tokens > max_output => expect max_output + """ + args = locals() + import litellm + + litellm.token_counter = MagicMock() + + def _mock_token_counter(*args, **kwargs): + return input_tokens + + litellm.token_counter.side_effect = _mock_token_counter + print(f"_mock_token_counter: {_mock_token_counter()}") + messages = [{"role": "user", "content": "Hello world!"}] + + calculated_value = get_modified_max_tokens( + model=model, + base_model=base_model, + messages=messages, + user_max_tokens=user_max_tokens, + buffer_perc=0, + buffer_num=0, + ) + + if expected_value is None: + assert calculated_value is None + else: + assert ( + calculated_value == expected_value + ), "Got={}, Expected={}, Params={}".format( + calculated_value, expected_value, args + ) diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 09715e6c16..8225b309dc 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -609,3 +609,83 @@ def test_logging_trace_id(langfuse_trace_id, langfuse_existing_trace_id): litellm_logging_obj._get_trace_id(service_name="langfuse") == litellm_call_id ) + + +def test_convert_model_response_object(): + """ + Unit test to ensure model response object correctly handles openrouter errors. + """ + args = { + "response_object": { + "id": None, + "choices": None, + "created": None, + "model": None, + "object": None, + "service_tier": None, + "system_fingerprint": None, + "usage": None, + "error": { + "message": '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}', + "code": 400, + }, + }, + "model_response_object": litellm.ModelResponse( + id="chatcmpl-b88ce43a-7bfc-437c-b8cc-e90d59372cfb", + choices=[ + litellm.Choices( + finish_reason="stop", + index=0, + message=litellm.Message(content="default", role="assistant"), + ) + ], + created=1719376241, + model="openrouter/anthropic/claude-3.5-sonnet", + object="chat.completion", + system_fingerprint=None, + usage=litellm.Usage(), + ), + "response_type": "completion", + "stream": False, + "start_time": None, + "end_time": None, + "hidden_params": None, + } + + try: + litellm.convert_to_model_response_object(**args) + pytest.fail("Expected this to fail") + except Exception as e: + assert hasattr(e, "status_code") + assert e.status_code == 400 + assert hasattr(e, "message") + assert ( + e.message + == '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}' + ) + + +@pytest.mark.parametrize( + "model, expected_bool", + [ + ("vertex_ai/gemini-1.5-pro", True), + ("gemini/gemini-1.5-pro", True), + ("predibase/llama3-8b-instruct", True), + ("gpt-4o", False), + ], +) +def test_supports_response_schema(model, expected_bool): + """ + Unit tests for 'supports_response_schema' helper function. + + Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models + Should be false otherwise + """ + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + from litellm.utils import supports_response_schema + + response = supports_response_schema(model=model, custom_llm_provider=None) + + assert expected_bool == response diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 2dda57c2e9..17fc26d60e 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -155,6 +155,16 @@ class ToolConfig(TypedDict): functionCallingConfig: FunctionCallingConfig +class TTL(TypedDict, total=False): + seconds: Required[float] + nano: float + + +class CachedContent(TypedDict, total=False): + ttl: TTL + expire_time: str + + class RequestBody(TypedDict, total=False): contents: Required[List[ContentType]] system_instruction: SystemInstructions @@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False): toolConfig: ToolConfig safetySettings: List[SafetSettingsConfig] generationConfig: GenerationConfig + cachedContent: str class SafetyRatings(TypedDict): diff --git a/litellm/types/utils.py b/litellm/types/utils.py index f2b161128c..d6b7bf7442 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -71,6 +71,7 @@ class ModelInfo(TypedDict, total=False): ] supported_openai_params: Required[Optional[List[str]]] supports_system_messages: Optional[bool] + supports_response_schema: Optional[bool] class GenericStreamingChunk(TypedDict): @@ -168,11 +169,13 @@ class Function(OpenAIObject): def __init__( self, - arguments: Union[Dict, str], + arguments: Optional[Union[Dict, str]], name: Optional[str] = None, **params, ): - if isinstance(arguments, Dict): + if arguments is None: + arguments = "" + elif isinstance(arguments, Dict): arguments = json.dumps(arguments) else: arguments = arguments @@ -992,3 +995,8 @@ class GenericImageParsingChunk(TypedDict): type: str media_type: str data: str + + +class ResponseFormatChunk(TypedDict, total=False): + type: Required[Literal["json_object", "text"]] + response_schema: dict diff --git a/litellm/utils.py b/litellm/utils.py index f5fe5964fc..bc307d9b01 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -50,12 +50,15 @@ from tokenizers import Tokenizer import litellm import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm.litellm_core_utils +import litellm.litellm_core_utils.json_validation_rule from litellm.caching import DualCache from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.exception_mapping_utils import get_error_message from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, ) +from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.utils import ( CallTypes, @@ -580,7 +583,7 @@ def client(original_function): else: return False - def post_call_processing(original_response, model): + def post_call_processing(original_response, model, optional_params: Optional[dict]): try: if original_response is None: pass @@ -595,11 +598,47 @@ def client(original_function): pass else: if isinstance(original_response, ModelResponse): - model_response = original_response.choices[ + model_response: Optional[str] = original_response.choices[ 0 - ].message.content - ### POST-CALL RULES ### - rules_obj.post_call_rules(input=model_response, model=model) + ].message.content # type: ignore + if model_response is not None: + ### POST-CALL RULES ### + rules_obj.post_call_rules( + input=model_response, model=model + ) + ### JSON SCHEMA VALIDATION ### + if ( + optional_params is not None + and "response_format" in optional_params + and isinstance( + optional_params["response_format"], dict + ) + and "type" in optional_params["response_format"] + and optional_params["response_format"]["type"] + == "json_object" + and "response_schema" + in optional_params["response_format"] + and isinstance( + optional_params["response_format"][ + "response_schema" + ], + dict, + ) + and "enforce_validation" + in optional_params["response_format"] + and optional_params["response_format"][ + "enforce_validation" + ] + is True + ): + # schema given, json response expected, and validation enforced + litellm.litellm_core_utils.json_validation_rule.validate_schema( + schema=optional_params["response_format"][ + "response_schema" + ], + response=model_response, + ) + except Exception as e: raise e @@ -815,7 +854,7 @@ def client(original_function): kwargs.get("max_tokens", None) is not None and model is not None and litellm.modify_params - == True # user is okay with params being modified + is True # user is okay with params being modified and ( call_type == CallTypes.acompletion.value or call_type == CallTypes.completion.value @@ -825,28 +864,19 @@ def client(original_function): base_model = model if kwargs.get("hf_model_name", None) is not None: base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit messages = None if len(args) > 1: messages = args[1] elif kwargs.get("messages", None): messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int + user_max_tokens = kwargs.get("max_tokens") + modified_max_tokens = get_modified_max_tokens( + model=model, + base_model=base_model, + messages=messages, + user_max_tokens=user_max_tokens, + ) + kwargs["max_tokens"] = modified_max_tokens except Exception as e: print_verbose(f"Error while checking max token limit: {str(e)}") # MODEL CALL @@ -877,7 +907,11 @@ def client(original_function): return result ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model or None) + post_call_processing( + original_response=result, + model=model or None, + optional_params=kwargs, + ) # [OPTIONAL] ADD TO CACHE if ( @@ -901,6 +935,17 @@ def client(original_function): model=model, optional_params=getattr(logging_obj, "optional_params", {}), ) + result._hidden_params["response_cost"] = ( + litellm.response_cost_calculator( + response_object=result, + model=getattr(logging_obj, "model", ""), + custom_llm_provider=getattr( + logging_obj, "custom_llm_provider", None + ), + call_type=getattr(logging_obj, "call_type", "completion"), + optional_params=getattr(logging_obj, "optional_params", {}), + ) + ) result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai @@ -1294,6 +1339,17 @@ def client(original_function): model=model, optional_params=kwargs, ) + result._hidden_params["response_cost"] = ( + litellm.response_cost_calculator( + response_object=result, + model=getattr(logging_obj, "model", ""), + custom_llm_provider=getattr( + logging_obj, "custom_llm_provider", None + ), + call_type=getattr(logging_obj, "call_type", "completion"), + optional_params=getattr(logging_obj, "optional_params", {}), + ) + ) if ( isinstance(result, ModelResponse) or isinstance(result, EmbeddingResponse) @@ -1304,7 +1360,9 @@ def client(original_function): ).total_seconds() * 1000 # return response latency in ms like openai ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model) + post_call_processing( + original_response=result, model=model, optional_params=kwargs + ) # [OPTIONAL] ADD TO CACHE if ( @@ -1835,9 +1893,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> Parameters: model (str): The model name to be checked. + custom_llm_provider (str): The provider to be checked. Returns: - bool: True if the model supports function calling, False otherwise. + bool: True if the model supports system messages, False otherwise. Raises: Exception: If the given model is not found in model_prices_and_context_window.json. @@ -1855,6 +1914,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> ) +def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool: + """ + Check if the given model + provider supports 'response_schema' as a param. + + Parameters: + model (str): The model name to be checked. + custom_llm_provider (str): The provider to be checked. + + Returns: + bool: True if the model supports response_schema, False otherwise. + + Does not raise error. Defaults to 'False'. Outputs logging.error. + """ + try: + ## GET LLM PROVIDER ## + model, custom_llm_provider, _, _ = get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider + ) + + if custom_llm_provider == "predibase": # predibase supports this globally + return True + + ## GET MODEL INFO + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + + if model_info.get("supports_response_schema", False) is True: + return True + return False + except Exception: + verbose_logger.error( + f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." + ) + return False + + def supports_function_calling(model: str) -> bool: """ Check if the given model supports function calling and return a boolean value. @@ -2312,7 +2408,9 @@ def get_optional_params( elif k == "hf_model_name" and custom_llm_provider != "sagemaker": continue elif ( - k.startswith("vertex_") and custom_llm_provider != "vertex_ai" + k.startswith("vertex_") + and custom_llm_provider != "vertex_ai" + and custom_llm_provider != "vertex_ai_beta" ): # allow dynamically setting vertex ai init logic continue passed_params[k] = v @@ -2415,6 +2513,7 @@ def get_optional_params( and custom_llm_provider != "together_ai" and custom_llm_provider != "groq" and custom_llm_provider != "nvidia_nim" + and custom_llm_provider != "volcengine" and custom_llm_provider != "deepseek" and custom_llm_provider != "codestral" and custom_llm_provider != "mistral" @@ -2743,6 +2842,11 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif ( custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models @@ -2811,12 +2915,7 @@ def get_optional_params( optional_params=optional_params, ) ) - else: - optional_params = litellm.AmazonAnthropicConfig().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) - else: # bedrock httpx route + elif model in litellm.BEDROCK_CONVERSE_MODELS: optional_params = litellm.AmazonConverseConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -2827,6 +2926,11 @@ def get_optional_params( else False ), ) + else: + optional_params = litellm.AmazonAnthropicConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif "amazon" in model: # amazon titan llms _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -3091,6 +3195,17 @@ def get_optional_params( optional_params=optional_params, model=model, ) + elif custom_llm_provider == "volcengine": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VolcEngineConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) + elif custom_llm_provider == "groq": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -3661,6 +3776,8 @@ def get_supported_openai_params( return litellm.FireworksAIConfig().get_supported_openai_params() elif custom_llm_provider == "nvidia_nim": return litellm.NvidiaNimConfig().get_supported_openai_params() + elif custom_llm_provider == "volcengine": + return litellm.VolcEngineConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "groq": return [ "temperature", @@ -3672,6 +3789,8 @@ def get_supported_openai_params( "tool_choice", "response_format", "seed", + "extra_headers", + "extra_body", ] elif custom_llm_provider == "deepseek": return [ @@ -3727,23 +3846,18 @@ def get_supported_openai_params( return litellm.AzureOpenAIConfig().get_supported_openai_params() elif custom_llm_provider == "openrouter": return [ - "functions", - "function_call", "temperature", "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", "frequency_penalty", - "logit_bias", - "user", - "response_format", + "presence_penalty", + "repetition_penalty", "seed", - "tools", - "tool_choice", - "max_retries", + "max_tokens", + "logit_bias", + "logprobs", + "top_logprobs", + "response_format", + "stop", ] elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral": # mistal and codestral api have the exact same params @@ -3761,6 +3875,10 @@ def get_supported_openai_params( "top_p", "stop", "seed", + "tools", + "tool_choice", + "functions", + "function_call", ] elif custom_llm_provider == "huggingface": return litellm.HuggingfaceConfig().get_supported_openai_params() @@ -4025,6 +4143,10 @@ def get_llm_provider( # nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://integrate.api.nvidia.com/v1" dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY") + elif custom_llm_provider == "volcengine": + # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 + api_base = "https://ark.cn-beijing.volces.com/api/v3" + dynamic_api_key = get_secret("VOLCENGINE_API_KEY") elif custom_llm_provider == "codestral": # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 api_base = "https://codestral.mistral.ai/v1" @@ -4334,7 +4456,7 @@ def get_utc_datetime(): return datetime.utcnow() # type: ignore -def get_max_tokens(model: str): +def get_max_tokens(model: str) -> Optional[int]: """ Get the maximum number of output tokens allowed for a given model. @@ -4388,7 +4510,8 @@ def get_max_tokens(model: str): return litellm.model_cost[model]["max_tokens"] else: raise Exception() - except: + return None + except Exception: raise Exception( f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" ) @@ -4396,8 +4519,7 @@ def get_max_tokens(model: str): def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo: """ - Get a dict for the maximum tokens (context window), - input_cost_per_token, output_cost_per_token for a given model. + Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Parameters: - model (str): The name of the model. @@ -4482,6 +4604,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod mode="chat", supported_openai_params=supported_openai_params, supports_system_messages=None, + supports_response_schema=None, ) else: """ @@ -4503,36 +4626,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) elif model in litellm.model_cost: _model_info = litellm.model_cost[model] _model_info["supported_openai_params"] = supported_openai_params @@ -4546,36 +4639,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) elif split_model in litellm.model_cost: _model_info = litellm.model_cost[split_model] _model_info["supported_openai_params"] = supported_openai_params @@ -4589,40 +4652,48 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) else: raise ValueError( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" ) + + ## PROVIDER-SPECIFIC INFORMATION + if custom_llm_provider == "predibase": + _model_info["supports_response_schema"] = True + + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_character=_model_info.get( + "input_cost_per_character", None + ), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_character=_model_info.get( + "output_cost_per_character", None + ), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + output_cost_per_character_above_128k_tokens=_model_info.get( + "output_cost_per_character_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + supports_response_schema=_model_info.get( + "supports_response_schema", None + ), + ) except Exception: raise Exception( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" @@ -4746,6 +4817,12 @@ def function_to_dict(input_function): # noqa: C901 return result +def modify_url(original_url, new_path): + url = httpx.URL(original_url) + modified_url = url.copy_with(path=new_path) + return str(modified_url) + + def load_test_model( model: str, custom_llm_provider: str = "", @@ -4975,6 +5052,11 @@ def validate_environment(model: Optional[str] = None) -> dict: keys_in_environment = True else: missing_keys.append("NVIDIA_NIM_API_KEY") + elif custom_llm_provider == "volcengine": + if "VOLCENGINE_API_KEY" in os.environ: + keys_in_environment = True + else: + missing_keys.append("VOLCENGINE_API_KEY") elif ( custom_llm_provider == "codestral" or custom_llm_provider == "text-completion-codestral" @@ -5263,6 +5345,27 @@ def convert_to_model_response_object( hidden_params: Optional[dict] = None, ): received_args = locals() + ### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary + if ( + response_object is not None + and "error" in response_object + and response_object["error"] is not None + ): + error_args = {"status_code": 422, "message": "Error in response object"} + if isinstance(response_object["error"], dict): + if "code" in response_object["error"]: + error_args["status_code"] = response_object["error"]["code"] + if "message" in response_object["error"]: + if isinstance(response_object["error"]["message"], dict): + message_str = json.dumps(response_object["error"]["message"]) + else: + message_str = str(response_object["error"]["message"]) + error_args["message"] = message_str + raised_exception = Exception() + setattr(raised_exception, "status_code", error_args["status_code"]) + setattr(raised_exception, "message", error_args["message"]) + raise raised_exception + try: if response_type == "completion" and ( model_response_object is None @@ -5718,7 +5821,10 @@ def exception_type( print() # noqa try: if model: - error_str = str(original_exception) + if hasattr(original_exception, "message"): + error_str = str(original_exception.message) + else: + error_str = str(original_exception) if isinstance(original_exception, BaseException): exception_type = type(original_exception).__name__ else: @@ -5740,6 +5846,18 @@ def exception_type( _model_group = _metadata.get("model_group") _deployment = _metadata.get("deployment") extra_information = f"\nModel: {model}" + + exception_provider = "Unknown" + if ( + isinstance(custom_llm_provider, str) + and len(custom_llm_provider) > 0 + ): + exception_provider = ( + custom_llm_provider[0].upper() + + custom_llm_provider[1:] + + "Exception" + ) + if _api_base: extra_information += f"\nAPI Base: `{_api_base}`" if ( @@ -5790,10 +5908,13 @@ def exception_type( or custom_llm_provider in litellm.openai_compatible_providers ): # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) + message = get_error_message(error_obj=original_exception) + if message is None: + if hasattr(original_exception, "message"): + message = original_exception.message + else: + message = str(original_exception) + if message is not None and isinstance(message, str): message = message.replace("OPENAI", custom_llm_provider.upper()) message = message.replace("openai", custom_llm_provider) @@ -6126,7 +6247,6 @@ def exception_type( ) elif ( original_exception.status_code == 400 - or original_exception.status_code == 422 or original_exception.status_code == 413 ): exception_mapping_worked = True @@ -6136,6 +6256,14 @@ def exception_type( llm_provider="replicate", response=original_exception.response, ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise UnprocessableEntityError( + message=f"ReplicateException - {original_exception.message}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( @@ -7239,10 +7367,17 @@ def exception_type( request=original_exception.request, ) elif custom_llm_provider == "azure": + message = get_error_message(error_obj=original_exception) + if message is None: + if hasattr(original_exception, "message"): + message = original_exception.message + else: + message = str(original_exception) + if "Internal server error" in error_str: exception_mapping_worked = True raise litellm.InternalServerError( - message=f"AzureException Internal server error - {original_exception.message}", + message=f"AzureException Internal server error - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7255,7 +7390,7 @@ def exception_type( elif "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( - message=f"AzureException ContextWindowExceededError - {original_exception.message}", + message=f"AzureException ContextWindowExceededError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7264,7 +7399,7 @@ def exception_type( elif "DeploymentNotFound" in error_str: exception_mapping_worked = True raise NotFoundError( - message=f"AzureException NotFoundError - {original_exception.message}", + message=f"AzureException NotFoundError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7284,7 +7419,7 @@ def exception_type( ): exception_mapping_worked = True raise ContentPolicyViolationError( - message=f"litellm.ContentPolicyViolationError: AzureException - {original_exception.message}", + message=f"litellm.ContentPolicyViolationError: AzureException - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7293,7 +7428,7 @@ def exception_type( elif "invalid_request_error" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException BadRequestError - {original_exception.message}", + message=f"AzureException BadRequestError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7305,7 +7440,7 @@ def exception_type( ): exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} AuthenticationError - {original_exception.message}", + message=f"{exception_provider} AuthenticationError - {message}", llm_provider=custom_llm_provider, model=model, litellm_debug_info=extra_information, @@ -7316,7 +7451,7 @@ def exception_type( if original_exception.status_code == 400: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException - {original_exception.message}", + message=f"AzureException - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7325,7 +7460,7 @@ def exception_type( elif original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"AzureException AuthenticationError - {original_exception.message}", + message=f"AzureException AuthenticationError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7334,7 +7469,7 @@ def exception_type( elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"AzureException Timeout - {original_exception.message}", + message=f"AzureException Timeout - {message}", model=model, litellm_debug_info=extra_information, llm_provider="azure", @@ -7342,7 +7477,7 @@ def exception_type( elif original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException BadRequestError - {original_exception.message}", + message=f"AzureException BadRequestError - {message}", model=model, llm_provider="azure", litellm_debug_info=extra_information, @@ -7351,7 +7486,7 @@ def exception_type( elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"AzureException RateLimitError - {original_exception.message}", + message=f"AzureException RateLimitError - {message}", model=model, llm_provider="azure", litellm_debug_info=extra_information, @@ -7360,7 +7495,7 @@ def exception_type( elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"AzureException ServiceUnavailableError - {original_exception.message}", + message=f"AzureException ServiceUnavailableError - {message}", model=model, llm_provider="azure", litellm_debug_info=extra_information, @@ -7369,7 +7504,7 @@ def exception_type( elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"AzureException Timeout - {original_exception.message}", + message=f"AzureException Timeout - {message}", model=model, litellm_debug_info=extra_information, llm_provider="azure", @@ -7378,7 +7513,7 @@ def exception_type( exception_mapping_worked = True raise APIError( status_code=original_exception.status_code, - message=f"AzureException APIError - {original_exception.message}", + message=f"AzureException APIError - {message}", llm_provider="azure", litellm_debug_info=extra_information, model=model, @@ -7810,6 +7945,7 @@ class CustomStreamWrapper: "", "", "<|im_end|>", + "<|im_start|>", ] self.holding_chunk = "" self.complete_response = "" diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index acd03aeea8..7f08b9eb19 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -863,6 +863,46 @@ "litellm_provider": "deepseek", "mode": "chat" }, + "codestral/codestral-latest": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "codestral", + "mode": "chat", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, + "codestral/codestral-2405": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "codestral", + "mode": "chat", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, + "text-completion-codestral/codestral-latest": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "text-completion-codestral", + "mode": "completion", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, + "text-completion-codestral/codestral-2405": { + "max_tokens": 8191, + "max_input_tokens": 32000, + "max_output_tokens": 8191, + "input_cost_per_token": 0.000000, + "output_cost_per_token": 0.000000, + "litellm_provider": "text-completion-codestral", + "mode": "completion", + "source": "https://docs.mistral.ai/capabilities/code_generation/" + }, "deepseek-coder": { "max_tokens": 4096, "max_input_tokens": 32000, @@ -1028,21 +1068,55 @@ "tool_use_system_prompt_tokens": 159 }, "text-bison": { - "max_tokens": 1024, + "max_tokens": 2048, "max_input_tokens": 8192, - "max_output_tokens": 1024, - "input_cost_per_token": 0.000000125, - "output_cost_per_token": 0.000000125, + "max_output_tokens": 2048, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-text-models", "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "text-bison@001": { + "max_tokens": 1024, + "max_input_tokens": 8192, + "max_output_tokens": 1024, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "text-bison@002": { + "max_tokens": 1024, + "max_input_tokens": 8192, + "max_output_tokens": 1024, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "text-bison32k": { "max_tokens": 1024, "max_input_tokens": 8192, "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "text-bison32k@002": { + "max_tokens": 1024, + "max_input_tokens": 8192, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-text-models", "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1073,6 +1147,8 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1083,6 +1159,8 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1093,6 +1171,8 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1103,6 +1183,20 @@ "max_output_tokens": 8192, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "chat-bison-32k@002": { + "max_tokens": 8192, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1113,6 +1207,8 @@ "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-text-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1123,6 +1219,44 @@ "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "code-bison@002": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "code-bison32k": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "code-bison-32k@002": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-text-models", "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1157,12 +1291,36 @@ "mode": "completion", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, + "code-gecko-latest": { + "max_tokens": 64, + "max_input_tokens": 2048, + "max_output_tokens": 64, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "codechat-bison@latest": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, "codechat-bison": { "max_tokens": 1024, "max_input_tokens": 6144, "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1173,6 +1331,20 @@ "max_output_tokens": 1024, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "codechat-bison@002": { + "max_tokens": 1024, + "max_input_tokens": 6144, + "max_output_tokens": 1024, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1183,6 +1355,20 @@ "max_output_tokens": 8192, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, + "litellm_provider": "vertex_ai-code-chat-models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "codechat-bison-32k@002": { + "max_tokens": 8192, + "max_input_tokens": 32000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "input_cost_per_character": 0.00000025, + "output_cost_per_character": 0.0000005, "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" @@ -1232,6 +1418,36 @@ "supports_function_calling": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, + "gemini-1.0-ultra": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 2048, + "input_cost_per_image": 0.0025, + "input_cost_per_video_per_second": 0.002, + "input_cost_per_token": 0.0000005, + "input_cost_per_character": 0.000000125, + "output_cost_per_token": 0.0000015, + "output_cost_per_character": 0.000000375, + "litellm_provider": "vertex_ai-language-models", + "mode": "chat", + "supports_function_calling": true, + "source": "As of Jun, 2024. There is no available doc on vertex ai pricing gemini-1.0-ultra-001. Using gemini-1.0-pro pricing. Got max_tokens info here: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, + "gemini-1.0-ultra-001": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 2048, + "input_cost_per_image": 0.0025, + "input_cost_per_video_per_second": 0.002, + "input_cost_per_token": 0.0000005, + "input_cost_per_character": 0.000000125, + "output_cost_per_token": 0.0000015, + "output_cost_per_character": 0.000000375, + "litellm_provider": "vertex_ai-language-models", + "mode": "chat", + "supports_function_calling": true, + "source": "As of Jun, 2024. There is no available doc on vertex ai pricing gemini-1.0-ultra-001. Using gemini-1.0-pro pricing. Got max_tokens info here: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" + }, "gemini-1.0-pro-002": { "max_tokens": 8192, "max_input_tokens": 32760, @@ -1249,7 +1465,7 @@ }, "gemini-1.5-pro": { "max_tokens": 8192, - "max_input_tokens": 1000000, + "max_input_tokens": 2097152, "max_output_tokens": 8192, "input_cost_per_image": 0.001315, "input_cost_per_audio_per_second": 0.000125, @@ -1270,6 +1486,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-001": { @@ -1295,6 +1512,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0514": { @@ -1320,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1345,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1368,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { @@ -1779,7 +2000,7 @@ }, "gemini/gemini-1.5-pro": { "max_tokens": 8192, - "max_input_tokens": 1000000, + "max_input_tokens": 2097152, "max_output_tokens": 8192, "input_cost_per_token": 0.00000035, "input_cost_per_token_above_128k_tokens": 0.0000007, @@ -1791,6 +2012,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -1807,6 +2029,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": { diff --git a/poetry.lock b/poetry.lock index 290d19f7a9..88927576c4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -343,13 +343,13 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "cachetools" -version = "5.3.3" +version = "5.3.1" description = "Extensible memoizing collections and decorators" -optional = true +optional = false python-versions = ">=3.7" files = [ - {file = "cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945"}, - {file = "cachetools-5.3.3.tar.gz", hash = "sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105"}, + {file = "cachetools-5.3.1-py3-none-any.whl", hash = "sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590"}, + {file = "cachetools-5.3.1.tar.gz", hash = "sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b"}, ] [[package]] @@ -3300,4 +3300,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "cryptography", "fastapi", "fastapi- [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "f400d2f686954c2b12b0ee88546f31d52ebc8e323a3ec850dc46d74748d38cdf" +content-hash = "022481b965a1a6524cc25d52eff59592779aafdf03dc6159c834b9519079f549" diff --git a/pyproject.toml b/pyproject.toml index 321f44b23b..c698a18e16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.40.27" +version = "1.41.3" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -27,7 +27,7 @@ jinja2 = "^3.1.2" aiohttp = "*" requests = "^2.31.0" pydantic = "^2.0.0" -ijson = "*" +jsonschema = "^4.22.0" uvicorn = {version = "^0.22.0", optional = true} gunicorn = {version = "^22.0.0", optional = true} @@ -90,7 +90,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.40.27" +version = "1.41.3" version_files = [ "pyproject.toml:^version" ] diff --git a/requirements.txt b/requirements.txt index 00d3802da5..e71ab450bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,5 +46,5 @@ aiohttp==3.9.0 # for network calls aioboto3==12.3.0 # for async sagemaker calls tenacity==8.2.3 # for retrying requests, when litellm.num_retries set pydantic==2.7.1 # proxy + openai req. -ijson==3.2.3 # for google ai studio streaming +jsonschema==4.22.0 # validating json schema #### \ No newline at end of file diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py new file mode 100644 index 0000000000..803135e35d --- /dev/null +++ b/tests/test_entrypoint.py @@ -0,0 +1,59 @@ +# What is this? +## Unit tests for 'entrypoint.sh' + +import pytest +import sys +import os + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path +import litellm +import subprocess + + +@pytest.mark.skip(reason="local test") +def test_decrypt_and_reset_env(): + os.environ["DATABASE_URL"] = ( + "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" + ) + from litellm.proxy.secret_managers.aws_secret_manager import ( + decrypt_and_reset_env_var, + ) + + decrypt_and_reset_env_var() + + assert os.environ["DATABASE_URL"] is not None + assert isinstance(os.environ["DATABASE_URL"], str) + assert not os.environ["DATABASE_URL"].startswith("aws_kms/") + + print("DATABASE_URL={}".format(os.environ["DATABASE_URL"])) + + +@pytest.mark.skip(reason="local test") +def test_entrypoint_decrypt_and_reset(): + os.environ["DATABASE_URL"] = ( + "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" + ) + command = "./entrypoint.sh" + directory = ".." # Relative to the current directory + + # Run the command using subprocess + result = subprocess.run( + command, shell=True, cwd=directory, capture_output=True, text=True + ) + + # Print the output for debugging purposes + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + + # Assert the script ran successfully + assert result.returncode == 0, "The shell script did not execute successfully" + assert ( + "DECRYPTS VALUE" in result.stdout + ), "Expected output not found in script output" + assert ( + "Database push successful!" in result.stdout + ), "Expected output not found in script output" + + assert False diff --git a/tests/test_whisper.py b/tests/test_whisper.py index 1debbbc1db..09819f796c 100644 --- a/tests/test_whisper.py +++ b/tests/test_whisper.py @@ -8,6 +8,9 @@ from openai import AsyncOpenAI import sys, os, dotenv from typing import Optional from dotenv import load_dotenv +from litellm.integrations.custom_logger import CustomLogger +import litellm +import logging # Get the current directory of the file being run pwd = os.path.dirname(os.path.realpath(__file__)) @@ -84,9 +87,32 @@ async def test_transcription_async_openai(): assert isinstance(transcript.text, str) +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler(CustomLogger): + def __init__(self): + self.openai_client = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + # init logging config + print("logging a transcript kwargs: ", kwargs) + print("openai client=", kwargs.get("client")) + self.openai_client = kwargs.get("client") + + except: + pass + + +proxy_handler_instance = MyCustomHandler() + + +# Set litellm.callbacks = [proxy_handler_instance] on the proxy +# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy @pytest.mark.asyncio async def test_transcription_on_router(): litellm.set_verbose = True + litellm.callbacks = [proxy_handler_instance] print("\n Testing async transcription on router\n") try: model_list = [ @@ -108,11 +134,29 @@ async def test_transcription_on_router(): ] router = Router(model_list=model_list) + + router_level_clients = [] + for deployment in router.model_list: + _deployment_openai_client = router._get_client( + deployment=deployment, + kwargs={"model": "whisper-1"}, + client_type="async", + ) + + router_level_clients.append(str(_deployment_openai_client)) + response = await router.atranscription( model="whisper", file=audio_file, ) print(response) + + # PROD Test + # Ensure we ONLY use OpenAI/Azure client initialized on the router level + await asyncio.sleep(5) + print("OpenAI Client used= ", proxy_handler_instance.openai_client) + print("all router level clients= ", router_level_clients) + assert proxy_handler_instance.openai_client in router_level_clients except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}")