diff --git a/.circleci/config.yml b/.circleci/config.yml index 5dfeedcaa..40d498d6e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -66,7 +66,7 @@ jobs: pip install "pydantic==2.7.1" pip install "diskcache==5.6.1" pip install "Pillow==10.3.0" - pip install "ijson==3.2.3" + pip install "jsonschema==4.22.0" - save_cache: paths: - ./venv @@ -128,7 +128,7 @@ jobs: pip install jinja2 pip install tokenizers pip install openai - pip install ijson + pip install jsonschema - run: name: Run tests command: | @@ -183,7 +183,7 @@ jobs: pip install numpydoc pip install prisma pip install fastapi - pip install ijson + pip install jsonschema pip install "httpx==0.24.1" pip install "gunicorn==21.2.0" pip install "anyio==3.7.1" @@ -212,6 +212,7 @@ jobs: -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AUTO_INFER_REGION=True \ -e OPENAI_API_KEY=$OPENAI_API_KEY \ + -e LITELLM_LICENSE=$LITELLM_LICENSE \ -e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \ -e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \ -e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \ diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index db2931909..5e2bd6079 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -50,7 +50,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea |Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | ✅ | | | | | |AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | -|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | ✅ | | | +|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | ✅ | ✅ | | | |Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ (for anthropic) | | |Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ | diff --git a/docs/my-website/docs/enterprise.md b/docs/my-website/docs/enterprise.md index 875aec57f..5bd09ec15 100644 --- a/docs/my-website/docs/enterprise.md +++ b/docs/my-website/docs/enterprise.md @@ -2,26 +2,39 @@ For companies that need SSO, user management and professional support for LiteLLM Proxy :::info - +Interested in Enterprise? Schedule a meeting with us here 👉 [Talk to founders](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ::: This covers: -- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)** -- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui) -- ✅ [**Audit Logs with retention policy**](../docs/proxy/enterprise.md#audit-logs) -- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md) -- ✅ [**Control available public, private routes**](../docs/proxy/enterprise.md#control-available-public-private-routes) -- ✅ [**Guardrails, Content Moderation, PII Masking, Secret/API Key Masking**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai) -- ✅ [**Prompt Injection Detection**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai) -- ✅ [**Invite Team Members to access `/spend` Routes**](../docs/proxy/cost_tracking#allowing-non-proxy-admins-to-access-spend-endpoints) +- **Enterprise Features** + - **Security** + - ✅ [SSO for Admin UI](./proxy/ui#✨-enterprise-features) + - ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs) + - ✅ [JWT-Auth](../docs/proxy/token_auth.md) + - ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes) + - ✅ [[BETA] AWS Key Manager v2 - Key Decryption](./proxy/enterprise#beta-aws-key-manager---key-decryption) + - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) + - ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](./proxy/enterprise#enforce-required-params-for-llm-requests) + - **Spend Tracking** + - ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags) + - ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend) + - **Advanced Metrics** + - ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](./proxy/prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens) + - **Guardrails, PII Masking, Content Moderation** + - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation) + - ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai) + - ✅ Reject calls from Blocked User list + - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) + - **Custom Branding** + - ✅ [Custom Branding + Routes on Swagger Docs](./proxy/enterprise#swagger-docs---custom-routes--branding) + - ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub) + - ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding) - ✅ **Feature Prioritization** - ✅ **Custom Integrations** - ✅ **Professional Support - Dedicated discord + slack** -- ✅ [**Custom Swagger**](../docs/proxy/enterprise.md#swagger-docs---custom-routes--branding) -- ✅ [**Public Model Hub**](../docs/proxy/enterprise.md#public-model-hub) -- ✅ [**Custom Email Branding**](../docs/proxy/email.md#customizing-email-branding) + diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index e7d3352f9..a662129d0 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -168,8 +168,12 @@ print(response) ## Supported Models +`Model Name` 👉 Human-friendly name. +`Function Call` 👉 How to call the model in LiteLLM. + | Model Name | Function Call | |------------------|--------------------------------------------| +| claude-3-5-sonnet | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-3-haiku | `completion('claude-3-haiku-20240307', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-3-opus | `completion('claude-3-opus-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-3-5-sonnet-20240620 | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` | diff --git a/docs/my-website/docs/providers/azure_ai.md b/docs/my-website/docs/providers/azure_ai.md index 87b8041ef..26c965a0c 100644 --- a/docs/my-website/docs/providers/azure_ai.md +++ b/docs/my-website/docs/providers/azure_ai.md @@ -14,7 +14,7 @@ LiteLLM supports all models on Azure AI Studio ### ENV VAR ```python import os -os.environ["AZURE_API_API_KEY"] = "" +os.environ["AZURE_AI_API_KEY"] = "" os.environ["AZURE_AI_API_BASE"] = "" ``` @@ -24,7 +24,7 @@ os.environ["AZURE_AI_API_BASE"] = "" from litellm import completion import os ## set ENV variables -os.environ["AZURE_API_API_KEY"] = "azure ai key" +os.environ["AZURE_AI_API_KEY"] = "azure ai key" os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/ # predibase llama-3 call diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index f380a6a50..b72dac10b 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -549,6 +549,10 @@ response = completion( This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues) + +Experimental - 2024-Jun-23: + `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token` will be extracted from boto3.client and be passed into the httpx client + ::: Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth. diff --git a/docs/my-website/docs/providers/openai_compatible.md b/docs/my-website/docs/providers/openai_compatible.md index f02149024..33ab8fb41 100644 --- a/docs/my-website/docs/providers/openai_compatible.md +++ b/docs/my-website/docs/providers/openai_compatible.md @@ -18,7 +18,7 @@ import litellm import os response = litellm.completion( - model="openai/mistral, # add `openai/` prefix to model so litellm knows to route to OpenAI + model="openai/mistral", # add `openai/` prefix to model so litellm knows to route to OpenAI api_key="sk-1234", # api key to your openai compatible endpoint api_base="http://0.0.0.0:4000", # set API Base of your Custom OpenAI Endpoint messages=[ diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index de1b5811f..ce9e73bab 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -123,6 +123,182 @@ print(completion(**data)) ### **JSON Schema** +From v`1.40.1+` LiteLLM supports sending `response_schema` as a param for Gemini-1.5-Pro on Vertex AI. For other models (e.g. `gemini-1.5-flash` or `claude-3-5-sonnet`), LiteLLM adds the schema to the message list with a user-controlled prompt. + +**Response Schema** + + + +```python +from litellm import completion +import json + +## SETUP ENVIRONMENT +# !gcloud auth application-default login - run this to add vertex credentials to your env + +messages = [ + { + "role": "user", + "content": "List 5 popular cookie recipes." + } +] + +response_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + } + + +completion( + model="vertex_ai_beta/gemini-1.5-pro", + messages=messages, + response_format={"type": "json_object", "response_schema": response_schema} # 👈 KEY CHANGE + ) + +print(json.loads(completion.choices[0].message.content)) +``` + + + + +1. Add model to config.yaml +```yaml +model_list: + - model_name: gemini-pro + litellm_params: + model: vertex_ai_beta/gemini-1.5-pro + vertex_project: "project-id" + vertex_location: "us-central1" + vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env +``` + +2. Start Proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "gemini-pro", + "messages": [ + {"role": "user", "content": "List 5 popular cookie recipes."} + ], + "response_format": {"type": "json_object", "response_schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + }} +} +' +``` + + + + +**Validate Schema** + +To validate the response_schema, set `enforce_validation: true`. + + + + +```python +from litellm import completion, JSONSchemaValidationError +try: + completion( + model="vertex_ai_beta/gemini-1.5-pro", + messages=messages, + response_format={ + "type": "json_object", + "response_schema": response_schema, + "enforce_validation": true # 👈 KEY CHANGE + } + ) +except JSONSchemaValidationError as e: + print("Raw Response: {}".format(e.raw_response)) + raise e +``` + + + +1. Add model to config.yaml +```yaml +model_list: + - model_name: gemini-pro + litellm_params: + model: vertex_ai_beta/gemini-1.5-pro + vertex_project: "project-id" + vertex_location: "us-central1" + vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env +``` + +2. Start Proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "gemini-pro", + "messages": [ + {"role": "user", "content": "List 5 popular cookie recipes."} + ], + "response_format": {"type": "json_object", "response_schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + }, + "enforce_validation": true + } +} +' +``` + + + + +LiteLLM will validate the response against the schema, and raise a `JSONSchemaValidationError` if the response does not match the schema. + +JSONSchemaValidationError inherits from `openai.APIError` + +Access the raw response with `e.raw_response` + +**Add to prompt yourself** + ```python from litellm import completion @@ -645,6 +821,86 @@ assert isinstance( ``` +## Usage - PDF / Videos / etc. Files + +Pass any file supported by Vertex AI, through LiteLLM. + + + + +```python +from litellm import completion + +response = completion( + model="vertex_ai/gemini-1.5-flash", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "You are a very professional document summarization specialist. Please summarize the given document."}, + { + "type": "image_url", + "image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + }, + ], + } + ], + max_tokens=300, +) + +print(response.choices[0]) + +``` + + + +1. Add model to config + +```yaml +- model_name: gemini-1.5-flash + litellm_params: + model: vertex_ai/gemini-1.5-flash + vertex_credentials: "/path/to/service_account.json" +``` + +2. Start Proxy + +``` +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl http://0.0.0.0:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{ + "model": "gemini-1.5-flash", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "You are a very professional document summarization specialist. Please summarize the given document" + }, + { + "type": "image_url", + "image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + }, + } + ] + } + ], + "max_tokens": 300 + }' + +``` + + + + ## Chat Models | Model Name | Function Call | |------------------|--------------------------------------| diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 3ab644855..00457dbc4 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -277,6 +277,54 @@ curl --location 'http://0.0.0.0:4000/v1/model/info' \ --data '' ``` +## Wildcard Model Name (Add ALL MODELS from env) + +Dynamically call any model from any given provider without the need to predefine it in the config YAML file. As long as the relevant keys are in the environment (see [providers list](../providers/)), LiteLLM will make the call correctly. + + + +1. Setup config.yaml +``` +model_list: + - model_name: "*" # all requests where model not in your config go to this deployment + litellm_params: + model: "openai/*" # passes our validation check that a real provider is given +``` + +2. Start LiteLLM proxy + +``` +litellm --config /path/to/config.yaml +``` + +3. Try claude 3-5 sonnet from anthropic + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + {"role": "user", "content": "Hey, how'\''s it going?"}, + { + "role": "assistant", + "content": "I'\''m doing well. Would like to hear the rest of the story?" + }, + {"role": "user", "content": "Na"}, + { + "role": "assistant", + "content": "No problem, is there anything else i can help you with today?" + }, + { + "role": "user", + "content": "I think you'\''re getting cut off sometimes" + } + ] +} +' +``` + ## Load Balancing :::info diff --git a/docs/my-website/docs/proxy/cost_tracking.md b/docs/my-website/docs/proxy/cost_tracking.md index 3ccf8f383..fe3a46250 100644 --- a/docs/my-website/docs/proxy/cost_tracking.md +++ b/docs/my-website/docs/proxy/cost_tracking.md @@ -117,6 +117,8 @@ That's IT. Now Verify your spend was tracked +Expect to see `x-litellm-response-cost` in the response headers with calculated cost + @@ -145,16 +147,16 @@ Navigate to the Usage Tab on the LiteLLM UI (found on https://your-proxy-endpoin -## API Endpoints to get Spend -#### Getting Spend Reports - To Charge Other Teams, Customers - -Use the `/global/spend/report` endpoint to get daily spend report per -- team -- customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm) - +## ✨ (Enterprise) API Endpoints to get Spend +#### Getting Spend Reports - To Charge Other Teams, Customers + +Use the `/global/spend/report` endpoint to get daily spend report per +- Team +- Customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm) +- [LiteLLM API key](virtual_keys.md) @@ -337,6 +339,61 @@ curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end ``` + + + + + +👉 Key Change: Specify `group_by=api_key` + + +```shell +curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30&group_by=api_key' \ + -H 'Authorization: Bearer sk-1234' +``` + +##### Example Response + + +```shell +[ + { + "api_key": "ad64768847d05d978d62f623d872bff0f9616cc14b9c1e651c84d14fe3b9f539", + "total_cost": 0.0002157, + "total_input_tokens": 45.0, + "total_output_tokens": 1375.0, + "model_details": [ + { + "model": "gpt-3.5-turbo", + "total_cost": 0.0001095, + "total_input_tokens": 9, + "total_output_tokens": 70 + }, + { + "model": "llama3-8b-8192", + "total_cost": 0.0001062, + "total_input_tokens": 36, + "total_output_tokens": 1305 + } + ] + }, + { + "api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b", + "total_cost": 0.00012924, + "total_input_tokens": 36.0, + "total_output_tokens": 1593.0, + "model_details": [ + { + "model": "llama3-8b-8192", + "total_cost": 0.00012924, + "total_input_tokens": 36, + "total_output_tokens": 1593 + } + ] + } +] +``` + diff --git a/docs/my-website/docs/proxy/debugging.md b/docs/my-website/docs/proxy/debugging.md index 571a97c0e..38680982a 100644 --- a/docs/my-website/docs/proxy/debugging.md +++ b/docs/my-website/docs/proxy/debugging.md @@ -88,4 +88,31 @@ Expected Output: ```bash # no info statements -``` \ No newline at end of file +``` + +## Common Errors + +1. "No available deployments..." + +``` +No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a. +``` + +This can be caused due to all your models hitting rate limit errors, causing the cooldown to kick in. + +How to control this? +- Adjust the cooldown time + +```yaml +router_settings: + cooldown_time: 0 # 👈 KEY CHANGE +``` + +- Disable Cooldowns [NOT RECOMMENDED] + +```yaml +router_settings: + disable_cooldowns: True +``` + +This is not recommended, as it will lead to requests being routed to deployments over their tpm/rpm limit. \ No newline at end of file diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 9fff879e5..5dabba5ed 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -6,21 +6,34 @@ import TabItem from '@theme/TabItem'; :::tip -Get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) +To get a license, get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) ::: Features: -- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features) -- ✅ [Audit Logs](#audit-logs) -- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags) -- ✅ [Control available public, private routes](#control-available-public-private-routes) -- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) -- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) -- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding) -- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests) -- ✅ Reject calls from Blocked User list -- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) + +- **Security** + - ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features) + - ✅ [Audit Logs with retention policy](#audit-logs) + - ✅ [JWT-Auth](../docs/proxy/token_auth.md) + - ✅ [Control available public, private routes](#control-available-public-private-routes) + - ✅ [[BETA] AWS Key Manager v2 - Key Decryption](#beta-aws-key-manager---key-decryption) + - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) + - ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests) +- **Spend Tracking** + - ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags) + - ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend) +- **Advanced Metrics** + - ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens) +- **Guardrails, PII Masking, Content Moderation** + - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) + - ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) + - ✅ Reject calls from Blocked User list + - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) +- **Custom Branding** + - ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding) + - ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub) + - ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding) ## Audit Logs @@ -1019,4 +1032,35 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ Share a public page of available models for users - \ No newline at end of file + + + +## [BETA] AWS Key Manager - Key Decryption + +This is a beta feature, and subject to changes. + + +**Step 1.** Add `USE_AWS_KMS` to env + +```env +USE_AWS_KMS="True" +``` + +**Step 2.** Add `aws_kms/` to encrypted keys in env + +```env +DATABASE_URL="aws_kms/AQICAH.." +``` + +**Step 3.** Start proxy + +``` +$ litellm +``` + +How it works? +- Key Decryption runs before server starts up. [**Code**](https://github.com/BerriAI/litellm/blob/8571cb45e80cc561dc34bc6aa89611eb96b9fe3e/litellm/proxy/proxy_cli.py#L445) +- It adds the decrypted value to the `os.environ` for the python process. + +**Note:** Setting an environment variable within a Python script using os.environ will not make that variable accessible via SSH sessions or any other new processes that are started independently of the Python script. Environment variables set this way only affect the current process and its child processes. + diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index f9ed5db3d..83bf8ee95 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -1188,6 +1188,7 @@ litellm_settings: s3_region_name: us-west-2 # AWS Region Name for S3 s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 + s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets ``` diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md new file mode 100644 index 000000000..1348a2fc1 --- /dev/null +++ b/docs/my-website/docs/proxy/pass_through.md @@ -0,0 +1,220 @@ +import Image from '@theme/IdealImage'; + +# ➡️ Create Pass Through Endpoints + +Add pass through routes to LiteLLM Proxy + +**Example:** Add a route `/v1/rerank` that forwards requests to `https://api.cohere.com/v1/rerank` through LiteLLM Proxy + + +💡 This allows making the following Request to LiteLLM Proxy +```shell +curl --request POST \ + --url http://localhost:4000/v1/rerank \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": ["Carson City is the capital city of the American state of Nevada."] + }' +``` + +## Tutorial - Pass through Cohere Re-Rank Endpoint + +**Step 1** Define pass through routes on [litellm config.yaml](configs.md) + +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json +``` + +**Step 2** Start Proxy Server in detailed_debug mode + +```shell +litellm --config config.yaml --detailed_debug +``` +**Step 3** Make Request to pass through endpoint + +Here `http://localhost:4000` is your litellm proxy endpoint + +```shell +curl --request POST \ + --url http://localhost:4000/v1/rerank \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": ["Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."] + }' +``` + + +🎉 **Expected Response** + +This request got forwarded from LiteLLM Proxy -> Defined Target URL (with headers) + +```shell +{ + "id": "37103a5b-8cfb-48d3-87c7-da288bedd429", + "results": [ + { + "index": 2, + "relevance_score": 0.999071 + }, + { + "index": 4, + "relevance_score": 0.7867867 + }, + { + "index": 0, + "relevance_score": 0.32713068 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } +} +``` + +## Tutorial - Pass Through Langfuse Requests + + +**Step 1** Define pass through routes on [litellm config.yaml](configs.md) + +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server + target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward + headers: + LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key + LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key +``` + +**Step 2** Start Proxy Server in detailed_debug mode + +```shell +litellm --config config.yaml --detailed_debug +``` +**Step 3** Make Request to pass through endpoint + +Run this code to make a sample trace +```python +from langfuse import Langfuse + +langfuse = Langfuse( + host="http://localhost:4000", # your litellm proxy endpoint + public_key="anything", # no key required since this is a pass through + secret_key="anything", # no key required since this is a pass through +) + +print("sending langfuse trace request") +trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough") +print("flushing langfuse request") +langfuse.flush() + +print("flushed langfuse request") +``` + + +🎉 **Expected Response** + +On success +Expect to see the following Trace Generated on your Langfuse Dashboard + + + +You will see the following endpoint called on your litellm proxy server logs + +```shell +POST /api/public/ingestion HTTP/1.1" 207 Multi-Status +``` + + +## ✨ [Enterprise] - Use LiteLLM keys/authentication on Pass Through Endpoints + +Use this if you want the pass through endpoint to honour LiteLLM keys/authentication + +Usage - set `auth: true` on the config +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + auth: true # 👈 Key change to use LiteLLM Auth / Keys + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + content-type: application/json + accept: application/json +``` + +Test Request with LiteLLM Key + +```shell +curl --request POST \ + --url http://localhost:4000/v1/rerank \ + --header 'accept: application/json' \ + --header 'Authorization: Bearer sk-1234'\ + --header 'content-type: application/json' \ + --data '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": ["Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."] + }' +``` + +## `pass_through_endpoints` Spec on config.yaml + +All possible values for `pass_through_endpoints` and what they mean + +**Example config** +```yaml +general_settings: + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json +``` + +**Spec** + +* `pass_through_endpoints` *list*: A collection of endpoint configurations for request forwarding. + * `path` *string*: The route to be added to the LiteLLM Proxy Server. + * `target` *string*: The URL to which requests for this path should be forwarded. + * `headers` *object*: Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint + * `Authorization` *string*: The authentication header for the target API. + * `content-type` *string*: The format specification for the request body. + * `accept` *string*: The expected response format from the server. + * `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse. + * `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse. + * `` *string*: Pass any custom header key/value pair \ No newline at end of file diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index 2c7481f4c..6790b25b0 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # 📈 Prometheus metrics [BETA] LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll @@ -61,6 +64,56 @@ http://localhost:4000/metrics | `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)| +### ✨ (Enterprise) LLM Remaining Requests and Remaining Tokens +Set this on your config.yaml to allow you to track how close you are to hitting your TPM / RPM limits on each model group + +```yaml +litellm_settings: + success_callback: ["prometheus"] + failure_callback: ["prometheus"] + return_response_headers: true # ensures the LLM API calls track the response headers +``` + +| Metric Name | Description | +|----------------------|--------------------------------------| +| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment | +| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment | + +Example Metric + + + + +```shell +litellm_remaining_requests +{ + api_base="https://api.openai.com/v1", + api_provider="openai", + litellm_model_name="gpt-3.5-turbo", + model_group="gpt-3.5-turbo" +} +8998.0 +``` + + + + + +```shell +litellm_remaining_tokens +{ + api_base="https://api.openai.com/v1", + api_provider="openai", + litellm_model_name="gpt-3.5-turbo", + model_group="gpt-3.5-turbo" +} +999981.0 +``` + + + + + ## Monitor System Health To monitor the health of litellm adjacent services (redis / postgres), do: diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 240e6c8e0..905954e97 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -815,6 +815,35 @@ model_list: +**Expected Response** + +``` +No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a. +``` + +#### **Disable cooldowns** + + + + + +```python +from litellm import Router + + +router = Router(..., disable_cooldowns=True) +``` + + + +```yaml +router_settings: + disable_cooldowns: True +``` + + + + ### Retries For both async + sync functions, we support retrying failed requests. diff --git a/docs/my-website/docs/secret.md b/docs/my-website/docs/secret.md index 08c2e89d1..91ae38368 100644 --- a/docs/my-website/docs/secret.md +++ b/docs/my-website/docs/secret.md @@ -8,7 +8,13 @@ LiteLLM supports reading secrets from Azure Key Vault and Infisical - [Infisical Secret Manager](#infisical-secret-manager) - [.env Files](#env-files) -## AWS Key Management Service +## AWS Key Management V1 + +:::tip + +[BETA] AWS Key Management v2 is on the enterprise tier. Go [here for docs](./proxy/enterprise.md#beta-aws-key-manager---key-decryption) + +::: Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment. diff --git a/docs/my-website/img/proxy_langfuse.png b/docs/my-website/img/proxy_langfuse.png new file mode 100644 index 000000000..4a3ca28ee Binary files /dev/null and b/docs/my-website/img/proxy_langfuse.png differ diff --git a/docs/my-website/img/response_cost_img.png b/docs/my-website/img/response_cost_img.png index 9f466b3fb..2fa9c2009 100644 Binary files a/docs/my-website/img/response_cost_img.png and b/docs/my-website/img/response_cost_img.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 31bc6abcb..82f4bd260 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -48,6 +48,7 @@ const sidebars = { "proxy/billing", "proxy/user_keys", "proxy/virtual_keys", + "proxy/token_auth", "proxy/alerting", { type: "category", @@ -56,11 +57,11 @@ const sidebars = { }, "proxy/ui", "proxy/prometheus", + "proxy/pass_through", "proxy/email", "proxy/multiple_admins", "proxy/team_based_routing", "proxy/customer_routing", - "proxy/token_auth", { type: "category", label: "Extra Load Balancing", diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index dd37ae2c1..2a4ad418b 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -114,7 +114,11 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): if flagged == True: raise HTTPException( - status_code=400, detail={"error": "Violated content safety policy"} + status_code=400, + detail={ + "error": "Violated content safety policy", + "lakera_ai_response": _json_response, + }, ) pass diff --git a/entrypoint.sh b/entrypoint.sh index 80adf8d07..a028e5426 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,48 +1,13 @@ -#!/bin/sh +#!/bin/bash +echo $(pwd) -# Check if DATABASE_URL is not set -if [ -z "$DATABASE_URL" ]; then - # Check if all required variables are provided - if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then - # Construct DATABASE_URL from the provided variables - DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}" - export DATABASE_URL - else - echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." - exit 1 - fi -fi +# Run the Python migration script +python3 litellm/proxy/prisma_migration.py -# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations -if [ -z "$DIRECT_URL" ]; then - export DIRECT_URL=$DATABASE_URL -fi - -# Apply migrations -retry_count=0 -max_retries=3 -exit_code=1 - -until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] -do - retry_count=$((retry_count+1)) - echo "Attempt $retry_count..." - - # Run the Prisma db push command - prisma db push --accept-data-loss - - exit_code=$? - - if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then - echo "Retrying in 10 seconds..." - sleep 10 - fi -done - -if [ $exit_code -ne 0 ]; then - echo "Unable to push database changes after $max_retries retries." +# Check if the Python script executed successfully +if [ $? -eq 0 ]; then + echo "Migration script ran successfully!" +else + echo "Migration script failed!" exit 1 fi - -echo "Database push successful!" - diff --git a/litellm/__init__.py b/litellm/__init__.py index 381c4c530..9cdfa15e3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -125,6 +125,9 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" ################## ### PREVIEW FEATURES ### enable_preview_features: bool = False +return_response_headers: bool = ( + False # get response headers from LLM Api providers - example x-remaining-requests, +) ################## logging: bool = True caching: bool = ( @@ -749,6 +752,7 @@ from .utils import ( create_pretrained_tokenizer, create_tokenizer, supports_function_calling, + supports_response_schema, supports_parallel_function_calling, supports_vision, supports_system_messages, @@ -799,7 +803,11 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig +from .llms.bedrock_httpx import ( + AmazonCohereChatConfig, + AmazonConverseConfig, + BEDROCK_CONVERSE_MODELS, +) from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, @@ -848,6 +856,7 @@ from .exceptions import ( APIResponseValidationError, UnprocessableEntityError, InternalServerError, + JSONSchemaValidationError, LITELLM_EXCEPTION_TYPES, ) from .budget_manager import BudgetManager diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 9a7df7ebe..062e98be9 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1,6 +1,7 @@ # What is this? ## File for 'response_cost' calculation in Logging import time +import traceback from typing import List, Literal, Optional, Tuple, Union import litellm @@ -668,3 +669,10 @@ def response_cost_calculator( f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map." ) return None + except Exception as e: + verbose_logger.error( + "litellm.cost_calculator.py::response_cost_calculator - Exception occurred - {}/n{}".format( + str(e), traceback.format_exc() + ) + ) + return None diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 98b519278..d85510b1d 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore message, llm_provider, model, - request: httpx.Request, + request: Optional[httpx.Request] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + if request is None: + request = httpx.Request(method="POST", url="https://api.openai.com/v1") super().__init__(self.message, request=request, body=None) # type: ignore def __str__(self): @@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore self.llm_provider = "openai" +class JSONSchemaValidationError(APIError): + def __init__( + self, model: str, llm_provider: str, raw_response: str, schema: str + ) -> None: + self.raw_response = raw_response + self.schema = schema + self.model = model + message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format( + model, raw_response, schema + ) + self.message = message + super().__init__( + model=model, message=message, llm_provider=llm_provider, status_code=500 + ) + + LITELLM_EXCEPTION_TYPES = [ AuthenticationError, NotFoundError, @@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [ APIResponseValidationError, OpenAIError, InternalServerError, + JSONSchemaValidationError, ] diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index eae8b8e22..983ec3942 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -311,22 +311,17 @@ class LangFuseLogger: try: tags = [] - try: - metadata = copy.deepcopy( - metadata - ) # Avoid modifying the original metadata - except: - new_metadata = {} - for key, value in metadata.items(): - if ( - isinstance(value, list) - or isinstance(value, dict) - or isinstance(value, str) - or isinstance(value, int) - or isinstance(value, float) - ): - new_metadata[key] = copy.deepcopy(value) - metadata = new_metadata + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, dict) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = copy.deepcopy(value) + metadata = new_metadata supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3") diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 4f0ffa387..6cd746907 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -2,14 +2,20 @@ #### What this does #### # On success, log events to Prometheus -import dotenv, os -import requests # type: ignore +import datetime +import os +import subprocess +import sys import traceback -import datetime, subprocess, sys -import litellm, uuid -from litellm._logging import print_verbose, verbose_logger +import uuid from typing import Optional, Union +import dotenv +import requests # type: ignore + +import litellm +from litellm._logging import print_verbose, verbose_logger + class PrometheusLogger: # Class variables or attributes @@ -20,6 +26,8 @@ class PrometheusLogger: try: from prometheus_client import Counter, Gauge + from litellm.proxy.proxy_server import premium_user + self.litellm_llm_api_failed_requests_metric = Counter( name="litellm_llm_api_failed_requests_metric", documentation="Total number of failed LLM API calls via litellm", @@ -88,6 +96,31 @@ class PrometheusLogger: labelnames=["hashed_api_key", "api_key_alias"], ) + # Litellm-Enterprise Metrics + if premium_user is True: + # Remaining Rate Limit for model + self.litellm_remaining_requests_metric = Gauge( + "litellm_remaining_requests", + "remaining requests for model, returned from LLM API Provider", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + ], + ) + + self.litellm_remaining_tokens_metric = Gauge( + "litellm_remaining_tokens", + "remaining tokens for model, returned from LLM API Provider", + labelnames=[ + "model_group", + "api_provider", + "api_base", + "litellm_model_name", + ], + ) + except Exception as e: print_verbose(f"Got exception on init prometheus client {str(e)}") raise e @@ -104,6 +137,8 @@ class PrometheusLogger: ): try: # Define prometheus client + from litellm.proxy.proxy_server import premium_user + verbose_logger.debug( f"prometheus Logging - Enters logging function for model {kwargs}" ) @@ -199,6 +234,10 @@ class PrometheusLogger: user_api_key, user_api_key_alias ).set(_remaining_api_key_budget) + # set x-ratelimit headers + if premium_user is True: + self.set_remaining_tokens_requests_metric(kwargs) + ### FAILURE INCREMENT ### if "exception" in kwargs: self.litellm_llm_api_failed_requests_metric.labels( @@ -216,6 +255,58 @@ class PrometheusLogger: verbose_logger.debug(traceback.format_exc()) pass + def set_remaining_tokens_requests_metric(self, request_kwargs: dict): + try: + verbose_logger.debug("setting remaining tokens requests metric") + _response_headers = request_kwargs.get("response_headers") + _litellm_params = request_kwargs.get("litellm_params", {}) or {} + _metadata = _litellm_params.get("metadata", {}) + litellm_model_name = request_kwargs.get("model", None) + model_group = _metadata.get("model_group", None) + api_base = _metadata.get("api_base", None) + llm_provider = _litellm_params.get("custom_llm_provider", None) + + remaining_requests = None + remaining_tokens = None + # OpenAI / OpenAI Compatible headers + if ( + _response_headers + and "x-ratelimit-remaining-requests" in _response_headers + ): + remaining_requests = _response_headers["x-ratelimit-remaining-requests"] + if ( + _response_headers + and "x-ratelimit-remaining-tokens" in _response_headers + ): + remaining_tokens = _response_headers["x-ratelimit-remaining-tokens"] + verbose_logger.debug( + f"remaining requests: {remaining_requests}, remaining tokens: {remaining_tokens}" + ) + + if remaining_requests: + """ + "model_group", + "api_provider", + "api_base", + "litellm_model_name" + """ + self.litellm_remaining_requests_metric.labels( + model_group, llm_provider, api_base, litellm_model_name + ).set(remaining_requests) + + if remaining_tokens: + self.litellm_remaining_tokens_metric.labels( + model_group, llm_provider, api_base, litellm_model_name + ).set(remaining_tokens) + + except Exception as e: + verbose_logger.error( + "Prometheus Error: set_remaining_tokens_requests_metric. Exception occured - {}".format( + str(e) + ) + ) + return + def safe_get_remaining_budget( max_budget: Optional[float], spend: Optional[float] diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index 0796d1048..6e8c4a4e4 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -1,10 +1,14 @@ #### What this does #### # On success + failure, log events to Supabase +import datetime import os +import subprocess +import sys import traceback -import datetime, subprocess, sys -import litellm, uuid +import uuid + +import litellm from litellm._logging import print_verbose, verbose_logger @@ -54,6 +58,7 @@ class S3Logger: "s3_aws_session_token" ) s3_config = litellm.s3_callback_params.get("s3_config") + s3_path = litellm.s3_callback_params.get("s3_path") # done reading litellm.s3_callback_params self.bucket_name = s3_bucket_name diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index a325a6885..d8d551048 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -26,7 +26,7 @@ def map_finish_reason( finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP" ): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] return "stop" - elif finish_reason == "SAFETY": # vertex ai + elif finish_reason == "SAFETY" or finish_reason == "RECITATION": # vertex ai return "content_filter" elif finish_reason == "STOP": # vertex ai return "stop" diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py new file mode 100644 index 000000000..5ac26c7ae --- /dev/null +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -0,0 +1,40 @@ +import json +from typing import Optional + + +def get_error_message(error_obj) -> Optional[str]: + """ + OpenAI Returns Error message that is nested, this extract the message + + Example: + { + 'request': "", + 'message': "Error code: 400 - {\'error\': {\'message\': \"Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.\", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'decimal_above_max_value'}}", + 'body': { + 'message': "Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.", + 'type': 'invalid_request_error', + 'param': 'temperature', + 'code': 'decimal_above_max_value' + }, + 'code': 'decimal_above_max_value', + 'param': 'temperature', + 'type': 'invalid_request_error', + 'response': "", + 'status_code': 400, + 'request_id': 'req_f287898caa6364cd42bc01355f74dd2a' + } + """ + try: + # First, try to access the message directly from the 'body' key + if error_obj is None: + return None + + if hasattr(error_obj, "body"): + _error_obj_body = getattr(error_obj, "body") + if isinstance(_error_obj_body, dict): + return _error_obj_body.get("message") + + # If all else fails, return None + return None + except Exception as e: + return None diff --git a/litellm/litellm_core_utils/json_validation_rule.py b/litellm/litellm_core_utils/json_validation_rule.py new file mode 100644 index 000000000..f19144aaf --- /dev/null +++ b/litellm/litellm_core_utils/json_validation_rule.py @@ -0,0 +1,23 @@ +import json + + +def validate_schema(schema: dict, response: str): + """ + Validate if the returned json response follows the schema. + + Params: + - schema - dict: JSON schema + - response - str: Received json response as string. + """ + from jsonschema import ValidationError, validate + + from litellm import JSONSchemaValidationError + + response_dict = json.loads(response) + + try: + validate(response_dict, schema=schema) + except ValidationError: + raise JSONSchemaValidationError( + model="", llm_provider="", raw_response=response, schema=json.dumps(schema) + ) diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 808813c05..1051a56b7 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -1,23 +1,28 @@ -import os, types +import copy import json -from enum import Enum -import requests, copy # type: ignore +import os import time +import types +from enum import Enum from functools import partial -from typing import Callable, Optional, List, Union -import litellm.litellm_core_utils -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -from litellm.litellm_core_utils.core_helpers import map_finish_reason +from typing import Callable, List, Optional, Union + +import httpx # type: ignore +import requests # type: ignore + import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt +import litellm.litellm_core_utils +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, _get_async_httpx_client, _get_httpx_client, ) -from .base import BaseLLM -import httpx # type: ignore from litellm.types.llms.anthropic import AnthropicMessagesToolChoice +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import custom_prompt, prompt_factory class AnthropicConstants(Enum): @@ -179,10 +184,19 @@ async def make_call( if client is None: client = _get_async_httpx_client() # Create a new client if none provided - response = await client.post(api_base, headers=headers, data=data, stream=True) + try: + response = await client.post(api_base, headers=headers, data=data, stream=True) + except httpx.HTTPStatusError as e: + raise AnthropicError( + status_code=e.response.status_code, message=await e.response.aread() + ) + except Exception as e: + raise AnthropicError(status_code=500, message=str(e)) if response.status_code != 200: - raise AnthropicError(status_code=response.status_code, message=response.text) + raise AnthropicError( + status_code=response.status_code, message=await response.aread() + ) completion_stream = response.aiter_lines() diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 11bf5c3f0..000feed44 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -23,6 +23,7 @@ from typing_extensions import overload import litellm from litellm import OpenAIConfig from litellm.caching import DualCache +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.utils import ( Choices, CustomStreamWrapper, @@ -458,6 +459,36 @@ class AzureChatCompletion(BaseLLM): return azure_client + async def make_azure_openai_chat_completion_request( + self, + azure_client: AsyncAzureOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = ( + await azure_client.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await azure_client.chat.completions.create( + **data, timeout=timeout + ) + return None, response + except Exception as e: + raise e + def completion( self, model: str, @@ -470,7 +501,7 @@ class AzureChatCompletion(BaseLLM): azure_ad_token: str, print_verbose: Callable, timeout: Union[float, httpx.Timeout], - logging_obj, + logging_obj: LiteLLMLoggingObj, optional_params, litellm_params, logger_fn, @@ -649,9 +680,9 @@ class AzureChatCompletion(BaseLLM): data: dict, timeout: Any, model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, azure_ad_token: Optional[str] = None, client=None, # this is the AsyncAzureOpenAI - logging_obj=None, ): response = None try: @@ -701,9 +732,13 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.chat.completions.create( - **data, timeout=timeout + + headers, response = await self.make_azure_openai_chat_completion_request( + azure_client=azure_client, + data=data, + timeout=timeout, ) + logging_obj.model_call_details["response_headers"] = headers stringified_response = response.model_dump() logging_obj.post_call( @@ -717,11 +752,32 @@ class AzureChatCompletion(BaseLLM): model_response_object=model_response, ) except AzureOpenAIError as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) exception_mapping_worked = True raise e except asyncio.CancelledError as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) raise AzureOpenAIError(status_code=500, message=str(e)) except Exception as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) if hasattr(e, "status_code"): raise e else: @@ -791,7 +847,7 @@ class AzureChatCompletion(BaseLLM): async def async_streaming( self, - logging_obj, + logging_obj: LiteLLMLoggingObj, api_base: str, api_key: str, api_version: str, @@ -840,9 +896,14 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.chat.completions.create( - **data, timeout=timeout + + headers, response = await self.make_azure_openai_chat_completion_request( + azure_client=azure_client, + data=data, + timeout=timeout, ) + logging_obj.model_call_details["response_headers"] = headers + # return response streamwrapper = CustomStreamWrapper( completion_stream=response, diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 14abec784..7b4628a76 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -60,6 +60,17 @@ from .prompt_templates.factory import ( prompt_factory, ) +BEDROCK_CONVERSE_MODELS = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-v2", + "anthropic.claude-v2:1", + "anthropic.claude-v1", + "anthropic.claude-instant-v1", +] + iam_cache = DualCache() @@ -305,6 +316,7 @@ class BedrockLLM(BaseLLM): self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, @@ -320,6 +332,7 @@ class BedrockLLM(BaseLLM): params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -337,6 +350,7 @@ class BedrockLLM(BaseLLM): ( aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -430,6 +444,19 @@ class BedrockLLM(BaseLLM): client = boto3.Session(profile_name=aws_profile_name) return client.get_credentials() + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_session_token is not None + ): ### CHECK FOR AWS SESSION TOKEN ### + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + return credentials else: session = boto3.Session( aws_access_key_id=aws_access_key_id, @@ -734,9 +761,10 @@ class BedrockLLM(BaseLLM): provider = model.split(".")[0] ## CREDENTIALS ## - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) @@ -768,6 +796,7 @@ class BedrockLLM(BaseLLM): credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, @@ -1422,6 +1451,7 @@ class BedrockConverseLLM(BaseLLM): self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, @@ -1437,6 +1467,7 @@ class BedrockConverseLLM(BaseLLM): params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -1454,6 +1485,7 @@ class BedrockConverseLLM(BaseLLM): ( aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_session_name, aws_profile_name, @@ -1547,6 +1579,19 @@ class BedrockConverseLLM(BaseLLM): client = boto3.Session(profile_name=aws_profile_name) return client.get_credentials() + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_session_token is not None + ): ### CHECK FOR AWS SESSION TOKEN ### + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + return credentials else: session = boto3.Session( aws_access_key_id=aws_access_key_id, @@ -1682,6 +1727,7 @@ class BedrockConverseLLM(BaseLLM): # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) @@ -1713,6 +1759,7 @@ class BedrockConverseLLM(BaseLLM): credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index f361ede5b..a6726eb98 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,4 +1,8 @@ -import time, json, httpx, asyncio +import asyncio +import json +import time + +import httpx class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): @@ -7,15 +11,18 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): """ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - if "images/generations" in request.url.path and request.url.params[ - "api-version" - ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict - "2023-06-01-preview", - "2023-07-01-preview", - "2023-08-01-preview", - "2023-09-01-preview", - "2023-10-01-preview", - ]: + _api_version = request.url.params.get("api-version", "") + if ( + "images/generations" in request.url.path + and _api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): request.url = request.url.copy_with( path="/openai/images/generations:submit" ) @@ -77,15 +84,18 @@ class CustomHTTPTransport(httpx.HTTPTransport): self, request: httpx.Request, ) -> httpx.Response: - if "images/generations" in request.url.path and request.url.params[ - "api-version" - ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict - "2023-06-01-preview", - "2023-07-01-preview", - "2023-08-01-preview", - "2023-09-01-preview", - "2023-10-01-preview", - ]: + _api_version = request.url.params.get("api-version", "") + if ( + "images/generations" in request.url.path + and _api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): request.url = request.url.copy_with( path="/openai/images/generations:submit" ) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index a3c5865fa..9b01c96b1 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,6 +1,11 @@ +import asyncio +import os +import traceback +from typing import Any, Mapping, Optional, Union + +import httpx + import litellm -import httpx, asyncio, traceback, os -from typing import Optional, Union, Mapping, Any # https://www.python-httpx.org/advanced/timeouts _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) @@ -93,7 +98,7 @@ class AsyncHTTPHandler: response = await self.client.send(req, stream=stream) response.raise_for_status() return response - except httpx.RemoteProtocolError: + except (httpx.RemoteProtocolError, httpx.ConnectError): # Retry the request with a new session if there is a connection error new_client = self.create_client(timeout=self.timeout, concurrent_limit=1) try: @@ -109,6 +114,11 @@ class AsyncHTTPHandler: finally: await new_client.aclose() except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) raise e except Exception as e: raise e @@ -208,6 +218,7 @@ class HTTPHandler: headers: Optional[dict] = None, stream: bool = False, ): + req = self.client.build_request( "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 7d14fa450..990ef2fae 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from typing_extensions import overload, override import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.utils import ProviderField from litellm.utils import ( Choices, @@ -652,6 +653,36 @@ class OpenAIChatCompletion(BaseLLM): else: return client + async def make_openai_chat_completion_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = ( + await openai_aclient.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await openai_aclient.chat.completions.create( + **data, timeout=timeout + ) + return None, response + except Exception as e: + raise e + def completion( self, model_response: ModelResponse, @@ -678,17 +709,17 @@ class OpenAIChatCompletion(BaseLLM): if headers: optional_params["extra_headers"] = headers if model is None or messages is None: - raise OpenAIError(status_code=422, message=f"Missing model or messages") + raise OpenAIError(status_code=422, message="Missing model or messages") if not isinstance(timeout, float) and not isinstance( timeout, httpx.Timeout ): raise OpenAIError( status_code=422, - message=f"Timeout needs to be a float or httpx.Timeout", + message="Timeout needs to be a float or httpx.Timeout", ) - if custom_llm_provider != "openai": + if custom_llm_provider is not None and custom_llm_provider != "openai": model_response.model = f"{custom_llm_provider}/{model}" # process all OpenAI compatible provider logic here if custom_llm_provider == "mistral": @@ -836,13 +867,13 @@ class OpenAIChatCompletion(BaseLLM): self, data: dict, model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, timeout: Union[float, httpx.Timeout], api_key: Optional[str] = None, api_base: Optional[str] = None, organization: Optional[str] = None, client=None, max_retries=None, - logging_obj=None, headers=None, ): response = None @@ -869,8 +900,8 @@ class OpenAIChatCompletion(BaseLLM): }, ) - response = await openai_aclient.chat.completions.create( - **data, timeout=timeout + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, data=data, timeout=timeout ) stringified_response = response.model_dump() logging_obj.post_call( @@ -879,9 +910,11 @@ class OpenAIChatCompletion(BaseLLM): original_response=stringified_response, additional_args={"complete_input_dict": data}, ) + logging_obj.model_call_details["response_headers"] = headers return convert_to_model_response_object( response_object=stringified_response, model_response_object=model_response, + hidden_params={"headers": headers}, ) except Exception as e: raise e @@ -931,10 +964,10 @@ class OpenAIChatCompletion(BaseLLM): async def async_streaming( self, - logging_obj, timeout: Union[float, httpx.Timeout], data: dict, model: str, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, organization: Optional[str] = None, @@ -965,9 +998,10 @@ class OpenAIChatCompletion(BaseLLM): }, ) - response = await openai_aclient.chat.completions.create( - **data, timeout=timeout + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, data=data, timeout=timeout ) + logging_obj.model_call_details["response_headers"] = headers streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -992,17 +1026,43 @@ class OpenAIChatCompletion(BaseLLM): else: raise OpenAIError(status_code=500, message=f"{str(e)}") + # Embedding + async def make_openai_embedding_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call embeddings.create.with_raw_response when litellm.return_response_headers is True + - call embeddings.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = await openai_aclient.embeddings.with_raw_response.create( + **data, timeout=timeout + ) # type: ignore + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore + return None, response + except Exception as e: + raise e + async def aembedding( self, input: list, data: dict, model_response: litellm.utils.EmbeddingResponse, timeout: float, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, client: Optional[AsyncOpenAI] = None, max_retries=None, - logging_obj=None, ): response = None try: @@ -1014,7 +1074,10 @@ class OpenAIChatCompletion(BaseLLM): max_retries=max_retries, client=client, ) - response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore + headers, response = await self.make_openai_embedding_request( + openai_aclient=openai_aclient, data=data, timeout=timeout + ) + logging_obj.model_call_details["response_headers"] = headers stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -1229,6 +1292,34 @@ class OpenAIChatCompletion(BaseLLM): else: raise OpenAIError(status_code=500, message=str(e)) + # Audio Transcriptions + async def make_openai_audio_transcriptions_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True + - call openai_aclient.audio.transcriptions.create by default + """ + try: + if litellm.return_response_headers is True: + raw_response = ( + await openai_aclient.audio.transcriptions.with_raw_response.create( + **data, timeout=timeout + ) + ) # type: ignore + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + else: + response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore + return None, response + except Exception as e: + raise e + def audio_transcriptions( self, model: str, @@ -1286,11 +1377,11 @@ class OpenAIChatCompletion(BaseLLM): data: dict, model_response: TranscriptionResponse, timeout: float, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, client=None, max_retries=None, - logging_obj=None, ): try: openai_aclient = self._get_openai_client( @@ -1302,9 +1393,12 @@ class OpenAIChatCompletion(BaseLLM): client=client, ) - response = await openai_aclient.audio.transcriptions.create( - **data, timeout=timeout - ) # type: ignore + headers, response = await self.make_openai_audio_transcriptions_request( + openai_aclient=openai_aclient, + data=data, + timeout=timeout, + ) + logging_obj.model_call_details["response_headers"] = headers stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -1497,9 +1591,9 @@ class OpenAITextCompletion(BaseLLM): model: str, messages: list, timeout: float, + logging_obj: LiteLLMLoggingObj, print_verbose: Optional[Callable] = None, api_base: Optional[str] = None, - logging_obj=None, acompletion: bool = False, optional_params=None, litellm_params=None, diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index b35914584..87af2a6bd 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2033,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list): return messages +def response_schema_prompt(model: str, response_schema: dict) -> str: + """ + Decides if a user-defined custom prompt or default needs to be used + + Returns the prompt str that's passed to the model as a user message + """ + custom_prompt_details: Optional[dict] = None + response_schema_as_message = [ + {"role": "user", "content": "{}".format(response_schema)} + ] + if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: + + custom_prompt_details = litellm.custom_prompt_dict[ + f"{model}/response_schema_prompt" + ] # allow user to define custom response schema prompt by model + elif "response_schema_prompt" in litellm.custom_prompt_dict: + custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"] + + if custom_prompt_details is not None: + return custom_prompt( + role_dict=custom_prompt_details["roles"], + initial_prompt_value=custom_prompt_details["initial_prompt_value"], + final_prompt_value=custom_prompt_details["final_prompt_value"], + messages=response_schema_as_message, + ) + else: + return default_response_schema_prompt(response_schema=response_schema) + + +def default_response_schema_prompt(response_schema: dict) -> str: + """ + Used if provider/model doesn't support 'response_schema' param. + + This is the default prompt. Allow user to override this with a custom_prompt. + """ + prompt_str = """Use this JSON schema: + ```json + {} + ```""".format( + response_schema + ) + return prompt_str + + # Custom prompt template def custom_prompt( role_dict: dict, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 4a4abaef4..c1e628d17 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,6 +12,7 @@ import requests # type: ignore from pydantic import BaseModel import litellm +from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.prompt_templates.factory import ( convert_to_anthropic_image_obj, @@ -328,80 +329,86 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: contents: List[ContentType] = [] msg_i = 0 - while msg_i < len(messages): - user_content: List[PartType] = [] - init_msg_i = msg_i - ## MERGE CONSECUTIVE USER CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: - if isinstance(messages[msg_i]["content"], list): - _parts: List[PartType] = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text" and len(element["text"]) > 0: - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - user_content.extend(_parts) - elif ( - isinstance(messages[msg_i]["content"], str) - and len(messages[msg_i]["content"]) > 0 + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types ): - _part = PartType(text=messages[msg_i]["content"]) - user_content.append(_part) + if isinstance(messages[msg_i]["content"], list): + _parts: List[PartType] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text" and len(element["text"]) > 0: + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + user_content.extend(_parts) + elif ( + isinstance(messages[msg_i]["content"], str) + and len(messages[msg_i]["content"]) > 0 + ): + _part = PartType(text=messages[msg_i]["content"]) + user_content.append(_part) - msg_i += 1 + msg_i += 1 - if user_content: - contents.append(ContentType(role="user", parts=user_content)) - assistant_content = [] - ## MERGE CONSECUTIVE ASSISTANT CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - if isinstance(messages[msg_i]["content"], list): - _parts = [] - for element in messages[msg_i]["content"]: - if isinstance(element, dict): - if element["type"] == "text": - _part = PartType(text=element["text"]) - _parts.append(_part) - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - _part = _process_gemini_image(image_url=image_url) - _parts.append(_part) # type: ignore - assistant_content.extend(_parts) - elif messages[msg_i].get( - "tool_calls", [] - ): # support assistant tool invoke conversion - assistant_content.extend( - convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"]) + if user_content: + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + _parts = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_gemini_image(image_url=image_url) + _parts.append(_part) # type: ignore + assistant_content.extend(_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke( + messages[msg_i]["tool_calls"] + ) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(PartType(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + _part = convert_to_gemini_tool_call_result(messages[msg_i]) + contents.append(ContentType(parts=[_part])) # type: ignore + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) ) - else: - assistant_text = ( - messages[msg_i].get("content") or "" - ) # either string or none - if assistant_text: - assistant_content.append(PartType(text=assistant_text)) - - msg_i += 1 - - if assistant_content: - contents.append(ContentType(role="model", parts=assistant_content)) - - ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and messages[msg_i]["role"] == "tool": - _part = convert_to_gemini_tool_call_result(messages[msg_i]) - contents.append(ContentType(parts=[_part])) # type: ignore - msg_i += 1 - if msg_i == init_msg_i: # prevent infinite loops - raise Exception( - "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( - messages[msg_i] - ) - ) - - return contents + return contents + except Exception as e: + raise e def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index ee6653afc..6b39716f1 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -1,24 +1,32 @@ # What is this? ## Handler file for calling claude-3 on vertex ai -import os, types +import copy import json +import os +import time +import types +import uuid from enum import Enum -import requests, copy # type: ignore -import time, uuid -from typing import Callable, Optional, List -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -from litellm.litellm_core_utils.core_helpers import map_finish_reason +from typing import Any, Callable, List, Optional, Tuple + +import httpx # type: ignore +import requests # type: ignore + import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + from .prompt_templates.factory import ( - contains_tag, - prompt_factory, - custom_prompt, construct_tool_use_system_prompt, + contains_tag, + custom_prompt, extract_between_tags, parse_xml_params, + prompt_factory, + response_schema_prompt, ) -import httpx # type: ignore class VertexAIError(Exception): @@ -104,6 +112,7 @@ class VertexAIAnthropicConfig: "stop", "temperature", "top_p", + "response_format", ] def map_openai_params(self, non_default_params: dict, optional_params: dict): @@ -120,6 +129,8 @@ class VertexAIAnthropicConfig: optional_params["temperature"] = value if param == "top_p": optional_params["top_p"] = value + if param == "response_format" and "response_schema" in value: + optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore return optional_params @@ -129,7 +140,6 @@ class VertexAIAnthropicConfig: """ -# makes headers for API call def refresh_auth( credentials, ) -> str: # used when user passes in credentials as json string @@ -144,6 +154,40 @@ def refresh_auth( return credentials.token +def get_vertex_client( + client: Any, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], +) -> Tuple[Any, Optional[str]]: + args = locals() + from litellm.llms.vertex_httpx import VertexLLM + + try: + from anthropic import AnthropicVertex + except Exception: + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", + ) + + access_token: Optional[str] = None + + if client is None: + _credentials, cred_project_id = VertexLLM().load_auth( + credentials=vertex_credentials, project_id=vertex_project + ) + vertex_ai_client = AnthropicVertex( + project_id=vertex_project or cred_project_id, + region=vertex_location or "us-central1", + access_token=_credentials.token, + ) + else: + vertex_ai_client = client + + return vertex_ai_client, access_token + + def completion( model: str, messages: list, @@ -151,10 +195,10 @@ def completion( print_verbose: Callable, encoding, logging_obj, + optional_params: dict, vertex_project=None, vertex_location=None, vertex_credentials=None, - optional_params=None, litellm_params=None, logger_fn=None, acompletion: bool = False, @@ -178,6 +222,13 @@ def completion( ) try: + vertex_ai_client, access_token = get_vertex_client( + client=client, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + ) + ## Load Config config = litellm.VertexAIAnthropicConfig.get_config() for k, v in config.items(): @@ -186,6 +237,7 @@ def completion( ## Format Prompt _is_function_call = False + _is_json_schema = False messages = copy.deepcopy(messages) optional_params = copy.deepcopy(optional_params) # Separate system prompt from rest of message @@ -200,6 +252,29 @@ def completion( messages.pop(idx) if len(system_prompt) > 0: optional_params["system"] = system_prompt + # Checks for 'response_schema' support - if passed in + if "response_format" in optional_params: + response_format_chunk = ResponseFormatChunk( + **optional_params["response_format"] # type: ignore + ) + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if ( + supports_response_schema is False + and response_format_chunk["type"] == "json_object" + and "response_schema" in response_format_chunk + ): + _is_json_schema = True + user_response_schema_message = response_schema_prompt( + model=model, + response_schema=response_format_chunk["response_schema"], + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + messages.append({"role": "assistant", "content": "{"}) + optional_params.pop("response_format") # Format rest of message according to anthropic guidelines try: messages = prompt_factory( @@ -233,32 +308,6 @@ def completion( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" ) - access_token = None - if client is None: - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - try: - json_obj = json.loads(vertex_credentials) - except json.JSONDecodeError: - json_obj = json.load(open(vertex_credentials)) - - creds = ( - google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - ) - ### CHECK IF ACCESS - access_token = refresh_auth(credentials=creds) - - vertex_ai_client = AnthropicVertex( - project_id=vertex_project, - region=vertex_location, - access_token=access_token, - ) - else: - vertex_ai_client = client if acompletion == True: """ @@ -315,7 +364,16 @@ def completion( ) message = vertex_ai_client.messages.create(**data) # type: ignore - text_content = message.content[0].text + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=message, + additional_args={"complete_input_dict": data}, + ) + + text_content: str = message.content[0].text ## TOOL CALLING - OUTPUT PARSE if text_content is not None and contains_tag("invoke", text_content): function_name = extract_between_tags("tool_name", text_content)[0] @@ -339,7 +397,13 @@ def completion( ) model_response.choices[0].message = _message # type: ignore else: - model_response.choices[0].message.content = text_content # type: ignore + if ( + _is_json_schema + ): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb + json_response = "{" + text_content[: text_content.rfind("}") + 1] + model_response.choices[0].message.content = json_response # type: ignore + else: + model_response.choices[0].message.content = text_content # type: ignore model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) ## CALCULATING USAGE diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 95a5ed492..69294d7de 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore -import ijson import requests # type: ignore import litellm @@ -21,7 +20,10 @@ import litellm.litellm_core_utils.litellm_logging from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.prompt_templates.factory import convert_url_to_base64 +from litellm.llms.prompt_templates.factory import ( + convert_url_to_base64, + response_schema_prompt, +) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.llms.openai import ( ChatCompletionResponseMessage, @@ -183,10 +185,17 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty if param == "tools" and isinstance(value, list): gtool_func_declarations = [] for tool in value: + _parameters = tool.get("function", {}).get("parameters", {}) + _properties = _parameters.get("properties", {}) + if isinstance(_properties, dict): + for _, _property in _properties.items(): + if "enum" in _property and "format" not in _property: + _property["format"] = "enum" + gtool_func_declaration = FunctionDeclaration( name=tool["function"]["name"], description=tool["function"].get("description", ""), - parameters=tool["function"].get("parameters", {}), + parameters=_parameters, ) gtool_func_declarations.append(gtool_func_declaration) optional_params["tools"] = [ @@ -349,6 +358,7 @@ class VertexGeminiConfig: model: str, non_default_params: dict, optional_params: dict, + drop_params: bool, ): for param, value in non_default_params.items(): if param == "temperature": @@ -368,8 +378,13 @@ class VertexGeminiConfig: optional_params["stop_sequences"] = value if param == "max_tokens": optional_params["max_output_tokens"] = value - if param == "response_format" and value["type"] == "json_object": # type: ignore - optional_params["response_mime_type"] = "application/json" + if param == "response_format" and isinstance(value, dict): # type: ignore + if value["type"] == "json_object": + optional_params["response_mime_type"] = "application/json" + elif value["type"] == "text": + optional_params["response_mime_type"] = "text/plain" + if "response_schema" in value: + optional_params["response_schema"] = value["response_schema"] if param == "frequency_penalty": optional_params["frequency_penalty"] = value if param == "presence_penalty": @@ -460,7 +475,7 @@ async def make_call( raise VertexAIError(status_code=response.status_code, message=response.text) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_bytes(), sync_stream=False + streaming_response=response.aiter_lines(), sync_stream=False ) # LOGGING logging_obj.post_call( @@ -491,7 +506,7 @@ def make_sync_call( raise VertexAIError(status_code=response.status_code, message=response.read()) completion_stream = ModelResponseIterator( - streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True + streaming_response=response.iter_lines(), sync_stream=True ) # LOGGING @@ -813,12 +828,13 @@ class VertexLLM(BaseLLM): endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - - url = ( - "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( + _gemini_model_name, endpoint, gemini_api_key + ) + else: + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( _gemini_model_name, endpoint, gemini_api_key ) - ) else: auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project @@ -829,7 +845,9 @@ class VertexLLM(BaseLLM): endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" if ( api_base is not None @@ -842,6 +860,9 @@ class VertexLLM(BaseLLM): else: url = "{}:{}".format(api_base, endpoint) + if stream is True: + url = url + "?alt=sse" + return auth_header, url async def async_streaming( @@ -994,35 +1015,53 @@ class VertexLLM(BaseLLM): if len(system_prompt_indices) > 0: for idx in reversed(system_prompt_indices): messages.pop(idx) - content = _gemini_convert_messages_with_history(messages=messages) - tools: Optional[Tools] = optional_params.pop("tools", None) - tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) - safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( - "safety_settings", None - ) # type: ignore - generation_config: Optional[GenerationConfig] = GenerationConfig( - **optional_params - ) - data = RequestBody(contents=content) - if len(system_content_blocks) > 0: - system_instructions = SystemInstructions(parts=system_content_blocks) - data["system_instruction"] = system_instructions - if tools is not None: - data["tools"] = tools - if tool_choice is not None: - data["toolConfig"] = tool_choice - if safety_settings is not None: - data["safetySettings"] = safety_settings - if generation_config is not None: - data["generationConfig"] = generation_config - headers = { - "Content-Type": "application/json; charset=utf-8", - } - if auth_header is not None: - headers["Authorization"] = f"Bearer {auth_header}" - if extra_headers is not None: - headers.update(extra_headers) + # Checks for 'response_schema' support - if passed in + if "response_schema" in optional_params: + supports_response_schema = litellm.supports_response_schema( + model=model, custom_llm_provider="vertex_ai" + ) + if supports_response_schema is False: + user_response_schema_message = response_schema_prompt( + model=model, response_schema=optional_params.get("response_schema") # type: ignore + ) + messages.append( + {"role": "user", "content": user_response_schema_message} + ) + optional_params.pop("response_schema") + + try: + content = _gemini_convert_messages_with_history(messages=messages) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( + "safety_settings", None + ) # type: ignore + generation_config: Optional[GenerationConfig] = GenerationConfig( + **optional_params + ) + data = RequestBody(contents=content) + if len(system_content_blocks) > 0: + system_instructions = SystemInstructions(parts=system_content_blocks) + data["system_instruction"] = system_instructions + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice + if safety_settings is not None: + data["safetySettings"] = safety_settings + if generation_config is not None: + data["generationConfig"] = generation_config + + headers = { + "Content-Type": "application/json", + } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) + except Exception as e: + raise e ## LOGGING logging_obj.pre_call( @@ -1270,11 +1309,6 @@ class VertexLLM(BaseLLM): class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response - if sync_stream: - self.response_iterator = iter(self.streaming_response) - - self.events = ijson.sendable_list() - self.coro = ijson.items_coro(self.events, "item") def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: @@ -1304,9 +1338,9 @@ class ModelResponseIterator: if "usageMetadata" in processed_chunk: usage = ChatCompletionUsageBlock( prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"], - completion_tokens=processed_chunk["usageMetadata"][ - "candidatesTokenCount" - ], + completion_tokens=processed_chunk["usageMetadata"].get( + "candidatesTokenCount", 0 + ), total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"], ) @@ -1324,31 +1358,36 @@ class ModelResponseIterator: # Sync iterator def __iter__(self): + self.response_iterator = self.streaming_response return self def __next__(self): try: chunk = self.response_iterator.__next__() - self.coro.send(chunk) - if self.events: - event = self.events.pop(0) - json_chunk = event - return self.chunk_parser(chunk=json_chunk) - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) except StopIteration: - if self.events: # flush the events - event = self.events.pop(0) # Remove the first event - return self.chunk_parser(chunk=event) raise StopIteration except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e}") + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") # Async iterator def __aiter__(self): @@ -1358,23 +1397,27 @@ class ModelResponseIterator: async def __anext__(self): try: chunk = await self.async_response_iterator.__anext__() - self.coro.send(chunk) - if self.events: - event = self.events.pop(0) - json_chunk = event - return self.chunk_parser(chunk=json_chunk) - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) except StopAsyncIteration: - if self.events: # flush the events - event = self.events.pop(0) # Remove the first event - return self.chunk_parser(chunk=event) raise StopAsyncIteration except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e}") + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/main.py b/litellm/main.py index 69ce61fab..55ac01935 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -476,6 +476,15 @@ def mock_completion( model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) + elif ( + isinstance(mock_response, str) and mock_response == "litellm.RateLimitError" + ): + raise litellm.RateLimitError( + message="this is a mock rate limit error", + status_code=getattr(mock_response, "status_code", 429), # type: ignore + llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + model=model, + ) time_delay = kwargs.get("mock_delay", None) if time_delay is not None: time.sleep(time_delay) @@ -676,6 +685,8 @@ def completion( client = kwargs.get("client", None) ### Admin Controls ### no_log = kwargs.get("no-log", False) + ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 + messages = deepcopy(messages) ######## end of unpacking kwargs ########### openai_params = [ "functions", @@ -1828,6 +1839,7 @@ def completion( logging_obj=logging, acompletion=acompletion, timeout=timeout, # type: ignore + custom_llm_provider="openrouter", ) ## LOGGING logging.post_call( @@ -2199,13 +2211,33 @@ def completion( # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - if ( - "aws_bedrock_client" in optional_params - ): # use old bedrock flow for aws_bedrock_client users. - response = bedrock.completion( + if "aws_bedrock_client" in optional_params: + verbose_logger.warning( + "'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication." + ) + # Extract credentials for legacy boto3 client and pass thru to httpx + aws_bedrock_client = optional_params.pop("aws_bedrock_client") + creds = aws_bedrock_client._get_credentials().get_frozen_credentials() + + if creds.access_key: + optional_params["aws_access_key_id"] = creds.access_key + if creds.secret_key: + optional_params["aws_secret_access_key"] = creds.secret_key + if creds.token: + optional_params["aws_session_token"] = creds.token + if ( + "aws_region_name" not in optional_params + or optional_params["aws_region_name"] is None + ): + optional_params["aws_region_name"] = ( + aws_bedrock_client.meta.region_name + ) + + if model in litellm.BEDROCK_CONVERSE_MODELS: + response = bedrock_converse_chat_completion.completion( model=model, messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, + custom_prompt_dict=custom_prompt_dict, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -2215,63 +2247,27 @@ def completion( logging_obj=logging, extra_headers=extra_headers, timeout=timeout, + acompletion=acompletion, + client=client, + ) + else: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - if "ai21" in model: - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - if model.startswith("anthropic"): - response = bedrock_converse_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) - else: - response = bedrock_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) if optional_params.get("stream", False): ## LOGGING logging.post_call( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 92f97ebe5..7f08b9eb1 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1486,6 +1486,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-001": { @@ -1511,6 +1512,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0514": { @@ -1536,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1561,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1584,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { @@ -2007,6 +2012,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -2023,6 +2029,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": { diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 938e74b5e..c62fc9944 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,54 +1,5 @@ -# model_list: -# - model_name: my-fake-model -# litellm_params: -# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 -# api_key: my-fake-key -# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 -# mock_response: "Hello world 1" -# model_info: -# max_input_tokens: 0 # trigger context window fallback -# - model_name: my-fake-model -# litellm_params: -# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 -# api_key: my-fake-key -# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 -# mock_response: "Hello world 2" -# model_info: -# max_input_tokens: 0 - -# router_settings: -# enable_pre_call_checks: True - - -# litellm_settings: -# failure_callback: ["langfuse"] - model_list: - - model_name: summarize + - model_name: claude-3-5-sonnet # all requests where model not in your config go to this deployment litellm_params: - model: openai/gpt-4o - rpm: 10000 - tpm: 12000000 - api_key: os.environ/OPENAI_API_KEY - mock_response: Hello world 1 - - - model_name: summarize-l - litellm_params: - model: claude-3-5-sonnet-20240620 - rpm: 4000 - tpm: 400000 - api_key: os.environ/ANTHROPIC_API_KEY - mock_response: Hello world 2 - -litellm_settings: - num_retries: 3 - request_timeout: 120 - allowed_fails: 3 - # fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}] - # context_window_fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}] - - - -router_settings: - routing_strategy: simple-shuffle - enable_pre_call_checks: true. + model: "openai/*" + mock_response: "litellm.RateLimitError" \ No newline at end of file diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index b8c26fd2a..ede853094 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -1,10 +1,11 @@ model_list: +- model_name: claude-3-5-sonnet + litellm_params: + model: anthropic/claude-3-5-sonnet - model_name: gemini-1.5-flash-gemini litellm_params: - model: gemini/gemini-1.5-flash -- model_name: gemini-1.5-flash-gemini - litellm_params: - model: gemini/gemini-1.5-flash + model: vertex_ai_beta/gemini-1.5-flash + api_base: https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash - litellm_params: api_base: http://0.0.0.0:8080 api_key: '' diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 640c7695a..1f1aaf0ee 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1622,7 +1622,7 @@ class ProxyException(Exception): } -class CommonProxyErrors(enum.Enum): +class CommonProxyErrors(str, enum.Enum): db_not_connected_error = "DB not connected" no_llm_router = "No models configured on proxy" not_allowed_access = "Admin-only endpoint. Not allowed to access this." diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py new file mode 100644 index 000000000..218032e01 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -0,0 +1,173 @@ +import ast +import traceback +from base64 import b64encode + +import httpx +from fastapi import ( + APIRouter, + Depends, + FastAPI, + HTTPException, + Request, + Response, + status, +) +from fastapi.responses import StreamingResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ProxyException +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +async_client = httpx.AsyncClient() + + +async def set_env_variables_in_header(custom_headers: dict): + """ + checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc + + only runs for headers defined on config.yaml + + example header can be + + {"Authorization": "bearer os.environ/COHERE_API_KEY"} + """ + headers = {} + for key, value in custom_headers.items(): + # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys + # we can then get the b64 encoded keys here + if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": + # langfuse requires b64 encoded headers - we construct that here + _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] + _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] + if isinstance( + _langfuse_public_key, str + ) and _langfuse_public_key.startswith("os.environ/"): + _langfuse_public_key = litellm.get_secret(_langfuse_public_key) + if isinstance( + _langfuse_secret_key, str + ) and _langfuse_secret_key.startswith("os.environ/"): + _langfuse_secret_key = litellm.get_secret(_langfuse_secret_key) + headers["Authorization"] = "Basic " + b64encode( + f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") + ).decode("ascii") + else: + # for all other headers + headers[key] = value + if isinstance(value, str) and "os.environ/" in value: + verbose_proxy_logger.debug( + "pass through endpoint - looking up 'os.environ/' variable" + ) + # get string section that is os.environ/ + start_index = value.find("os.environ/") + _variable_name = value[start_index:] + + verbose_proxy_logger.debug( + "pass through endpoint - getting secret for variable name: %s", + _variable_name, + ) + _secret_value = litellm.get_secret(_variable_name) + new_value = value.replace(_variable_name, _secret_value) + headers[key] = new_value + return headers + + +async def pass_through_request(request: Request, target: str, custom_headers: dict): + try: + + url = httpx.URL(target) + headers = custom_headers + + request_body = await request.body() + _parsed_body = ast.literal_eval(request_body.decode("utf-8")) + + verbose_proxy_logger.debug( + "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( + url, headers, _parsed_body + ) + ) + + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=request.query_params, + json=_parsed_body, + ) + + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.text) + + content = await response.aread() + return Response( + content=content, + status_code=response.status_code, + headers=dict(response.headers), + ) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.pass through endpoint(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +def create_pass_through_route(endpoint, target, custom_headers=None): + async def endpoint_func(request: Request): + return await pass_through_request(request, target, custom_headers) + + return endpoint_func + + +async def initialize_pass_through_endpoints(pass_through_endpoints: list): + + verbose_proxy_logger.debug("initializing pass through endpoints") + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import app, premium_user + + for endpoint in pass_through_endpoints: + _target = endpoint.get("target", None) + _path = endpoint.get("path", None) + _custom_headers = endpoint.get("headers", None) + _custom_headers = await set_env_variables_in_header( + custom_headers=_custom_headers + ) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) + + if _target is None: + continue + + verbose_proxy_logger.debug("adding pass through endpoint: %s", _path) + + app.add_api_route( + path=_path, + endpoint=create_pass_through_route(_path, _target, _custom_headers), + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=_dependencies, + ) + + verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py new file mode 100644 index 000000000..6ee09c22b --- /dev/null +++ b/litellm/proxy/prisma_migration.py @@ -0,0 +1,68 @@ +# What is this? +## Script to apply initial prisma migration on Docker setup + +import os +import subprocess +import sys +import time + +sys.path.insert( + 0, os.path.abspath("./") +) # Adds the parent directory to the system path +from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + +if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + +# Check if DATABASE_URL is not set +database_url = os.getenv("DATABASE_URL") +if not database_url: + # Check if all required variables are provided + database_host = os.getenv("DATABASE_HOST") + database_username = os.getenv("DATABASE_USERNAME") + database_password = os.getenv("DATABASE_PASSWORD") + database_name = os.getenv("DATABASE_NAME") + + if database_host and database_username and database_password and database_name: + # Construct DATABASE_URL from the provided variables + database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}" + os.environ["DATABASE_URL"] = database_url + else: + print( # noqa + "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa + ) + exit(1) + +# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations +direct_url = os.getenv("DIRECT_URL") +if not direct_url: + os.environ["DIRECT_URL"] = database_url + +# Apply migrations +retry_count = 0 +max_retries = 3 +exit_code = 1 + +while retry_count < max_retries and exit_code != 0: + retry_count += 1 + print(f"Attempt {retry_count}...") # noqa + + # Run the Prisma db push command + result = subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"], capture_output=True + ) + exit_code = result.returncode + + if exit_code != 0 and retry_count < max_retries: + print("Retrying in 10 seconds...") # noqa + time.sleep(10) + +if exit_code != 0: + print(f"Unable to push database changes after {max_retries} retries.") # noqa + exit(1) + +print("Database push successful!") # noqa diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 6e6d1f4a9..e98704642 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -442,6 +442,20 @@ def run_server( db_connection_pool_limit = 100 db_connection_timeout = 60 + ### DECRYPT ENV VAR ### + + from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + + if ( + os.getenv("USE_AWS_KMS", None) is not None + and os.getenv("USE_AWS_KMS") == "True" + ): + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + if config is not None: """ Allow user to pass in db url via config @@ -459,6 +473,7 @@ def run_server( proxy_config = ProxyConfig() _config = asyncio.run(proxy_config.get_config(config_file_path=config)) + ### LITELLM SETTINGS ### litellm_settings = _config.get("litellm_settings", None) if ( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 0c0365f43..9f2324e51 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -20,11 +20,23 @@ model_list: general_settings: master_key: sk-1234 - alerting: ["slack", "email"] - public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] - + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + auth: true # 👈 Key change to use LiteLLM Auth / Keys + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + content-type: application/json + accept: application/json + - path: "/api/public/ingestion" + target: "https://us.cloud.langfuse.com/api/public/ingestion" + auth: true + headers: + LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" + LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" litellm_settings: + return_response_headers: true success_callback: ["prometheus"] callbacks: ["otel", "hide_secrets"] failure_callback: ["prometheus"] @@ -34,6 +46,5 @@ litellm_settings: - user - metadata - metadata.generation_name - cache: True diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 553c8a48c..0577ec0a0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -161,6 +161,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( router as key_management_router, ) from litellm.proxy.management_endpoints.team_endpoints import router as team_router +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + initialize_pass_through_endpoints, +) from litellm.proxy.secret_managers.aws_secret_manager import ( load_aws_kms, load_aws_secret_manager, @@ -590,7 +593,7 @@ async def _PROXY_failure_handler( _model_id = _metadata.get("model_info", {}).get("id", "") _model_group = _metadata.get("model_group", "") api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params) - _exception_string = str(_exception)[:500] + _exception_string = str(_exception) error_log = LiteLLM_ErrorLogs( request_id=str(uuid.uuid4()), @@ -1856,6 +1859,11 @@ class ProxyConfig: user_custom_key_generate = get_instance_fn( value=custom_key_generate, config_file_path=config_file_path ) + ## pass through endpoints + if general_settings.get("pass_through_endpoints", None) is not None: + await initialize_pass_through_endpoints( + pass_through_endpoints=general_settings["pass_through_endpoints"] + ) ## dynamodb database_type = general_settings.get("database_type", None) if database_type is not None and ( @@ -7503,7 +7511,9 @@ async def login(request: Request): # Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin user_role = LitellmUserRoles.PROXY_ADMIN user_id = username - key_user_id = user_id + + # we want the key created to have PROXY_ADMIN_PERMISSIONS + key_user_id = litellm_proxy_admin_name if ( os.getenv("PROXY_ADMIN_ID", None) is not None and os.environ["PROXY_ADMIN_ID"] == user_id @@ -7523,7 +7533,17 @@ async def login(request: Request): if os.getenv("DATABASE_URL") is not None: response = await generate_key_helper_fn( request_type="key", - **{"user_role": LitellmUserRoles.PROXY_ADMIN, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"}, # type: ignore + **{ + "user_role": LitellmUserRoles.PROXY_ADMIN, + "duration": "2hr", + "key_max_budget": 5, + "models": [], + "aliases": {}, + "config": {}, + "spend": 0, + "user_id": key_user_id, + "team_id": "litellm-dashboard", + }, # type: ignore ) else: raise ProxyException( diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py index 8dd6772cf..c4afaedc2 100644 --- a/litellm/proxy/secret_managers/aws_secret_manager.py +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -8,9 +8,13 @@ Requires: * `pip install boto3>=1.28.57` """ -import litellm +import ast +import base64 import os -from typing import Optional +import re +from typing import Any, Dict, Optional + +import litellm from litellm.proxy._types import KeyManagementSystem @@ -57,3 +61,99 @@ def load_aws_kms(use_aws_kms: Optional[bool]): except Exception as e: raise e + + +class AWSKeyManagementService_V2: + """ + V2 Clean Class for decrypting keys from AWS KeyManagementService + """ + + def __init__(self) -> None: + self.validate_environment() + self.kms_client = self.load_aws_kms(use_aws_kms=True) + + def validate_environment( + self, + ): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + ## CHECK IF LICENSE IN ENV ## - premium feature + if os.getenv("LITELLM_LICENSE", None) is None: + raise ValueError( + "AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment." + ) + + def load_aws_kms(self, use_aws_kms: Optional[bool]): + if use_aws_kms is None or use_aws_kms is False: + return + try: + import boto3 + + validate_environment() + + # Create a Secrets Manager client + kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) + + return kms_client + except Exception as e: + raise e + + def decrypt_value(self, secret_name: str) -> Any: + if self.kms_client is None: + raise ValueError("kms_client is None") + encrypted_value = os.getenv(secret_name, None) + if encrypted_value is None: + raise Exception( + "AWS KMS - Encrypted Value of Key={} is None".format(secret_name) + ) + if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"): + encrypted_value = encrypted_value.replace("aws_kms/", "") + + # Decode the base64 encoded ciphertext + ciphertext_blob = base64.b64decode(encrypted_value) + + # Set up the parameters for the decrypt call + params = {"CiphertextBlob": ciphertext_blob} + # Perform the decryption + response = self.kms_client.decrypt(**params) + + # Extract and decode the plaintext + plaintext = response["Plaintext"] + secret = plaintext.decode("utf-8") + if isinstance(secret, str): + secret = secret.strip() + try: + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + except Exception: + pass + + return secret + + +""" +- look for all values in the env with `aws_kms/` +- decrypt keys +- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist. +""" + + +def decrypt_env_var() -> Dict[str, Any]: + # setup client class + aws_kms = AWSKeyManagementService_V2() + # iterate through env - for `aws_kms/` + new_values = {} + for k, v in os.environ.items(): + if ( + k is not None + and isinstance(k, str) + and k.lower().startswith("litellm_secret_aws_kms") + ) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")): + decrypted_value = aws_kms.decrypt_value(secret_name=k) + # reset env var + k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE) + new_values[k] = decrypted_value + + return new_values diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index 1fbd95b3c..87bd85078 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -817,9 +817,9 @@ async def get_global_spend_report( default=None, description="Time till which to view spend", ), - group_by: Optional[Literal["team", "customer"]] = fastapi.Query( + group_by: Optional[Literal["team", "customer", "api_key"]] = fastapi.Query( default="team", - description="Group spend by internal team or customer", + description="Group spend by internal team or customer or api_key", ), ): """ @@ -860,7 +860,7 @@ async def get_global_spend_report( start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - from litellm.proxy.proxy_server import prisma_client + from litellm.proxy.proxy_server import premium_user, prisma_client try: if prisma_client is None: @@ -868,6 +868,12 @@ async def get_global_spend_report( f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" ) + if premium_user is not True: + verbose_proxy_logger.debug("accessing /spend/report but not a premium user") + raise ValueError( + "/spend/report endpoint " + CommonProxyErrors.not_premium_user.value + ) + if group_by == "team": # first get data from spend logs -> SpendByModelApiKey # then read data from "SpendByModelApiKey" to format the response obj @@ -992,6 +998,48 @@ async def get_global_spend_report( return [] return db_response + elif group_by == "api_key": + sql_query = """ + WITH SpendByModelApiKey AS ( + SELECT + sl.api_key, + sl.model, + SUM(sl.spend) AS model_cost, + SUM(sl.prompt_tokens) AS model_input_tokens, + SUM(sl.completion_tokens) AS model_output_tokens + FROM + "LiteLLM_SpendLogs" sl + WHERE + sl."startTime" BETWEEN $1::date AND $2::date + GROUP BY + sl.api_key, + sl.model + ) + SELECT + api_key, + SUM(model_cost) AS total_cost, + SUM(model_input_tokens) AS total_input_tokens, + SUM(model_output_tokens) AS total_output_tokens, + jsonb_agg(jsonb_build_object( + 'model', model, + 'total_cost', model_cost, + 'total_input_tokens', model_input_tokens, + 'total_output_tokens', model_output_tokens + )) AS model_details + FROM + SpendByModelApiKey + GROUP BY + api_key + ORDER BY + total_cost DESC; + """ + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + if db_response is None: + return [] + + return db_response except Exception as e: raise HTTPException( diff --git a/litellm/proxy/tests/test_pass_through_langfuse.py b/litellm/proxy/tests/test_pass_through_langfuse.py new file mode 100644 index 000000000..dfc91ee1b --- /dev/null +++ b/litellm/proxy/tests/test_pass_through_langfuse.py @@ -0,0 +1,14 @@ +from langfuse import Langfuse + +langfuse = Langfuse( + host="http://localhost:4000", + public_key="anything", + secret_key="anything", +) + +print("sending langfuse trace request") +trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough") +print("flushing langfuse request") +langfuse.flush() + +print("flushed langfuse request") diff --git a/litellm/router.py b/litellm/router.py index e2f7ce8b2..0d082de5d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -105,7 +105,9 @@ class Router: def __init__( self, - model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None, + model_list: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, ## ASSISTANTS API ## assistants_config: Optional[AssistantsTypedDict] = None, ## CACHING ## @@ -154,6 +156,7 @@ class Router: cooldown_time: Optional[ float ] = None, # (seconds) time to cooldown a deployment after failure + disable_cooldowns: Optional[bool] = None, routing_strategy: Literal[ "simple-shuffle", "least-busy", @@ -305,6 +308,7 @@ class Router: self.allowed_fails = allowed_fails or litellm.allowed_fails self.cooldown_time = cooldown_time or 60 + self.disable_cooldowns = disable_cooldowns self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown @@ -2988,6 +2992,8 @@ class Router: the exception is not one that should be immediately retried (e.g. 401) """ + if self.disable_cooldowns is True: + return if deployment is None: return @@ -3028,24 +3034,50 @@ class Router: exception_status = 500 _should_retry = litellm._should_retry(status_code=exception_status) - if updated_fails > allowed_fails or _should_retry == False: + if updated_fails > allowed_fails or _should_retry is False: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls - cached_value = self.cache.get_cache(key=cooldown_key) + cached_value = self.cache.get_cache( + key=cooldown_key + ) # [(deployment_id, {last_error_str, last_error_status_code})] + cached_value_deployment_ids = [] + if ( + cached_value is not None + and isinstance(cached_value, list) + and len(cached_value) > 0 + and isinstance(cached_value[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cached_value] verbose_router_logger.debug(f"adding {deployment} to cooldown models") # update value - try: - if deployment in cached_value: + if cached_value is not None and len(cached_value_deployment_ids) > 0: + if deployment in cached_value_deployment_ids: pass else: - cached_value = cached_value + [deployment] + cached_value = cached_value + [ + ( + deployment, + { + "Exception Received": str(original_exception), + "Status Code": str(exception_status), + }, + ) + ] # save updated value self.cache.set_cache( value=cached_value, key=cooldown_key, ttl=cooldown_time ) - except: - cached_value = [deployment] + else: + cached_value = [ + ( + deployment, + { + "Exception Received": str(original_exception), + "Status Code": str(exception_status), + }, + ) + ] # save updated value self.cache.set_cache( value=cached_value, key=cooldown_key, ttl=cooldown_time @@ -3061,7 +3093,33 @@ class Router: key=deployment, value=updated_fails, ttl=cooldown_time ) - async def _async_get_cooldown_deployments(self): + async def _async_get_cooldown_deployments(self) -> List[str]: + """ + Async implementation of '_get_cooldown_deployments' + """ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + # get the current cooldown list for that minute + cooldown_key = f"{current_minute}:cooldown_models" + + # ---------------------- + # Return cooldown models + # ---------------------- + cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or [] + + cached_value_deployment_ids = [] + if ( + cooldown_models is not None + and isinstance(cooldown_models, list) + and len(cooldown_models) > 0 + and isinstance(cooldown_models[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cooldown_models] + + verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") + return cached_value_deployment_ids + + async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]: """ Async implementation of '_get_cooldown_deployments' """ @@ -3078,7 +3136,7 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models - def _get_cooldown_deployments(self): + def _get_cooldown_deployments(self) -> List[str]: """ Get the list of models being cooled down for this minute """ @@ -3092,8 +3150,17 @@ class Router: # ---------------------- cooldown_models = self.cache.get_cache(key=cooldown_key) or [] + cached_value_deployment_ids = [] + if ( + cooldown_models is not None + and isinstance(cooldown_models, list) + and len(cooldown_models) > 0 + and isinstance(cooldown_models[0], tuple) + ): + cached_value_deployment_ids = [cv[0] for cv in cooldown_models] + verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") - return cooldown_models + return cached_value_deployment_ids def _get_healthy_deployments(self, model: str): _all_deployments: list = [] @@ -3970,16 +4037,36 @@ class Router: Augment litellm info with additional params set in `model_info`. + For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set. + Returns - ModelInfo - If found -> typed dict with max tokens, input cost, etc. + + Raises: + - ValueError -> If model is not mapped yet """ - ## SET MODEL NAME + ## GET BASE MODEL base_model = deployment.get("model_info", {}).get("base_model", None) if base_model is None: base_model = deployment.get("litellm_params", {}).get("base_model", None) - model = base_model or deployment.get("litellm_params", {}).get("model", None) - ## GET LITELLM MODEL INFO + model = base_model + + ## GET PROVIDER + _model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=deployment.get("litellm_params", {}).get("model", ""), + litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})), + ) + + ## SET MODEL TO 'model=' - if base_model is None + not azure + if custom_llm_provider == "azure" and base_model is None: + verbose_router_logger.error( + "Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models" + ) + elif custom_llm_provider != "azure": + model = _model + + ## GET LITELLM MODEL INFO - raises exception, if model is not mapped model_info = litellm.get_model_info(model=model) ## CHECK USER SET MODEL INFO @@ -4365,7 +4452,7 @@ class Router: """ Filter out model in model group, if: - - model context window < message length + - model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models - filter models above rpm limits - if region given, filter out models not in that region / unknown region - [TODO] function call and model doesn't support function calling @@ -4382,6 +4469,11 @@ class Router: try: input_tokens = litellm.token_counter(messages=messages) except Exception as e: + verbose_router_logger.error( + "litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format( + str(e) + ) + ) return _returned_deployments _context_window_error = False @@ -4425,7 +4517,7 @@ class Router: ) continue except Exception as e: - verbose_router_logger.debug("An error occurs - {}".format(str(e))) + verbose_router_logger.error("An error occurs - {}".format(str(e))) _litellm_params = deployment.get("litellm_params", {}) model_id = deployment.get("model_info", {}).get("id", "") @@ -4686,7 +4778,7 @@ class Router: if _allowed_model_region is None: _allowed_model_region = "n/a" raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}" + f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}" ) if ( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 901d68ef3..c4705325b 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -880,6 +880,208 @@ Using this JSON schema: mock_call.assert_called_once() +def vertex_httpx_mock_post_valid_response(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n' + } + ], + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.09790669, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.11736965, + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.1261379, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08601588, + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.083441176, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0355444, + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071981624, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08108212, + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 60, + "candidatesTokenCount": 55, + "totalTokenCount": 115, + }, + } + return mock_response + + +def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + {"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'} + ], + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.09790669, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.11736965, + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.1261379, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08601588, + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.083441176, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0355444, + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071981624, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.08108212, + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 60, + "candidatesTokenCount": 55, + "totalTokenCount": 115, + }, + } + return mock_response + + +@pytest.mark.parametrize( + "model, vertex_location, supports_response_schema", + [ + ("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True), + ("vertex_ai_beta/gemini-1.5-flash", "us-central1", False), + ], +) +@pytest.mark.parametrize( + "invalid_response", + [True, False], +) +@pytest.mark.parametrize( + "enforce_validation", + [True, False], +) +@pytest.mark.asyncio +async def test_gemini_pro_json_schema_args_sent_httpx( + model, + supports_response_schema, + vertex_location, + invalid_response, + enforce_validation, +): + load_vertex_ai_credentials() + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + litellm.set_verbose = True + messages = [{"role": "user", "content": "List 5 cookie recipes"}] + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + response_schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "recipe_name": { + "type": "string", + }, + }, + "required": ["recipe_name"], + }, + } + + client = HTTPHandler() + + httpx_response = MagicMock() + if invalid_response is True: + httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response + else: + httpx_response.side_effect = vertex_httpx_mock_post_valid_response + with patch.object(client, "post", new=httpx_response) as mock_call: + try: + _ = completion( + model=model, + messages=messages, + response_format={ + "type": "json_object", + "response_schema": response_schema, + "enforce_validation": enforce_validation, + }, + vertex_location=vertex_location, + client=client, + ) + if invalid_response is True and enforce_validation is True: + pytest.fail("Expected this to fail") + except litellm.JSONSchemaValidationError as e: + if invalid_response is False and "claude-3" not in model: + pytest.fail("Expected this to pass. Got={}".format(e)) + + mock_call.assert_called_once() + print(mock_call.call_args.kwargs) + print(mock_call.call_args.kwargs["json"]["generationConfig"]) + + if supports_response_schema: + assert ( + "response_schema" + in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + else: + assert ( + "response_schema" + not in mock_call.call_args.kwargs["json"]["generationConfig"] + ) + assert ( + "Use this JSON schema:" + in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] + ) + + @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.asyncio async def test_gemini_pro_httpx_custom_api_base(provider): diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 24eefceef..fb4ba7556 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -25,6 +25,7 @@ from litellm import ( completion_cost, embedding, ) +from litellm.llms.bedrock_httpx import BedrockLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler # litellm.num_retries = 3 @@ -217,6 +218,234 @@ def test_completion_bedrock_claude_sts_client_auth(): pytest.fail(f"Error occurred: {e}") +@pytest.fixture() +def bedrock_session_token_creds(): + print("\ncalling oidc auto to get aws_session_token credentials") + import os + + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") + + bllm = BedrockLLM() + if aws_session_token is not None: + # For local testing + creds = bllm.get_credentials( + aws_region_name=aws_region_name, + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + aws_session_token=aws_session_token, + ) + else: + # For circle-ci testing + # aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] + # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually + aws_role_name = ( + "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci" + ) + aws_web_identity_token = "oidc/circleci_v2/" + + creds = bllm.get_credentials( + aws_region_name=aws_region_name, + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + return creds + + +def process_stream_response(res, messages): + import types + + if isinstance(res, litellm.utils.CustomStreamWrapper): + chunks = [] + for part in res: + chunks.append(part) + text = part.choices[0].delta.content or "" + print(text, end="") + res = litellm.stream_chunk_builder(chunks, messages=messages) + else: + raise ValueError("Response object is not a streaming response") + + return res + + +@pytest.mark.skipif( + os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, + reason="Cannot run without being in CircleCI Runner", +) +def test_completion_bedrock_claude_aws_session_token(bedrock_session_token_creds): + print("\ncalling bedrock claude with aws_session_token auth") + + import os + + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_access_key_id = bedrock_session_token_creds.access_key + aws_secret_access_key = bedrock_session_token_creds.secret_key + aws_session_token = bedrock_session_token_creds.token + + try: + litellm.set_verbose = True + + response_1 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + print(response_1) + assert len(response_1.choices) > 0 + assert len(response_1.choices[0].message.content) > 0 + + # This second call is to verify that the cache isn't breaking anything + response_2 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=5, + temperature=0.2, + aws_region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + print(response_2) + assert len(response_2.choices) > 0 + assert len(response_2.choices[0].message.content) > 0 + + # This third call is to verify that the cache isn't used for a different region + response_3 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + print(response_3) + assert len(response_3.choices) > 0 + assert len(response_3.choices[0].message.content) > 0 + + # This fourth call is to verify streaming api works + response_4 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + stream=True, + ) + response_4 = process_stream_response(response_4, messages) + print(response_4) + assert len(response_4.choices) > 0 + assert len(response_4.choices[0].message.content) > 0 + + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.skipif( + os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, + reason="Cannot run without being in CircleCI Runner", +) +def test_completion_bedrock_claude_aws_bedrock_client(bedrock_session_token_creds): + print("\ncalling bedrock claude with aws_session_token auth") + + import os + + import boto3 + from botocore.client import Config + + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_access_key_id = bedrock_session_token_creds.access_key + aws_secret_access_key = bedrock_session_token_creds.secret_key + aws_session_token = bedrock_session_token_creds.token + + aws_bedrock_client_west = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=Config(read_timeout=600), + ) + + try: + litellm.set_verbose = True + + response_1 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_bedrock_client=aws_bedrock_client_west, + ) + print(response_1) + assert len(response_1.choices) > 0 + assert len(response_1.choices[0].message.content) > 0 + + # This second call is to verify that the cache isn't breaking anything + response_2 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=5, + temperature=0.2, + aws_bedrock_client=aws_bedrock_client_west, + ) + print(response_2) + assert len(response_2.choices) > 0 + assert len(response_2.choices[0].message.content) > 0 + + # This third call is to verify that the cache isn't used for a different region + aws_bedrock_client_east = boto3.client( + service_name="bedrock-runtime", + region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=Config(read_timeout=600), + ) + + response_3 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_bedrock_client=aws_bedrock_client_east, + ) + print(response_3) + assert len(response_3.choices) > 0 + assert len(response_3.choices[0].message.content) > 0 + + # This fourth call is to verify streaming api works + response_4 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_bedrock_client=aws_bedrock_client_east, + stream=True, + ) + response_4 = process_stream_response(response_4, messages) + print(response_4) + assert len(response_4.choices) > 0 + assert len(response_4.choices[0].message.content) > 0 + + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_completion_bedrock_claude_sts_client_auth() @@ -489,61 +718,6 @@ def test_completion_claude_3_base64(): pytest.fail(f"An exception occurred - {str(e)}") -def test_provisioned_throughput(): - try: - litellm.set_verbose = True - import io - import json - - import botocore - import botocore.session - from botocore.stub import Stubber - - bedrock_client = botocore.session.get_session().create_client( - "bedrock-runtime", region_name="us-east-1" - ) - - expected_params = { - "accept": "application/json", - "body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", ' - '"max_tokens_to_sample": 256}', - "contentType": "application/json", - "modelId": "provisioned-model-arn", - } - response_from_bedrock = { - "body": io.StringIO( - json.dumps( - { - "completion": " Here is a short poem about the sky:", - "stop_reason": "max_tokens", - "stop": None, - } - ) - ), - "contentType": "contentType", - "ResponseMetadata": {"HTTPStatusCode": 200}, - } - - with Stubber(bedrock_client) as stubber: - stubber.add_response( - "invoke_model", - service_response=response_from_bedrock, - expected_params=expected_params, - ) - response = litellm.completion( - model="bedrock/anthropic.claude-instant-v1", - model_id="provisioned-model-arn", - messages=[{"content": "Hello, how are you?", "role": "user"}], - aws_bedrock_client=bedrock_client, - ) - print("response stubbed", response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -# test_provisioned_throughput() - - def test_completion_bedrock_mistral_completion_auth(): print("calling bedrock mistral completion params auth") import os @@ -682,3 +856,56 @@ async def test_bedrock_custom_prompt_template(): prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"] assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>" mock_client_post.assert_called_once() + + +def test_completion_bedrock_external_client_region(): + print("\ncalling bedrock claude external client auth") + import os + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_region_name = "us-east-1" + + os.environ.pop("AWS_ACCESS_KEY_ID", None) + os.environ.pop("AWS_SECRET_ACCESS_KEY", None) + + client = HTTPHandler() + + try: + import boto3 + + litellm.set_verbose = True + + bedrock = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com", + ) + with patch.object(client, "post", new=Mock()) as mock_client_post: + try: + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_bedrock_client=bedrock, + client=client, + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pass + + print(f"mock_client_post.call_args: {mock_client_post.call_args}") + assert "us-east-1" in mock_client_post.call_args.kwargs["url"] + + mock_client_post.assert_called_once() + + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 5138e9b61..1c10ef461 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries = 3 +# litellm.num_retries=3 litellm.cache = None litellm.success_callback = [] user_message = "Write a short poem about the sky" diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 3d8cb3c2a..fb390bb48 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -249,6 +249,25 @@ def test_completion_azure_exception(): # test_completion_azure_exception() +def test_azure_embedding_exceptions(): + try: + + response = litellm.embedding( + model="azure/azure-embedding-model", + input="hello", + messages="hello", + ) + pytest.fail(f"Bad request this should have failed but got {response}") + + except Exception as e: + print(vars(e)) + # CRUCIAL Test - Ensures our exceptions are readable and not overly complicated. some users have complained exceptions will randomly have another exception raised in our exception mapping + assert ( + e.message + == "litellm.APIError: AzureException APIError - Embeddings.create() got an unexpected keyword argument 'messages'" + ) + + async def asynctest_completion_azure_exception(): try: import openai diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 72d4d7b1b..e287946ae 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -61,7 +61,6 @@ async def test_token_single_public_key(): import jwt jwt_handler = JWTHandler() - backend_keys = { "keys": [ { diff --git a/litellm/tests/test_lakera_ai_prompt_injection.py b/litellm/tests/test_lakera_ai_prompt_injection.py index 6227eabaa..3e328c824 100644 --- a/litellm/tests/test_lakera_ai_prompt_injection.py +++ b/litellm/tests/test_lakera_ai_prompt_injection.py @@ -1,10 +1,16 @@ # What is this? ## This tests the Lakera AI integration -import sys, os, asyncio, time, random -from datetime import datetime +import asyncio +import os +import random +import sys +import time import traceback +from datetime import datetime + from dotenv import load_dotenv +from fastapi import HTTPException load_dotenv() import os @@ -12,17 +18,19 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import logging + import pytest + import litellm +from litellm import Router, mock_completion +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( _ENTERPRISE_lakeraAI_Moderation, ) -from litellm import Router, mock_completion from litellm.proxy.utils import ProxyLogging, hash_token -from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache -from litellm._logging import verbose_proxy_logger -import logging verbose_proxy_logger.setLevel(logging.DEBUG) @@ -55,10 +63,12 @@ async def test_lakera_prompt_injection_detection(): call_type="completion", ) pytest.fail(f"Should have failed") - except Exception as e: - print("Got exception: ", e) - assert "Violated content safety policy" in str(e) - pass + except HTTPException as http_exception: + print("http exception details=", http_exception.detail) + + # Assert that the laker ai response is in the exception raise + assert "lakera_ai_response" in http_exception.detail + assert "Violated content safety policy" in str(http_exception) @pytest.mark.asyncio diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py new file mode 100644 index 000000000..0f234dfa8 --- /dev/null +++ b/litellm/tests/test_pass_through_endpoints.py @@ -0,0 +1,85 @@ +import os +import sys + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds-the parent directory to the system path + +import asyncio + +import httpx + +from litellm.proxy.proxy_server import app, initialize_pass_through_endpoints + + +# Mock the async_client used in the pass_through_request function +async def mock_request(*args, **kwargs): + return httpx.Response(200, json={"message": "Mocked response"}) + + +@pytest.fixture +def client(): + return TestClient(app) + + +@pytest.mark.asyncio +async def test_pass_through_endpoint(client, monkeypatch): + # Mock the httpx.AsyncClient.request method + monkeypatch.setattr("httpx.AsyncClient.request", mock_request) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/test-endpoint", + "target": "https://api.example.com/v1/chat/completions", + "headers": {"Authorization": "Bearer test-token"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + # Make a request to the pass-through endpoint + response = client.post("/test-endpoint", json={"prompt": "Hello, world!"}) + + # Assert the response + assert response.status_code == 200 + assert response.json() == {"message": "Mocked response"} + + +@pytest.mark.asyncio +async def test_pass_through_endpoint_rerank(client): + _cohere_api_key = os.environ.get("COHERE_API_KEY") + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/v1/rerank", + "target": "https://api.cohere.com/v1/rerank", + "headers": {"Authorization": f"bearer {_cohere_api_key}"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + _json_data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada." + ], + } + + # Make a request to the pass-through endpoint + response = client.post("/v1/rerank", json=_json_data) + + print("JSON response: ", _json_data) + + # Assert the response + assert response.status_code == 200 diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index 498842661..4fb1e7134 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -1,25 +1,31 @@ # test that the proxy actually does exception mapping to the OpenAI format -import sys, os -from unittest import mock import json +import os +import sys +from unittest import mock + from dotenv import load_dotenv load_dotenv() -import os, io, asyncio +import asyncio +import io +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import openai import pytest -import litellm, openai -from fastapi.testclient import TestClient from fastapi import Response -from litellm.proxy.proxy_server import ( +from fastapi.testclient import TestClient + +import litellm +from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined + initialize, router, save_worker_config, - initialize, -) # Replace with the actual module where your FastAPI router is defined +) invalid_authentication_error_response = Response( status_code=401, @@ -66,6 +72,12 @@ def test_chat_completion_exception(client): json_response = response.json() print("keys in json response", json_response.keys()) assert json_response.keys() == {"error"} + print("ERROR=", json_response["error"]) + assert isinstance(json_response["error"]["message"], str) + assert ( + json_response["error"]["message"] + == "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys." + ) # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 3237c8084..7c59611d7 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -16,6 +16,7 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock, MagicMock, patch import httpx from dotenv import load_dotenv @@ -811,6 +812,7 @@ def test_router_context_window_check_pre_call_check(): "base_model": "azure/gpt-35-turbo", "mock_response": "Hello world 1!", }, + "model_info": {"base_model": "azure/gpt-35-turbo"}, }, { "model_name": "gpt-3.5-turbo", # openai model name @@ -1884,3 +1886,106 @@ async def test_router_model_usage(mock_response): else: print(f"allowed_fails: {allowed_fails}") raise e + + +@pytest.mark.parametrize( + "model, base_model, llm_provider", + [ + ("azure/gpt-4", None, "azure"), + ("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"), + ("gpt-4", None, "openai"), + ], +) +def test_router_get_model_info(model, base_model, llm_provider): + """ + Test if router get model info works based on provider + + For azure -> only if base model set + For openai -> use model= + """ + router = Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": model, + "api_key": "my-fake-key", + "api_base": "my-fake-base", + }, + "model_info": {"base_model": base_model, "id": "1"}, + } + ] + ) + + deployment = router.get_deployment(model_id="1") + + assert deployment is not None + + if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"): + router.get_router_model_info(deployment=deployment.to_json()) + else: + try: + router.get_router_model_info(deployment=deployment.to_json()) + pytest.fail("Expected this to raise model not mapped error") + except Exception as e: + if "This model isn't mapped yet" in str(e): + pass + + +@pytest.mark.parametrize( + "model, base_model, llm_provider", + [ + ("azure/gpt-4", None, "azure"), + ("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"), + ("gpt-4", None, "openai"), + ], +) +def test_router_context_window_pre_call_check(model, base_model, llm_provider): + """ + - For an azure model + - if no base model set + - don't enforce context window limits + """ + try: + model_list = [ + { + "model_name": "gpt-4", + "litellm_params": { + "model": model, + "api_key": "my-fake-key", + "api_base": "my-fake-base", + }, + "model_info": {"base_model": base_model, "id": "1"}, + } + ] + router = Router( + model_list=model_list, + set_verbose=True, + enable_pre_call_checks=True, + num_retries=0, + ) + + litellm.token_counter = MagicMock() + + def token_counter_side_effect(*args, **kwargs): + # Process args and kwargs if needed + return 1000000 + + litellm.token_counter.side_effect = token_counter_side_effect + try: + updated_list = router._pre_call_checks( + model="gpt-4", + healthy_deployments=model_list, + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + if llm_provider == "azure" and base_model is None: + assert len(updated_list) == 1 + else: + pytest.fail("Expected to raise an error. Got={}".format(updated_list)) + except Exception as e: + if ( + llm_provider == "azure" and base_model is not None + ) or llm_provider == "openai": + pass + except Exception as e: + pytest.fail(f"Got unexpected exception on router! - {str(e)}") diff --git a/litellm/tests/test_router_debug_logs.py b/litellm/tests/test_router_debug_logs.py index 09590e5ac..20eccf5de 100644 --- a/litellm/tests/test_router_debug_logs.py +++ b/litellm/tests/test_router_debug_logs.py @@ -1,16 +1,23 @@ -import sys, os, time -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import litellm, asyncio, logging +import asyncio +import logging + +import litellm from litellm import Router # this tests debug logs from litellm router and litellm proxy server -from litellm._logging import verbose_router_logger, verbose_logger, verbose_proxy_logger +from litellm._logging import verbose_logger, verbose_proxy_logger, verbose_router_logger # this tests debug logs from litellm router and litellm proxy server @@ -81,7 +88,7 @@ def test_async_fallbacks(caplog): # Define the expected log messages # - error request, falling back notice, success notice expected_logs = [ - "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m", + "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.\x1b[0m", "Falling back to model_group = azure/gpt-3.5-turbo", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", "Successful fallback b/w models.", diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 3042e91b3..1f1b253a0 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -742,7 +742,10 @@ def test_completion_palm_stream(): # test_completion_palm_stream() -@pytest.mark.parametrize("sync_mode", [False]) # True, +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) # , @pytest.mark.asyncio async def test_completion_gemini_stream(sync_mode): try: @@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode): pytest.fail(f"Error occurred: {e}") -@pytest.mark.asyncio -async def test_acompletion_gemini_stream(): - try: - litellm.set_verbose = True - print("Streaming gemini response") - messages = [ - # {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": "What do you know?", - }, - ] - print("testing gemini streaming") - response = await acompletion( - model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True - ) - print(f"type of response at the top: {response}") - complete_response = "" - idx = 0 - # Add any assertions here to check, the response - async for chunk in response: - print(f"chunk in acompletion gemini: {chunk}") - print(chunk.choices[0].delta) - chunk, finished = streaming_format_tests(idx, chunk) - if finished: - break - print(f"chunk: {chunk}") - complete_response += chunk - idx += 1 - print(f"completion_response: {complete_response}") - if complete_response.strip() == "": - raise Exception("Empty response received") - except litellm.APIError as e: - pass - except litellm.RateLimitError as e: - pass - except Exception as e: - if "429 Resource has been exhausted" in str(e): - pass - else: - pytest.fail(f"Error occurred: {e}") - - # asyncio.run(test_acompletion_gemini_stream()) @@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key(): # test_completion_replicate_stream() -@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta" +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "" def test_vertex_ai_stream(provider): from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials @@ -1080,14 +1040,27 @@ def test_vertex_ai_stream(provider): litellm.vertex_project = "adroit-crow-413218" import random - test_models = ["gemini-1.0-pro"] + test_models = ["gemini-1.5-pro"] for model in test_models: try: print("making request", model) response = completion( model="{}/{}".format(provider, model), messages=[ - {"role": "user", "content": "write 10 line code code for saying hi"} + {"role": "user", "content": "Hey, how's it going?"}, + { + "role": "assistant", + "content": "I'm doing well. Would like to hear the rest of the story?", + }, + {"role": "user", "content": "Na"}, + { + "role": "assistant", + "content": "No problem, is there anything else i can help you with today?", + }, + { + "role": "user", + "content": "I think you're getting cut off sometimes", + }, ], stream=True, ) @@ -1104,6 +1077,8 @@ def test_vertex_ai_stream(provider): raise Exception("Empty response received") print(f"completion_response: {complete_response}") assert is_finished == True + + assert False except litellm.RateLimitError as e: pass except Exception as e: @@ -1251,6 +1226,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): messages=messages, max_tokens=10, # type: ignore stream=True, + num_retries=3, ) complete_response = "" # Add any assertions here to check the response @@ -1272,6 +1248,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): messages=messages, max_tokens=100, # type: ignore stream=True, + num_retries=3, ) complete_response = "" # Add any assertions here to check the response @@ -1290,6 +1267,8 @@ async def test_completion_replicate_llama3_streaming(sync_mode): raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") + except litellm.UnprocessableEntityError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 09715e6c1..8225b309d 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -609,3 +609,83 @@ def test_logging_trace_id(langfuse_trace_id, langfuse_existing_trace_id): litellm_logging_obj._get_trace_id(service_name="langfuse") == litellm_call_id ) + + +def test_convert_model_response_object(): + """ + Unit test to ensure model response object correctly handles openrouter errors. + """ + args = { + "response_object": { + "id": None, + "choices": None, + "created": None, + "model": None, + "object": None, + "service_tier": None, + "system_fingerprint": None, + "usage": None, + "error": { + "message": '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}', + "code": 400, + }, + }, + "model_response_object": litellm.ModelResponse( + id="chatcmpl-b88ce43a-7bfc-437c-b8cc-e90d59372cfb", + choices=[ + litellm.Choices( + finish_reason="stop", + index=0, + message=litellm.Message(content="default", role="assistant"), + ) + ], + created=1719376241, + model="openrouter/anthropic/claude-3.5-sonnet", + object="chat.completion", + system_fingerprint=None, + usage=litellm.Usage(), + ), + "response_type": "completion", + "stream": False, + "start_time": None, + "end_time": None, + "hidden_params": None, + } + + try: + litellm.convert_to_model_response_object(**args) + pytest.fail("Expected this to fail") + except Exception as e: + assert hasattr(e, "status_code") + assert e.status_code == 400 + assert hasattr(e, "message") + assert ( + e.message + == '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}' + ) + + +@pytest.mark.parametrize( + "model, expected_bool", + [ + ("vertex_ai/gemini-1.5-pro", True), + ("gemini/gemini-1.5-pro", True), + ("predibase/llama3-8b-instruct", True), + ("gpt-4o", False), + ], +) +def test_supports_response_schema(model, expected_bool): + """ + Unit tests for 'supports_response_schema' helper function. + + Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models + Should be false otherwise + """ + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + from litellm.utils import supports_response_schema + + response = supports_response_schema(model=model, custom_llm_provider=None) + + assert expected_bool == response diff --git a/litellm/types/utils.py b/litellm/types/utils.py index a63e34738..d6b7bf744 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -71,6 +71,7 @@ class ModelInfo(TypedDict, total=False): ] supported_openai_params: Required[Optional[List[str]]] supports_system_messages: Optional[bool] + supports_response_schema: Optional[bool] class GenericStreamingChunk(TypedDict): @@ -994,3 +995,8 @@ class GenericImageParsingChunk(TypedDict): type: str media_type: str data: str + + +class ResponseFormatChunk(TypedDict, total=False): + type: Required[Literal["json_object", "text"]] + response_schema: dict diff --git a/litellm/utils.py b/litellm/utils.py index 91f390f09..e86a36012 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -48,8 +48,10 @@ from tokenizers import Tokenizer import litellm import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm.litellm_core_utils +import litellm.litellm_core_utils.json_validation_rule from litellm.caching import DualCache from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.exception_mapping_utils import get_error_message from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, @@ -579,7 +581,7 @@ def client(original_function): else: return False - def post_call_processing(original_response, model): + def post_call_processing(original_response, model, optional_params: Optional[dict]): try: if original_response is None: pass @@ -594,11 +596,47 @@ def client(original_function): pass else: if isinstance(original_response, ModelResponse): - model_response = original_response.choices[ + model_response: Optional[str] = original_response.choices[ 0 - ].message.content - ### POST-CALL RULES ### - rules_obj.post_call_rules(input=model_response, model=model) + ].message.content # type: ignore + if model_response is not None: + ### POST-CALL RULES ### + rules_obj.post_call_rules( + input=model_response, model=model + ) + ### JSON SCHEMA VALIDATION ### + if ( + optional_params is not None + and "response_format" in optional_params + and isinstance( + optional_params["response_format"], dict + ) + and "type" in optional_params["response_format"] + and optional_params["response_format"]["type"] + == "json_object" + and "response_schema" + in optional_params["response_format"] + and isinstance( + optional_params["response_format"][ + "response_schema" + ], + dict, + ) + and "enforce_validation" + in optional_params["response_format"] + and optional_params["response_format"][ + "enforce_validation" + ] + is True + ): + # schema given, json response expected, and validation enforced + litellm.litellm_core_utils.json_validation_rule.validate_schema( + schema=optional_params["response_format"][ + "response_schema" + ], + response=model_response, + ) + except Exception as e: raise e @@ -867,7 +905,11 @@ def client(original_function): return result ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model or None) + post_call_processing( + original_response=result, + model=model or None, + optional_params=kwargs, + ) # [OPTIONAL] ADD TO CACHE if ( @@ -1316,7 +1358,9 @@ def client(original_function): ).total_seconds() * 1000 # return response latency in ms like openai ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model) + post_call_processing( + original_response=result, model=model, optional_params=kwargs + ) # [OPTIONAL] ADD TO CACHE if ( @@ -1847,9 +1891,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> Parameters: model (str): The model name to be checked. + custom_llm_provider (str): The provider to be checked. Returns: - bool: True if the model supports function calling, False otherwise. + bool: True if the model supports system messages, False otherwise. Raises: Exception: If the given model is not found in model_prices_and_context_window.json. @@ -1867,6 +1912,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> ) +def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool: + """ + Check if the given model + provider supports 'response_schema' as a param. + + Parameters: + model (str): The model name to be checked. + custom_llm_provider (str): The provider to be checked. + + Returns: + bool: True if the model supports response_schema, False otherwise. + + Does not raise error. Defaults to 'False'. Outputs logging.error. + """ + try: + ## GET LLM PROVIDER ## + model, custom_llm_provider, _, _ = get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider + ) + + if custom_llm_provider == "predibase": # predibase supports this globally + return True + + ## GET MODEL INFO + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + + if model_info.get("supports_response_schema", False) is True: + return True + return False + except Exception: + verbose_logger.error( + f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." + ) + return False + + def supports_function_calling(model: str) -> bool: """ Check if the given model supports function calling and return a boolean value. @@ -2324,7 +2406,9 @@ def get_optional_params( elif k == "hf_model_name" and custom_llm_provider != "sagemaker": continue elif ( - k.startswith("vertex_") and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai_beta" + k.startswith("vertex_") + and custom_llm_provider != "vertex_ai" + and custom_llm_provider != "vertex_ai_beta" ): # allow dynamically setting vertex ai init logic continue passed_params[k] = v @@ -2756,6 +2840,11 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif ( custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models @@ -2824,12 +2913,7 @@ def get_optional_params( optional_params=optional_params, ) ) - else: - optional_params = litellm.AmazonAnthropicConfig().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) - else: # bedrock httpx route + elif model in litellm.BEDROCK_CONVERSE_MODELS: optional_params = litellm.AmazonConverseConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -2840,6 +2924,11 @@ def get_optional_params( else False ), ) + else: + optional_params = litellm.AmazonAnthropicConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif "amazon" in model: # amazon titan llms _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -3755,23 +3844,18 @@ def get_supported_openai_params( return litellm.AzureOpenAIConfig().get_supported_openai_params() elif custom_llm_provider == "openrouter": return [ - "functions", - "function_call", "temperature", "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", "frequency_penalty", - "logit_bias", - "user", - "response_format", + "presence_penalty", + "repetition_penalty", "seed", - "tools", - "tool_choice", - "max_retries", + "max_tokens", + "logit_bias", + "logprobs", + "top_logprobs", + "response_format", + "stop", ] elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral": # mistal and codestral api have the exact same params @@ -3789,6 +3873,10 @@ def get_supported_openai_params( "top_p", "stop", "seed", + "tools", + "tool_choice", + "functions", + "function_call", ] elif custom_llm_provider == "huggingface": return litellm.HuggingfaceConfig().get_supported_openai_params() @@ -4434,8 +4522,7 @@ def get_max_tokens(model: str) -> Optional[int]: def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo: """ - Get a dict for the maximum tokens (context window), - input_cost_per_token, output_cost_per_token for a given model. + Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Parameters: - model (str): The name of the model. @@ -4520,6 +4607,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod mode="chat", supported_openai_params=supported_openai_params, supports_system_messages=None, + supports_response_schema=None, ) else: """ @@ -4541,36 +4629,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) elif model in litellm.model_cost: _model_info = litellm.model_cost[model] _model_info["supported_openai_params"] = supported_openai_params @@ -4584,36 +4642,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) elif split_model in litellm.model_cost: _model_info = litellm.model_cost[split_model] _model_info["supported_openai_params"] = supported_openai_params @@ -4627,40 +4655,48 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return ModelInfo( - max_tokens=_model_info.get("max_tokens", None), - max_input_tokens=_model_info.get("max_input_tokens", None), - max_output_tokens=_model_info.get("max_output_tokens", None), - input_cost_per_token=_model_info.get("input_cost_per_token", 0), - input_cost_per_character=_model_info.get( - "input_cost_per_character", None - ), - input_cost_per_token_above_128k_tokens=_model_info.get( - "input_cost_per_token_above_128k_tokens", None - ), - output_cost_per_token=_model_info.get("output_cost_per_token", 0), - output_cost_per_character=_model_info.get( - "output_cost_per_character", None - ), - output_cost_per_token_above_128k_tokens=_model_info.get( - "output_cost_per_token_above_128k_tokens", None - ), - output_cost_per_character_above_128k_tokens=_model_info.get( - "output_cost_per_character_above_128k_tokens", None - ), - litellm_provider=_model_info.get( - "litellm_provider", custom_llm_provider - ), - mode=_model_info.get("mode"), - supported_openai_params=supported_openai_params, - supports_system_messages=_model_info.get( - "supports_system_messages", None - ), - ) else: raise ValueError( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" ) + + ## PROVIDER-SPECIFIC INFORMATION + if custom_llm_provider == "predibase": + _model_info["supports_response_schema"] = True + + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_character=_model_info.get( + "input_cost_per_character", None + ), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_character=_model_info.get( + "output_cost_per_character", None + ), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + output_cost_per_character_above_128k_tokens=_model_info.get( + "output_cost_per_character_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + supports_response_schema=_model_info.get( + "supports_response_schema", None + ), + ) except Exception: raise Exception( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" @@ -5278,6 +5314,27 @@ def convert_to_model_response_object( hidden_params: Optional[dict] = None, ): received_args = locals() + ### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary + if ( + response_object is not None + and "error" in response_object + and response_object["error"] is not None + ): + error_args = {"status_code": 422, "message": "Error in response object"} + if isinstance(response_object["error"], dict): + if "code" in response_object["error"]: + error_args["status_code"] = response_object["error"]["code"] + if "message" in response_object["error"]: + if isinstance(response_object["error"]["message"], dict): + message_str = json.dumps(response_object["error"]["message"]) + else: + message_str = str(response_object["error"]["message"]) + error_args["message"] = message_str + raised_exception = Exception() + setattr(raised_exception, "status_code", error_args["status_code"]) + setattr(raised_exception, "message", error_args["message"]) + raise raised_exception + try: if response_type == "completion" and ( model_response_object is None @@ -5733,7 +5790,10 @@ def exception_type( print() # noqa try: if model: - error_str = str(original_exception) + if hasattr(original_exception, "message"): + error_str = str(original_exception.message) + else: + error_str = str(original_exception) if isinstance(original_exception, BaseException): exception_type = type(original_exception).__name__ else: @@ -5755,6 +5815,18 @@ def exception_type( _model_group = _metadata.get("model_group") _deployment = _metadata.get("deployment") extra_information = f"\nModel: {model}" + + exception_provider = "Unknown" + if ( + isinstance(custom_llm_provider, str) + and len(custom_llm_provider) > 0 + ): + exception_provider = ( + custom_llm_provider[0].upper() + + custom_llm_provider[1:] + + "Exception" + ) + if _api_base: extra_information += f"\nAPI Base: `{_api_base}`" if ( @@ -5805,10 +5877,13 @@ def exception_type( or custom_llm_provider in litellm.openai_compatible_providers ): # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) + message = get_error_message(error_obj=original_exception) + if message is None: + if hasattr(original_exception, "message"): + message = original_exception.message + else: + message = str(original_exception) + if message is not None and isinstance(message, str): message = message.replace("OPENAI", custom_llm_provider.upper()) message = message.replace("openai", custom_llm_provider) @@ -6141,7 +6216,6 @@ def exception_type( ) elif ( original_exception.status_code == 400 - or original_exception.status_code == 422 or original_exception.status_code == 413 ): exception_mapping_worked = True @@ -6151,6 +6225,14 @@ def exception_type( llm_provider="replicate", response=original_exception.response, ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise UnprocessableEntityError( + message=f"ReplicateException - {original_exception.message}", + model=model, + llm_provider="replicate", + response=original_exception.response, + ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( @@ -7254,10 +7336,17 @@ def exception_type( request=original_exception.request, ) elif custom_llm_provider == "azure": + message = get_error_message(error_obj=original_exception) + if message is None: + if hasattr(original_exception, "message"): + message = original_exception.message + else: + message = str(original_exception) + if "Internal server error" in error_str: exception_mapping_worked = True raise litellm.InternalServerError( - message=f"AzureException Internal server error - {original_exception.message}", + message=f"AzureException Internal server error - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7270,7 +7359,7 @@ def exception_type( elif "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( - message=f"AzureException ContextWindowExceededError - {original_exception.message}", + message=f"AzureException ContextWindowExceededError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7279,7 +7368,7 @@ def exception_type( elif "DeploymentNotFound" in error_str: exception_mapping_worked = True raise NotFoundError( - message=f"AzureException NotFoundError - {original_exception.message}", + message=f"AzureException NotFoundError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7299,7 +7388,7 @@ def exception_type( ): exception_mapping_worked = True raise ContentPolicyViolationError( - message=f"litellm.ContentPolicyViolationError: AzureException - {original_exception.message}", + message=f"litellm.ContentPolicyViolationError: AzureException - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7308,7 +7397,7 @@ def exception_type( elif "invalid_request_error" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException BadRequestError - {original_exception.message}", + message=f"AzureException BadRequestError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7320,7 +7409,7 @@ def exception_type( ): exception_mapping_worked = True raise AuthenticationError( - message=f"{exception_provider} AuthenticationError - {original_exception.message}", + message=f"{exception_provider} AuthenticationError - {message}", llm_provider=custom_llm_provider, model=model, litellm_debug_info=extra_information, @@ -7331,7 +7420,7 @@ def exception_type( if original_exception.status_code == 400: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException - {original_exception.message}", + message=f"AzureException - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7340,7 +7429,7 @@ def exception_type( elif original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"AzureException AuthenticationError - {original_exception.message}", + message=f"AzureException AuthenticationError - {message}", llm_provider="azure", model=model, litellm_debug_info=extra_information, @@ -7349,7 +7438,7 @@ def exception_type( elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"AzureException Timeout - {original_exception.message}", + message=f"AzureException Timeout - {message}", model=model, litellm_debug_info=extra_information, llm_provider="azure", @@ -7357,7 +7446,7 @@ def exception_type( elif original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( - message=f"AzureException BadRequestError - {original_exception.message}", + message=f"AzureException BadRequestError - {message}", model=model, llm_provider="azure", litellm_debug_info=extra_information, @@ -7366,7 +7455,7 @@ def exception_type( elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"AzureException RateLimitError - {original_exception.message}", + message=f"AzureException RateLimitError - {message}", model=model, llm_provider="azure", litellm_debug_info=extra_information, @@ -7375,7 +7464,7 @@ def exception_type( elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"AzureException ServiceUnavailableError - {original_exception.message}", + message=f"AzureException ServiceUnavailableError - {message}", model=model, llm_provider="azure", litellm_debug_info=extra_information, @@ -7384,7 +7473,7 @@ def exception_type( elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"AzureException Timeout - {original_exception.message}", + message=f"AzureException Timeout - {message}", model=model, litellm_debug_info=extra_information, llm_provider="azure", @@ -7393,7 +7482,7 @@ def exception_type( exception_mapping_worked = True raise APIError( status_code=original_exception.status_code, - message=f"AzureException APIError - {original_exception.message}", + message=f"AzureException APIError - {message}", llm_provider="azure", litellm_debug_info=extra_information, model=model, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 92f97ebe5..7f08b9eb1 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1486,6 +1486,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-001": { @@ -1511,6 +1512,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0514": { @@ -1536,6 +1538,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1561,6 +1564,7 @@ "supports_system_messages": true, "supports_function_calling": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1584,7 +1588,8 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, - "supports_tool_choice": true, + "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-flash": { @@ -2007,6 +2012,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -2023,6 +2029,7 @@ "supports_function_calling": true, "supports_vision": true, "supports_tool_choice": true, + "supports_response_schema": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": { diff --git a/pyproject.toml b/pyproject.toml index 4aee21cd9..c698a18e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.40.31" +version = "1.41.3" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -27,7 +27,7 @@ jinja2 = "^3.1.2" aiohttp = "*" requests = "^2.31.0" pydantic = "^2.0.0" -ijson = "*" +jsonschema = "^4.22.0" uvicorn = {version = "^0.22.0", optional = true} gunicorn = {version = "^22.0.0", optional = true} @@ -90,7 +90,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.40.31" +version = "1.41.3" version_files = [ "pyproject.toml:^version" ] diff --git a/requirements.txt b/requirements.txt index 00d3802da..e71ab450b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,5 +46,5 @@ aiohttp==3.9.0 # for network calls aioboto3==12.3.0 # for async sagemaker calls tenacity==8.2.3 # for retrying requests, when litellm.num_retries set pydantic==2.7.1 # proxy + openai req. -ijson==3.2.3 # for google ai studio streaming +jsonschema==4.22.0 # validating json schema #### \ No newline at end of file diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py new file mode 100644 index 000000000..803135e35 --- /dev/null +++ b/tests/test_entrypoint.py @@ -0,0 +1,59 @@ +# What is this? +## Unit tests for 'entrypoint.sh' + +import pytest +import sys +import os + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path +import litellm +import subprocess + + +@pytest.mark.skip(reason="local test") +def test_decrypt_and_reset_env(): + os.environ["DATABASE_URL"] = ( + "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" + ) + from litellm.proxy.secret_managers.aws_secret_manager import ( + decrypt_and_reset_env_var, + ) + + decrypt_and_reset_env_var() + + assert os.environ["DATABASE_URL"] is not None + assert isinstance(os.environ["DATABASE_URL"], str) + assert not os.environ["DATABASE_URL"].startswith("aws_kms/") + + print("DATABASE_URL={}".format(os.environ["DATABASE_URL"])) + + +@pytest.mark.skip(reason="local test") +def test_entrypoint_decrypt_and_reset(): + os.environ["DATABASE_URL"] = ( + "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" + ) + command = "./entrypoint.sh" + directory = ".." # Relative to the current directory + + # Run the command using subprocess + result = subprocess.run( + command, shell=True, cwd=directory, capture_output=True, text=True + ) + + # Print the output for debugging purposes + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + + # Assert the script ran successfully + assert result.returncode == 0, "The shell script did not execute successfully" + assert ( + "DECRYPTS VALUE" in result.stdout + ), "Expected output not found in script output" + assert ( + "Database push successful!" in result.stdout + ), "Expected output not found in script output" + + assert False