diff --git a/README.md b/README.md index bfdba2fa3..5d3efe355 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,36 @@ Step 4: Submit a PR with your changes! 🚀 - push your fork to your GitHub repo - submit a PR from there +### Building LiteLLM Docker Image + +Follow these instructions if you want to build / run the LiteLLM Docker Image yourself. + +Step 1: Clone the repo + +``` +git clone https://github.com/BerriAI/litellm.git +``` + +Step 2: Build the Docker Image + +Build using Dockerfile.non_root +``` +docker build -f docker/Dockerfile.non_root -t litellm_test_image . +``` + +Step 3: Run the Docker Image + +Make sure config.yaml is present in the root directory. This is your litellm proxy config file. +``` +docker run \ + -v $(pwd)/proxy_config.yaml:/app/config.yaml \ + -e DATABASE_URL="postgresql://xxxxxxxx" \ + -e LITELLM_MASTER_KEY="sk-1234" \ + -p 4000:4000 \ + litellm_test_image \ + --config /app/config.yaml --detailed_debug +``` + # Enterprise For companies that need better security, user management and professional support diff --git a/deploy/charts/litellm-helm/templates/migrations-job.yaml b/deploy/charts/litellm-helm/templates/migrations-job.yaml index fc1aacf16..010d2d1b5 100644 --- a/deploy/charts/litellm-helm/templates/migrations-job.yaml +++ b/deploy/charts/litellm-helm/templates/migrations-job.yaml @@ -13,18 +13,18 @@ spec: spec: containers: - name: prisma-migrations - image: "ghcr.io/berriai/litellm:main-stable" + image: ghcr.io/berriai/litellm-database:main-latest command: ["python", "litellm/proxy/prisma_migration.py"] workingDir: "/app" env: - {{- if .Values.db.deployStandalone }} - - name: DATABASE_URL - value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }} - {{- else if .Values.db.useExisting }} + {{- if .Values.db.useExisting }} - name: DATABASE_URL value: {{ .Values.db.url | quote }} + {{- else }} + - name: DATABASE_URL + value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }} {{- end }} - name: DISABLE_SCHEMA_UPDATE - value: "{{ .Values.migrationJob.disableSchemaUpdate }}" + value: "false" # always run the migration from the Helm PreSync hook, override the value set restartPolicy: OnFailure backoffLimit: {{ .Values.migrationJob.backoffLimit }} diff --git a/docs/my-website/docs/completion/json_mode.md b/docs/my-website/docs/completion/json_mode.md index a782bfb0a..51f76b7a6 100644 --- a/docs/my-website/docs/completion/json_mode.md +++ b/docs/my-website/docs/completion/json_mode.md @@ -75,6 +75,7 @@ Works for: - Google AI Studio - Gemini models - Vertex AI models (Gemini + Anthropic) - Bedrock Models +- Anthropic API Models diff --git a/docs/my-website/docs/completion/prefix.md b/docs/my-website/docs/completion/prefix.md index e3619a2a0..d413ad989 100644 --- a/docs/my-website/docs/completion/prefix.md +++ b/docs/my-website/docs/completion/prefix.md @@ -93,7 +93,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \ ## Check Model Support -Call `litellm.get_model_info` to check if a model/provider supports `response_format`. +Call `litellm.get_model_info` to check if a model/provider supports `prefix`. @@ -116,4 +116,4 @@ curl -X GET 'http://0.0.0.0:4000/v1/model/info' \ -H 'Authorization: Bearer $LITELLM_KEY' \ ``` - \ No newline at end of file + diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index 290e094d0..d4660b807 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -957,3 +957,69 @@ curl http://0.0.0.0:4000/v1/chat/completions \ ``` + +## Usage - passing 'user_id' to Anthropic + +LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param. + + + + +```python +response = completion( + model="claude-3-5-sonnet-20240620", + messages=messages, + user="user_123", +) +``` + + + +1. Setup config.yaml + +```yaml +model_list: + - model_name: claude-3-5-sonnet-20240620 + litellm_params: + model: anthropic/claude-3-5-sonnet-20240620 + api_key: os.environ/ANTHROPIC_API_KEY +``` + +2. Start Proxy + +``` +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl http://0.0.0.0:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{ + "model": "claude-3-5-sonnet-20240620", + "messages": [{"role": "user", "content": "What is Anthropic?"}], + "user": "user_123" + }' +``` + + + + +## All Supported OpenAI Params + +``` +"stream", +"stop", +"temperature", +"top_p", +"max_tokens", +"max_completion_tokens", +"tools", +"tool_choice", +"extra_headers", +"parallel_tool_calls", +"response_format", +"user" +``` \ No newline at end of file diff --git a/docs/my-website/docs/providers/huggingface.md b/docs/my-website/docs/providers/huggingface.md index 4620a6c5d..5297a688b 100644 --- a/docs/my-website/docs/providers/huggingface.md +++ b/docs/my-website/docs/providers/huggingface.md @@ -37,7 +37,7 @@ os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key" messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}] # e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API -response = litellm.completion( +response = completion( model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", messages=[{ "content": "Hello, how are you?","role": "user"}], stream=True @@ -165,14 +165,14 @@ Steps to use ```python import os -import litellm +from litellm import completion os.environ["HUGGINGFACE_API_KEY"] = "" # TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b # add the 'huggingface/' prefix to the model to set huggingface as the provider # set api base to your deployed api endpoint from hugging face -response = litellm.completion( +response = completion( model="huggingface/glaiveai/glaive-coder-7b", messages=[{ "content": "Hello, how are you?","role": "user"}], api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud" @@ -383,6 +383,8 @@ def default_pt(messages): #### Custom prompt templates ```python +import litellm + # Create your own custom prompt template works litellm.register_prompt_template( model="togethercomputer/LLaMA-2-7B-32K", diff --git a/docs/my-website/docs/providers/jina_ai.md b/docs/my-website/docs/providers/jina_ai.md index 499cf6709..6c13dbf1a 100644 --- a/docs/my-website/docs/providers/jina_ai.md +++ b/docs/my-website/docs/providers/jina_ai.md @@ -1,6 +1,13 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # Jina AI https://jina.ai/embeddings/ +Supported endpoints: +- /embeddings +- /rerank + ## API Key ```python # env variable @@ -8,6 +15,10 @@ os.environ['JINA_AI_API_KEY'] ``` ## Sample Usage - Embedding + + + + ```python from litellm import embedding import os @@ -19,6 +30,142 @@ response = embedding( ) print(response) ``` + + + +1. Add to config.yaml +```yaml +model_list: + - model_name: embedding-model + litellm_params: + model: jina_ai/jina-embeddings-v3 + api_key: os.environ/JINA_AI_API_KEY +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml + +# RUNNING on http://0.0.0.0:4000/ +``` + +3. Test it! + +```bash +curl -L -X POST 'http://0.0.0.0:4000/embeddings' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{"input": ["hello world"], "model": "embedding-model"}' +``` + + + + +## Sample Usage - Rerank + + + + +```python +from litellm import rerank +import os + +os.environ["JINA_AI_API_KEY"] = "sk-..." + +query = "What is the capital of the United States?" +documents = [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. is the capital of the United States.", + "Capital punishment has existed in the United States since before it was a country.", +] + +response = rerank( + model="jina_ai/jina-reranker-v2-base-multilingual", + query=query, + documents=documents, + top_n=3, +) +print(response) +``` + + + +1. Add to config.yaml +```yaml +model_list: + - model_name: rerank-model + litellm_params: + model: jina_ai/jina-reranker-v2-base-multilingual + api_key: os.environ/JINA_AI_API_KEY +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -L -X POST 'http://0.0.0.0:4000/rerank' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "model": "rerank-model", + "query": "What is the capital of the United States?", + "documents": [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. is the capital of the United States.", + "Capital punishment has existed in the United States since before it was a country." + ], + "top_n": 3 +}' +``` + + + ## Supported Models All models listed here https://jina.ai/embeddings/ are supported + +## Supported Optional Rerank Parameters + +All cohere rerank parameters are supported. + +## Supported Optional Embeddings Parameters + +``` +dimensions +``` + +## Provider-specific parameters + +Pass any jina ai specific parameters as a keyword argument to the `embedding` or `rerank` function, e.g. + + + + +```python +response = embedding( + model="jina_ai/jina-embeddings-v3", + input=["good morning from litellm"], + dimensions=1536, + my_custom_param="my_custom_value", # any other jina ai specific parameters +) +``` + + + +```bash +curl -L -X POST 'http://0.0.0.0:4000/embeddings' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{"input": ["good morning from litellm"], "model": "jina_ai/jina-embeddings-v3", "dimensions": 1536, "my_custom_param": "my_custom_value"}' +``` + + + diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index b69e8ee56..921db9e73 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \ ## **Embedding Models** #### Usage - Embedding + + + + ```python import litellm from litellm import embedding @@ -1574,6 +1578,49 @@ response = embedding( ) print(response) ``` + + + + + +1. Add model to config.yaml +```yaml +model_list: + - model_name: snowflake-arctic-embed-m-long-1731622468876 + litellm_params: + model: vertex_ai/ + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: adroit-crow-413218-a956eef1a2a8.json + +litellm_settings: + drop_params: True +``` + +2. Start Proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +3. Make Request using OpenAI Python SDK, Langchain Python SDK + +```python +import openai + +client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") + +response = client.embeddings.create( + model="snowflake-arctic-embed-m-long-1731622468876", + input = ["good morning from litellm", "this is another item"], +) + +print(response) +``` + + + + #### Supported Embedding Models All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported @@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02 | textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` | | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | +| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/", input)` | ### Supported OpenAI (Unified) Params diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index c6b9f2d45..888f424b4 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -791,9 +791,9 @@ general_settings: | store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) | | max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. | | max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. | -| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. | -| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. | -| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. | +| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. **Default is 597 seconds** | +| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. **Default is 605 seconds** | +| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. **Default is 10 seconds** | | alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) | | custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) | | allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. | diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 5867a8f23..1bd1b6c4b 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -66,10 +66,16 @@ Removes any field with `user_api_key_*` from metadata. Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response. ```python + class StandardLoggingPayload(TypedDict): id: str + trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries) call_type: str response_cost: float + response_cost_failure_debug_info: Optional[ + StandardLoggingModelCostFailureDebugInformation + ] + status: StandardLoggingPayloadStatus total_tokens: int prompt_tokens: int completion_tokens: int @@ -84,13 +90,13 @@ class StandardLoggingPayload(TypedDict): metadata: StandardLoggingMetadata cache_hit: Optional[bool] cache_key: Optional[str] - saved_cache_cost: Optional[float] - request_tags: list + saved_cache_cost: float + request_tags: list end_user: Optional[str] - requester_ip_address: Optional[str] # IP address of requester - requester_metadata: Optional[dict] # metadata passed in request in the "metadata" field + requester_ip_address: Optional[str] messages: Optional[Union[str, list, dict]] response: Optional[Union[str, list, dict]] + error_str: Optional[str] model_parameters: dict hidden_params: StandardLoggingHiddenParams @@ -99,12 +105,47 @@ class StandardLoggingHiddenParams(TypedDict): cache_key: Optional[str] api_base: Optional[str] response_cost: Optional[str] - additional_headers: Optional[dict] + additional_headers: Optional[StandardLoggingAdditionalHeaders] +class StandardLoggingAdditionalHeaders(TypedDict, total=False): + x_ratelimit_limit_requests: int + x_ratelimit_limit_tokens: int + x_ratelimit_remaining_requests: int + x_ratelimit_remaining_tokens: int + +class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): + """ + Specific metadata k,v pairs logged to integration for easier cost tracking + """ + + spend_logs_metadata: Optional[ + dict + ] # special param to log k,v pairs to spendlogs for a call + requester_ip_address: Optional[str] + requester_metadata: Optional[dict] class StandardLoggingModelInformation(TypedDict): model_map_key: str model_map_value: Optional[ModelInfo] + + +StandardLoggingPayloadStatus = Literal["success", "failure"] + +class StandardLoggingModelCostFailureDebugInformation(TypedDict, total=False): + """ + Debug information, if cost tracking fails. + + Avoid logging sensitive information like response or optional params + """ + + error_str: Required[str] + traceback_str: Required[str] + model: str + cache_hit: Optional[bool] + custom_llm_provider: Optional[str] + base_model: Optional[str] + call_type: str + custom_pricing: Optional[bool] ``` ## Langfuse diff --git a/docs/my-website/docs/proxy/prod.md b/docs/my-website/docs/proxy/prod.md index 66c719e5d..32a6fceee 100644 --- a/docs/my-website/docs/proxy/prod.md +++ b/docs/my-website/docs/proxy/prod.md @@ -1,5 +1,6 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import Image from '@theme/IdealImage'; # ⚡ Best Practices for Production @@ -112,7 +113,35 @@ general_settings: disable_spend_logs: True ``` -## 7. Set LiteLLM Salt Key +## 7. Use Helm PreSync Hook for Database Migrations [BETA] + +To ensure only one service manages database migrations, use our [Helm PreSync hook for Database Migrations](https://github.com/BerriAI/litellm/blob/main/deploy/charts/litellm-helm/templates/migrations-job.yaml). This ensures migrations are handled during `helm upgrade` or `helm install`, while LiteLLM pods explicitly disable migrations. + + +1. **Helm PreSync Hook**: + - The Helm PreSync hook is configured in the chart to run database migrations during deployments. + - The hook always sets `DISABLE_SCHEMA_UPDATE=false`, ensuring migrations are executed reliably. + + Reference Settings to set on ArgoCD for `values.yaml` + + ```yaml + db: + useExisting: true # use existing Postgres DB + url: postgresql://ishaanjaffer0324:3rnwpOBau6hT@ep-withered-mud-a5dkdpke.us-east-2.aws.neon.tech/test-argo-cd?sslmode=require # url of existing Postgres DB + ``` + +2. **LiteLLM Pods**: + - Set `DISABLE_SCHEMA_UPDATE=true` in LiteLLM pod configurations to prevent them from running migrations. + + Example configuration for LiteLLM pod: + ```yaml + env: + - name: DISABLE_SCHEMA_UPDATE + value: "true" + ``` + + +## 8. Set LiteLLM Salt Key If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB. diff --git a/docs/my-website/docs/proxy/reliability.md b/docs/my-website/docs/proxy/reliability.md index 9a3ba4ec6..73f25f817 100644 --- a/docs/my-website/docs/proxy/reliability.md +++ b/docs/my-website/docs/proxy/reliability.md @@ -748,4 +748,19 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ "max_tokens": 300, "mock_testing_fallbacks": true }' +``` + +### Disable Fallbacks per key + +You can disable fallbacks per key by setting `disable_fallbacks: true` in your key metadata. + +```bash +curl -L -X POST 'http://0.0.0.0:4000/key/generate' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "metadata": { + "disable_fallbacks": true + } +}' ``` \ No newline at end of file diff --git a/docs/my-website/docs/rerank.md b/docs/my-website/docs/rerank.md index 8179e6b81..d25b552fb 100644 --- a/docs/my-website/docs/rerank.md +++ b/docs/my-website/docs/rerank.md @@ -113,4 +113,5 @@ curl http://0.0.0.0:4000/rerank \ |-------------|--------------------| | Cohere | [Usage](#quick-start) | | Together AI| [Usage](../docs/providers/togetherai) | -| Azure AI| [Usage](../docs/providers/azure_ai) | \ No newline at end of file +| Azure AI| [Usage](../docs/providers/azure_ai) | +| Jina AI| [Usage](../docs/providers/jina_ai) | \ No newline at end of file diff --git a/docs/my-website/docs/secret.md b/docs/my-website/docs/secret.md index db5ec6910..15480ea3d 100644 --- a/docs/my-website/docs/secret.md +++ b/docs/my-website/docs/secret.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # Secret Manager LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager @@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2 ``` 2. Enable AWS Secret Manager in config. + + + + ```yaml general_settings: master_key: os.environ/litellm_master_key key_management_system: "aws_secret_manager" # 👈 KEY CHANGE key_management_settings: hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS + ``` + + + + +This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager. + +```yaml +general_settings: + key_management_system: "aws_secret_manager" # 👈 KEY CHANGE + key_management_settings: + store_virtual_keys: true + access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"] +``` + + + 3. Run proxy ```bash @@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml Use encrypted keys from Google KMS on the proxy -### Usage with LiteLLM Proxy Server - -## Step 1. Add keys to env +Step 1. Add keys to env ``` export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json" export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*" export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...' ``` -## Step 2: Update Config +Step 2: Update Config ```yaml general_settings: @@ -199,7 +221,7 @@ general_settings: master_key: sk-1234 ``` -## Step 3: Start + test proxy +Step 3: Start + test proxy ``` $ litellm --config /path/to/config.yaml @@ -215,3 +237,17 @@ $ litellm --test + + +## All Secret Manager Settings + +All settings related to secret management + +```yaml +general_settings: + key_management_system: "aws_secret_manager" # REQUIRED + key_management_settings: + store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager + access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only" + hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS +``` \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index e54117e11..edfe1a336 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -305,7 +305,7 @@ secret_manager_client: Optional[Any] = ( ) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None -_key_management_settings: Optional[KeyManagementSettings] = None +_key_management_settings: KeyManagementSettings = KeyManagementSettings() #### PII MASKING #### output_parse_pii: bool = False ############################################# @@ -962,6 +962,8 @@ from .utils import ( supports_response_schema, supports_parallel_function_calling, supports_vision, + supports_audio_input, + supports_audio_output, supports_system_messages, get_litellm_params, acreate, diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 0aa8a8e36..50bed6fe9 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -46,6 +46,9 @@ from litellm.llms.OpenAI.cost_calculation import ( from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router from litellm.llms.together_ai.cost_calculator import get_model_params_and_category +from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.cost_calculator import ( + cost_calculator as vertex_ai_image_cost_calculator, +) from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.rerank import RerankResponse from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS @@ -667,9 +670,11 @@ def completion_cost( # noqa: PLR0915 ): ### IMAGE GENERATION COST CALCULATION ### if custom_llm_provider == "vertex_ai": - # https://cloud.google.com/vertex-ai/generative-ai/pricing - # Vertex Charges Flat $0.20 per image - return 0.020 + if isinstance(completion_response, ImageResponse): + return vertex_ai_image_cost_calculator( + model=model, + image_response=completion_response, + ) elif custom_llm_provider == "bedrock": if isinstance(completion_response, ImageResponse): return bedrock_image_cost_calculator( diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index ca1de75be..3fb276611 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -239,7 +239,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ContextWindowExceededError: {exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif ( @@ -251,7 +251,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif "A timeout occurred" in error_str: @@ -271,7 +271,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ContentPolicyViolationError: {exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif ( @@ -283,7 +283,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif "Web server is returning an unknown error" in error_str: @@ -299,7 +299,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"RateLimitError: {exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif ( @@ -311,7 +311,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AuthenticationError: {exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif "Mistral API raised a streaming error" in error_str: @@ -335,7 +335,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"{exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 401: @@ -344,7 +344,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AuthenticationError: {exception_provider} - {message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 404: @@ -353,7 +353,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NotFoundError: {exception_provider} - {message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 408: @@ -516,7 +516,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "input is too long" in error_str: exception_mapping_worked = True @@ -524,7 +524,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif exception_type == "ModelError": exception_mapping_worked = True @@ -532,7 +532,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Request was throttled" in error_str: exception_mapping_worked = True @@ -540,7 +540,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -549,7 +549,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( original_exception.status_code == 400 @@ -560,7 +560,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 422: exception_mapping_worked = True @@ -568,7 +568,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -583,7 +583,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -591,7 +591,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -599,7 +599,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) exception_mapping_worked = True raise APIError( @@ -631,7 +631,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif "token_quota_reached" in error_str: @@ -640,7 +640,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "The server received an invalid response from an upstream server." @@ -750,7 +750,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.", model=model, llm_provider="bedrock", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Malformed input request" in error_str: exception_mapping_worked = True @@ -758,7 +758,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {error_str}", model=model, llm_provider="bedrock", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "A conversation must start with a user message." in error_str: exception_mapping_worked = True @@ -766,7 +766,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`", model=model, llm_provider="bedrock", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "Unable to locate credentials" in error_str @@ -778,7 +778,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException Invalid Authentication - {error_str}", model=model, llm_provider="bedrock", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "AccessDeniedException" in error_str: exception_mapping_worked = True @@ -786,7 +786,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException PermissionDeniedError - {error_str}", model=model, llm_provider="bedrock", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "throttlingException" in error_str @@ -797,7 +797,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException: Rate Limit Error - {error_str}", model=model, llm_provider="bedrock", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "Connect timeout on endpoint URL" in error_str @@ -836,7 +836,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -844,7 +844,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 404: exception_mapping_worked = True @@ -852,7 +852,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -868,7 +868,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {original_exception.message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 429: @@ -877,7 +877,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {original_exception.message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 503: @@ -886,7 +886,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BedrockException - {original_exception.message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 504: # gateway timeout error @@ -907,7 +907,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"litellm.BadRequestError: SagemakerException - {error_str}", model=model, llm_provider="sagemaker", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "Input validation error: `best_of` must be > 0 and <= 2" @@ -918,7 +918,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", model=model, llm_provider="sagemaker", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "`inputs` tokens + `max_new_tokens` must be <=" in error_str @@ -929,7 +929,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {error_str}", model=model, llm_provider="sagemaker", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 500: @@ -951,7 +951,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {original_exception.message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -959,7 +959,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {original_exception.message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 404: exception_mapping_worked = True @@ -967,7 +967,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {original_exception.message}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -986,7 +986,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {original_exception.message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 429: @@ -995,7 +995,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {original_exception.message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 503: @@ -1004,7 +1004,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"SagemakerException - {original_exception.message}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 504: # gateway timeout error @@ -1217,7 +1217,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message="GeminiException - Invalid api key", model=model, llm_provider="palm", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if ( "504 Deadline expired before operation could complete." in error_str @@ -1235,7 +1235,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"GeminiException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if ( "500 An internal error has occurred." in error_str @@ -1262,7 +1262,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"GeminiException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes elif custom_llm_provider == "cloudflare": @@ -1272,7 +1272,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"Cloudflare Exception - {original_exception.message}", llm_provider="cloudflare", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if "must have required property" in error_str: exception_mapping_worked = True @@ -1280,7 +1280,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"Cloudflare Exception - {original_exception.message}", llm_provider="cloudflare", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat" @@ -1294,7 +1294,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "too many tokens" in error_str: exception_mapping_worked = True @@ -1302,7 +1302,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif hasattr(original_exception, "status_code"): if ( @@ -1314,7 +1314,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -1329,7 +1329,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "CohereConnectionError" in exception_type @@ -1339,7 +1339,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "invalid type:" in error_str: exception_mapping_worked = True @@ -1347,7 +1347,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Unexpected server error" in error_str: exception_mapping_worked = True @@ -1355,7 +1355,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) else: if hasattr(original_exception, "status_code"): @@ -1375,7 +1375,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=error_str, model=model, llm_provider="huggingface", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "A valid user token is required" in error_str: exception_mapping_worked = True @@ -1383,7 +1383,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=error_str, llm_provider="huggingface", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Rate limit reached" in error_str: exception_mapping_worked = True @@ -1391,7 +1391,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=error_str, llm_provider="huggingface", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -1400,7 +1400,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -1408,7 +1408,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -1423,7 +1423,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 503: exception_mapping_worked = True @@ -1431,7 +1431,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) else: exception_mapping_worked = True @@ -1450,7 +1450,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if "Bad or missing API token." in original_exception.message: exception_mapping_worked = True @@ -1458,7 +1458,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -1467,7 +1467,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -1482,7 +1482,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -1490,7 +1490,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) else: exception_mapping_worked = True @@ -1509,7 +1509,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "value is not a valid" in error_str: exception_mapping_worked = True @@ -1517,7 +1517,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) else: exception_mapping_worked = True @@ -1542,7 +1542,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( original_exception.status_code == 401 @@ -1553,7 +1553,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( original_exception.status_code == 522 @@ -1574,7 +1574,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( original_exception.status_code == 500 @@ -1597,7 +1597,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) else: exception_mapping_worked = True @@ -1623,7 +1623,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "error" in error_response @@ -1634,7 +1634,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {error_response['error']}", llm_provider="together_ai", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "error" in error_response @@ -1645,7 +1645,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "A timeout occurred" in error_str: exception_mapping_worked = True @@ -1664,7 +1664,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif ( "error_type" in error_response @@ -1675,7 +1675,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 408: @@ -1691,7 +1691,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -1699,7 +1699,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"TogetherAIException - {original_exception.message}", llm_provider="together_ai", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 524: exception_mapping_worked = True @@ -1727,7 +1727,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "InvalidToken" in error_str or "No token provided" in error_str: exception_mapping_worked = True @@ -1735,7 +1735,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif hasattr(original_exception, "status_code"): verbose_logger.debug( @@ -1754,7 +1754,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -1762,7 +1762,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -1770,7 +1770,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) raise original_exception raise original_exception @@ -1787,7 +1787,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", model=model, llm_provider="ollama", - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Failed to establish a new connection" in error_str: exception_mapping_worked = True @@ -1795,7 +1795,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"OllamaException: {original_exception}", llm_provider="ollama", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Invalid response object from API" in error_str: exception_mapping_worked = True @@ -1803,7 +1803,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"OllamaException: {original_exception}", llm_provider="ollama", model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), ) elif "Read timed out" in error_str: exception_mapping_worked = True @@ -1837,6 +1837,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif "This model's maximum context length is" in error_str: exception_mapping_worked = True @@ -1845,6 +1846,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif "DeploymentNotFound" in error_str: exception_mapping_worked = True @@ -1853,6 +1855,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif ( ( @@ -1873,6 +1876,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif "invalid_request_error" in error_str: exception_mapping_worked = True @@ -1881,6 +1885,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif ( "The api_key client option must be set either by passing api_key to the client or by setting" @@ -1892,6 +1897,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider=custom_llm_provider, model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif "Connection error" in error_str: exception_mapping_worked = True @@ -1910,6 +1916,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 401: exception_mapping_worked = True @@ -1918,6 +1925,7 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider="azure", model=model, litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -1934,6 +1942,7 @@ def exception_type( # type: ignore # noqa: PLR0915 model=model, llm_provider="azure", litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -1942,6 +1951,7 @@ def exception_type( # type: ignore # noqa: PLR0915 model=model, llm_provider="azure", litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 503: exception_mapping_worked = True @@ -1950,6 +1960,7 @@ def exception_type( # type: ignore # noqa: PLR0915 model=model, llm_provider="azure", litellm_debug_info=extra_information, + response=getattr(original_exception, "response", None), ) elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True @@ -1989,7 +2000,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"{exception_provider} - {error_str}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 401: @@ -1998,7 +2009,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"AuthenticationError: {exception_provider} - {error_str}", llm_provider=custom_llm_provider, model=model, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 404: @@ -2007,7 +2018,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"NotFoundError: {exception_provider} - {error_str}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 408: @@ -2024,7 +2035,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"BadRequestError: {exception_provider} - {error_str}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 429: @@ -2033,7 +2044,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"RateLimitError: {exception_provider} - {error_str}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 503: @@ -2042,7 +2053,7 @@ def exception_type( # type: ignore # noqa: PLR0915 message=f"ServiceUnavailableError: {exception_provider} - {error_str}", model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=getattr(original_exception, "response", None), litellm_debug_info=extra_information, ) elif original_exception.status_code == 504: # gateway timeout error diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 66f91abf1..69d6adca4 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -202,6 +202,7 @@ class Logging: start_time, litellm_call_id: str, function_id: str, + litellm_trace_id: Optional[str] = None, dynamic_input_callbacks: Optional[ List[Union[str, Callable, CustomLogger]] ] = None, @@ -239,6 +240,7 @@ class Logging: self.start_time = start_time # log the call start time self.call_type = call_type self.litellm_call_id = litellm_call_id + self.litellm_trace_id = litellm_trace_id self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response self.sync_streaming_chunks: List[Any] = ( @@ -275,6 +277,11 @@ class Logging: self.completion_start_time: Optional[datetime.datetime] = None self._llm_caching_handler: Optional[LLMCachingHandler] = None + self.model_call_details = { + "litellm_trace_id": litellm_trace_id, + "litellm_call_id": litellm_call_id, + } + def process_dynamic_callbacks(self): """ Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks @@ -382,21 +389,23 @@ class Logging: self.logger_fn = litellm_params.get("logger_fn", None) verbose_logger.debug(f"self.optional_params: {self.optional_params}") - self.model_call_details = { - "model": self.model, - "messages": self.messages, - "optional_params": self.optional_params, - "litellm_params": self.litellm_params, - "start_time": self.start_time, - "stream": self.stream, - "user": user, - "call_type": str(self.call_type), - "litellm_call_id": self.litellm_call_id, - "completion_start_time": self.completion_start_time, - "standard_callback_dynamic_params": self.standard_callback_dynamic_params, - **self.optional_params, - **additional_params, - } + self.model_call_details.update( + { + "model": self.model, + "messages": self.messages, + "optional_params": self.optional_params, + "litellm_params": self.litellm_params, + "start_time": self.start_time, + "stream": self.stream, + "user": user, + "call_type": str(self.call_type), + "litellm_call_id": self.litellm_call_id, + "completion_start_time": self.completion_start_time, + "standard_callback_dynamic_params": self.standard_callback_dynamic_params, + **self.optional_params, + **additional_params, + } + ) ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation if "stream_options" in additional_params: @@ -2823,6 +2832,7 @@ def get_standard_logging_object_payload( payload: StandardLoggingPayload = StandardLoggingPayload( id=str(id), + trace_id=kwargs.get("litellm_trace_id"), # type: ignore call_type=call_type or "", cache_hit=cache_hit, status=status, diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 2d119a28f..2952d54d5 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -44,7 +44,9 @@ from litellm.types.llms.openai import ( ChatCompletionToolCallFunctionChunk, ChatCompletionUsageBlock, ) -from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper +from litellm.types.utils import GenericStreamingChunk +from litellm.types.utils import Message as LitellmMessage +from litellm.types.utils import PromptTokensDetailsWrapper from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from ...base import BaseLLM @@ -94,6 +96,7 @@ async def make_call( messages: list, logging_obj, timeout: Optional[Union[float, httpx.Timeout]], + json_mode: bool, ) -> Tuple[Any, httpx.Headers]: if client is None: client = litellm.module_level_aclient @@ -119,7 +122,9 @@ async def make_call( raise AnthropicError(status_code=500, message=str(e)) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_lines(), sync_stream=False + streaming_response=response.aiter_lines(), + sync_stream=False, + json_mode=json_mode, ) # LOGGING @@ -142,6 +147,7 @@ def make_sync_call( messages: list, logging_obj, timeout: Optional[Union[float, httpx.Timeout]], + json_mode: bool, ) -> Tuple[Any, httpx.Headers]: if client is None: client = litellm.module_level_client # re-use a module level client @@ -175,7 +181,7 @@ def make_sync_call( ) completion_stream = ModelResponseIterator( - streaming_response=response.iter_lines(), sync_stream=True + streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode ) # LOGGING @@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM): "arguments" ) if json_mode_content_str is not None: - args = json.loads(json_mode_content_str) - values: Optional[dict] = args.get("values") - if values is not None: - _message = litellm.Message(content=json.dumps(values)) + _converted_message = self._convert_tool_response_to_message( + tool_calls=tool_calls, + ) + if _converted_message is not None: completion_response["stop_reason"] = "stop" + _message = _converted_message model_response.choices[0].message = _message # type: ignore model_response._hidden_params["original_response"] = completion_response[ "content" @@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM): model_response._hidden_params = _hidden_params return model_response + @staticmethod + def _convert_tool_response_to_message( + tool_calls: List[ChatCompletionToolCallChunk], + ) -> Optional[LitellmMessage]: + """ + In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format + + """ + ## HANDLE JSON MODE - anthropic returns single function call + json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( + "arguments" + ) + try: + if json_mode_content_str is not None: + args = json.loads(json_mode_content_str) + if ( + isinstance(args, dict) + and (values := args.get("values")) is not None + ): + _message = litellm.Message(content=json.dumps(values)) + return _message + else: + # a lot of the times the `values` key is not present in the tool response + # relevant issue: https://github.com/BerriAI/litellm/issues/6741 + _message = litellm.Message(content=json.dumps(args)) + return _message + except json.JSONDecodeError: + # json decode error does occur, return the original tool response str + return litellm.Message(content=json_mode_content_str) + return None + async def acompletion_stream_function( self, model: str, @@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM): stream, _is_function_call, data: dict, + json_mode: bool, optional_params=None, litellm_params=None, logger_fn=None, @@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM): messages=messages, logging_obj=logging_obj, timeout=timeout, + json_mode=json_mode, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, @@ -440,8 +480,8 @@ class AnthropicChatCompletion(BaseLLM): logging_obj, optional_params: dict, timeout: Union[float, httpx.Timeout], + litellm_params: dict, acompletion=None, - litellm_params=None, logger_fn=None, headers={}, client=None, @@ -464,6 +504,7 @@ class AnthropicChatCompletion(BaseLLM): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, headers=headers, _is_function_call=_is_function_call, is_vertex_request=is_vertex_request, @@ -500,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM): optional_params=optional_params, stream=stream, _is_function_call=_is_function_call, + json_mode=json_mode, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, @@ -547,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM): messages=messages, logging_obj=logging_obj, timeout=timeout, + json_mode=json_mode, ) return CustomStreamWrapper( completion_stream=completion_stream, @@ -605,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM): class ModelResponseIterator: - def __init__(self, streaming_response, sync_stream: bool): + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): self.streaming_response = streaming_response self.response_iterator = self.streaming_response self.content_blocks: List[ContentBlockDelta] = [] self.tool_index = -1 + self.json_mode = json_mode def check_empty_tool_call_args(self) -> bool: """ @@ -771,6 +817,8 @@ class ModelResponseIterator: status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500 ) + text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use) + returned_chunk = GenericStreamingChunk( text=text, tool_use=tool_use, @@ -785,6 +833,34 @@ class ModelResponseIterator: except json.JSONDecodeError: raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + def _handle_json_mode_chunk( + self, text: str, tool_use: Optional[ChatCompletionToolCallChunk] + ) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]: + """ + If JSON mode is enabled, convert the tool call to a message. + + Anthropic returns the JSON schema as part of the tool call + OpenAI returns the JSON schema as part of the content, this handles placing it in the content + + Args: + text: str + tool_use: Optional[ChatCompletionToolCallChunk] + Returns: + Tuple[str, Optional[ChatCompletionToolCallChunk]] + + text: The text to use in the content + tool_use: The ChatCompletionToolCallChunk to use in the chunk response + """ + if self.json_mode is True and tool_use is not None: + message = AnthropicChatCompletion._convert_tool_response_to_message( + tool_calls=[tool_use] + ) + if message is not None: + text = message.content or "" + tool_use = None + + return text, tool_use + # Sync iterator def __iter__(self): return self diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index e222d8721..28bd8d86f 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -91,6 +91,7 @@ class AnthropicConfig: "extra_headers", "parallel_tool_calls", "response_format", + "user", ] def get_cache_control_headers(self) -> dict: @@ -246,6 +247,28 @@ class AnthropicConfig: anthropic_tools.append(new_tool) return anthropic_tools + def _map_stop_sequences( + self, stop: Optional[Union[str, List[str]]] + ) -> Optional[List[str]]: + new_stop: Optional[List[str]] = None + if isinstance(stop, str): + if ( + stop == "\n" + ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences + return new_stop + new_stop = [stop] + elif isinstance(stop, list): + new_v = [] + for v in stop: + if ( + v == "\n" + ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences + continue + new_v.append(v) + if len(new_v) > 0: + new_stop = new_v + return new_stop + def map_openai_params( self, non_default_params: dict, @@ -271,26 +294,10 @@ class AnthropicConfig: optional_params["tool_choice"] = _tool_choice if param == "stream" and value is True: optional_params["stream"] = value - if param == "stop": - if isinstance(value, str): - if ( - value == "\n" - ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences - continue - value = [value] - elif isinstance(value, list): - new_v = [] - for v in value: - if ( - v == "\n" - ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences - continue - new_v.append(v) - if len(new_v) > 0: - value = new_v - else: - continue - optional_params["stop_sequences"] = value + if param == "stop" and (isinstance(value, str) or isinstance(value, list)): + _value = self._map_stop_sequences(value) + if _value is not None: + optional_params["stop_sequences"] = _value if param == "temperature": optional_params["temperature"] = value if param == "top_p": @@ -314,7 +321,8 @@ class AnthropicConfig: optional_params["tools"] = [_tool] optional_params["tool_choice"] = _tool_choice optional_params["json_mode"] = True - + if param == "user": + optional_params["metadata"] = {"user_id": value} ## VALIDATE REQUEST """ Anthropic doesn't support tool calling without `tools=` param specified. @@ -465,6 +473,7 @@ class AnthropicConfig: model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, headers: dict, _is_function_call: bool, is_vertex_request: bool, @@ -502,6 +511,15 @@ class AnthropicConfig: if "tools" in optional_params: _is_function_call = True + ## Handle user_id in metadata + _litellm_metadata = litellm_params.get("metadata", None) + if ( + _litellm_metadata + and isinstance(_litellm_metadata, dict) + and "user_id" in _litellm_metadata + ): + optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]} + data = { "messages": anthropic_messages, **optional_params, diff --git a/litellm/llms/bedrock/image/amazon_stability3_transformation.py b/litellm/llms/bedrock/image/amazon_stability3_transformation.py index 784e86b04..2c90b3a12 100644 --- a/litellm/llms/bedrock/image/amazon_stability3_transformation.py +++ b/litellm/llms/bedrock/image/amazon_stability3_transformation.py @@ -53,9 +53,15 @@ class AmazonStability3Config: sd3-medium sd3.5-large sd3.5-large-turbo + + Stability ultra models + stable-image-ultra-v1 """ - if model and ("sd3" in model or "sd3.5" in model): - return True + if model: + if "sd3" in model or "sd3.5" in model: + return True + if "stable-image-ultra-v1" in model: + return True return False @classmethod diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py index dc0958118..8e6ad0eda 100644 --- a/litellm/llms/custom_httpx/types.py +++ b/litellm/llms/custom_httpx/types.py @@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum): GuardrailCallback = "guardrail_callback" Caching = "caching" Oauth2Check = "oauth2_check" + SecretManager = "secret_manager" diff --git a/litellm/llms/jina_ai/embedding/transformation.py b/litellm/llms/jina_ai/embedding/transformation.py index 26ff58878..97b7b2cfa 100644 --- a/litellm/llms/jina_ai/embedding/transformation.py +++ b/litellm/llms/jina_ai/embedding/transformation.py @@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig: or get_secret_str("JINA_AI_API_KEY") or get_secret_str("JINA_AI_TOKEN") ) - return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key + return LlmProviders.JINA_AI.value, api_base, dynamic_api_key diff --git a/litellm/llms/jina_ai/rerank/handler.py b/litellm/llms/jina_ai/rerank/handler.py new file mode 100644 index 000000000..a2cfdd49e --- /dev/null +++ b/litellm/llms/jina_ai/rerank/handler.py @@ -0,0 +1,96 @@ +""" +Re rank api + +LiteLLM supports the re rank API format, no paramter transformation occurs +""" + +import uuid +from typing import Any, Dict, List, Optional, Union + +import httpx +from pydantic import BaseModel + +import litellm +from litellm.llms.base import BaseLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) +from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig +from litellm.types.rerank import RerankRequest, RerankResponse + + +class JinaAIRerank(BaseLLM): + def rerank( + self, + model: str, + api_key: str, + query: str, + documents: List[Union[str, Dict[str, Any]]], + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + _is_async: Optional[bool] = False, + ) -> RerankResponse: + client = _get_httpx_client() + + request_data = RerankRequest( + model=model, + query=query, + top_n=top_n, + documents=documents, + rank_fields=rank_fields, + return_documents=return_documents, + ) + + # exclude None values from request_data + request_data_dict = request_data.dict(exclude_none=True) + + if _is_async: + return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method + + response = client.post( + "https://api.jina.ai/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + }, + json=request_data_dict, + ) + + if response.status_code != 200: + raise Exception(response.text) + + _json_response = response.json() + + return JinaAIRerankConfig()._transform_response(_json_response) + + async def async_rerank( # New async method + self, + request_data_dict: Dict[str, Any], + api_key: str, + ) -> RerankResponse: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.JINA_AI + ) # Use async client + + response = await client.post( + "https://api.jina.ai/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + }, + json=request_data_dict, + ) + + if response.status_code != 200: + raise Exception(response.text) + + _json_response = response.json() + + return JinaAIRerankConfig()._transform_response(_json_response) + + pass diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py new file mode 100644 index 000000000..82039a15b --- /dev/null +++ b/litellm/llms/jina_ai/rerank/transformation.py @@ -0,0 +1,36 @@ +""" +Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. + +Why separate file? Make it easy to see how transformation works + +Docs - https://jina.ai/reranker +""" + +import uuid +from typing import List, Optional + +from litellm.types.rerank import ( + RerankBilledUnits, + RerankResponse, + RerankResponseMeta, + RerankTokens, +) + + +class JinaAIRerankConfig: + def _transform_response(self, response: dict) -> RerankResponse: + + _billed_units = RerankBilledUnits(**response.get("usage", {})) + _tokens = RerankTokens(**response.get("usage", {})) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + _results: Optional[List[dict]] = response.get("results") + + if _results is None: + raise ValueError(f"No results found in the response={response}") + + return RerankResponse( + id=response.get("id") or str(uuid.uuid4()), + results=_results, + meta=rerank_meta, + ) # Return response diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 845d0e2dd..842d946c6 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -185,6 +185,8 @@ class OllamaConfig: "name": "mistral" }' """ + if model.startswith("ollama/") or model.startswith("ollama_chat/"): + model = model.split("/", 1)[1] api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" try: diff --git a/litellm/llms/together_ai/rerank.py b/litellm/llms/together_ai/rerank/handler.py similarity index 84% rename from litellm/llms/together_ai/rerank.py rename to litellm/llms/together_ai/rerank/handler.py index 1be73af2d..3e6d5d667 100644 --- a/litellm/llms/together_ai/rerank.py +++ b/litellm/llms/together_ai/rerank/handler.py @@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.types.rerank import RerankRequest, RerankResponse +from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig +from litellm.types.rerank import ( + RerankBilledUnits, + RerankRequest, + RerankResponse, + RerankResponseMeta, + RerankTokens, +) class TogetherAIRerank(BaseLLM): @@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM): _json_response = response.json() - response = RerankResponse( - id=_json_response.get("id"), - results=_json_response.get("results"), - meta=_json_response.get("meta") or {}, - ) - - return response + return TogetherAIRerankConfig()._transform_response(_json_response) async def async_rerank( # New async method self, @@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM): _json_response = response.json() - return RerankResponse( - id=_json_response.get("id"), - results=_json_response.get("results"), - meta=_json_response.get("meta") or {}, - ) # Return response - - pass + return TogetherAIRerankConfig()._transform_response(_json_response) diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py new file mode 100644 index 000000000..b2024b5cd --- /dev/null +++ b/litellm/llms/together_ai/rerank/transformation.py @@ -0,0 +1,34 @@ +""" +Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. + +Why separate file? Make it easy to see how transformation works +""" + +import uuid +from typing import List, Optional + +from litellm.types.rerank import ( + RerankBilledUnits, + RerankResponse, + RerankResponseMeta, + RerankTokens, +) + + +class TogetherAIRerankConfig: + def _transform_response(self, response: dict) -> RerankResponse: + + _billed_units = RerankBilledUnits(**response.get("usage", {})) + _tokens = RerankTokens(**response.get("usage", {})) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + _results: Optional[List[dict]] = response.get("results") + + if _results is None: + raise ValueError(f"No results found in the response={response}") + + return RerankResponse( + id=response.get("id") or str(uuid.uuid4()), + results=_results, + meta=rerank_meta, + ) # Return response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 0f95b222c..74bab0b26 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -89,6 +89,9 @@ def _get_vertex_url( elif mode == "embedding": endpoint = "predict" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + if model.isdigit(): + # https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" if not url or not endpoint: raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/cost_calculator.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/cost_calculator.py new file mode 100644 index 000000000..2d7fa37f7 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/cost_calculator.py @@ -0,0 +1,25 @@ +""" +Vertex AI Image Generation Cost Calculator +""" + +from typing import Optional + +import litellm +from litellm.types.utils import ImageResponse + + +def cost_calculator( + model: str, + image_response: ImageResponse, +) -> float: + """ + Vertex AI Image Generation Cost Calculator + """ + _model_info = litellm.get_model_info( + model=model, + custom_llm_provider="vertex_ai", + ) + + output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0 + num_images: int = len(image_response.data) + return output_cost_per_image * num_images diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py index 0cde5c3b5..26741ff4f 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py @@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase): headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) vertex_request: VertexEmbeddingRequest = ( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params + input=input, optional_params=optional_params, model=model ) ) @@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase): headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) vertex_request: VertexEmbeddingRequest = ( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params + input=input, optional_params=optional_params, model=model ) ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py index 1ca405392..6f4b25cef 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py @@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel): return optional_params def transform_openai_request_to_vertex_embedding_request( - self, input: Union[list, str], optional_params: dict + self, input: Union[list, str], optional_params: dict, model: str ) -> VertexEmbeddingRequest: """ Transforms an openai request to a vertex embedding request. """ + if model.isdigit(): + return self._transform_openai_request_to_fine_tuned_embedding_request( + input, optional_params, model + ) + vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() vertex_text_embedding_input_list: List[TextEmbeddingInput] = [] task_type: Optional[TaskType] = optional_params.get("task_type") @@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel): return vertex_request + def _transform_openai_request_to_fine_tuned_embedding_request( + self, input: Union[list, str], optional_params: dict, model: str + ) -> VertexEmbeddingRequest: + """ + Transforms an openai request to a vertex fine-tuned embedding request. + + Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22)) + Sample Request: + + ```json + { + "instances" : [ + { + "inputs": "How would the Future of AI in 10 Years look?", + "parameters": { + "max_new_tokens": 128, + "temperature": 1.0, + "top_p": 0.9, + "top_k": 10 + } + } + ] + } + ``` + """ + vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() + vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = [] + if isinstance(input, str): + input = [input] # Convert single string to list for uniform processing + + for text in input: + embedding_input = TextEmbeddingFineTunedInput(inputs=text) + vertex_text_embedding_input_list.append(embedding_input) + + vertex_request["instances"] = vertex_text_embedding_input_list + vertex_request["parameters"] = TextEmbeddingFineTunedParameters( + **optional_params + ) + + return vertex_request + def create_embedding_input( self, content: str, @@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel): """ Transforms a vertex embedding response to an openai response. """ + if model.isdigit(): + return self._transform_vertex_response_to_openai_for_fine_tuned_models( + response, model, model_response + ) + _predictions = response["predictions"] embedding_response = [] @@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel): ) setattr(model_response, "usage", usage) return model_response + + def _transform_vertex_response_to_openai_for_fine_tuned_models( + self, response: dict, model: str, model_response: litellm.EmbeddingResponse + ) -> litellm.EmbeddingResponse: + """ + Transforms a vertex fine-tuned model embedding response to an openai response format. + """ + _predictions = response["predictions"] + + embedding_response = [] + # For fine-tuned models, we don't get token counts in the response + input_tokens = 0 + + for idx, embedding_values in enumerate(_predictions): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding_values[ + 0 + ], # The embedding values are nested one level deeper + } + ) + + model_response.object = "list" + model_response.data = embedding_response + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + setattr(model_response, "usage", usage) + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py index 311809c82..433305516 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py @@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False): title: Optional[str] +# Fine-tuned models require a different input format +# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22)) +class TextEmbeddingFineTunedInput(TypedDict, total=False): + inputs: str + + +class TextEmbeddingFineTunedParameters(TypedDict, total=False): + max_new_tokens: Optional[int] + temperature: Optional[float] + top_p: Optional[float] + top_k: Optional[int] + + class EmbeddingParameters(TypedDict, total=False): auto_truncate: Optional[bool] output_dimensionality: Optional[int] class VertexEmbeddingRequest(TypedDict, total=False): - instances: List[TextEmbeddingInput] - parameters: Optional[EmbeddingParameters] + instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]] + parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]] # Example usage: diff --git a/litellm/main.py b/litellm/main.py index afb46c698..543a93eea 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915 azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), user_continue_message=kwargs.get("user_continue_message"), base_model=base_model, + litellm_trace_id=kwargs.get("litellm_trace_id"), ) logging.update_environment_variables( model=model, @@ -3455,7 +3456,7 @@ def embedding( # noqa: PLR0915 client=client, aembedding=aembedding, ) - elif custom_llm_provider == "openai_like": + elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai": api_base = ( api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE") ) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index fb8fb105c..cae3bee12 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -2986,19 +2986,19 @@ "supports_function_calling": true }, "vertex_ai/imagegeneration@006": { - "cost_per_image": 0.020, + "output_cost_per_image": 0.020, "litellm_provider": "vertex_ai-image-models", "mode": "image_generation", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" }, "vertex_ai/imagen-3.0-generate-001": { - "cost_per_image": 0.04, + "output_cost_per_image": 0.04, "litellm_provider": "vertex_ai-image-models", "mode": "image_generation", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" }, "vertex_ai/imagen-3.0-fast-generate-001": { - "cost_per_image": 0.02, + "output_cost_per_image": 0.02, "litellm_provider": "vertex_ai-image-models", "mode": "image_generation", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" @@ -5620,6 +5620,13 @@ "litellm_provider": "bedrock", "mode": "image_generation" }, + "stability.stable-image-ultra-v1:0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.14, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "sagemaker/meta-textgeneration-llama-2-7b": { "max_tokens": 4096, "max_input_tokens": 4096, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 911f15b86..b06a9e667 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,122 +1,15 @@ model_list: - - model_name: "*" - litellm_params: - model: claude-3-5-sonnet-20240620 - api_key: os.environ/ANTHROPIC_API_KEY - - model_name: claude-3-5-sonnet-aihubmix - litellm_params: - model: openai/claude-3-5-sonnet-20240620 - input_cost_per_token: 0.000003 # 3$/M - output_cost_per_token: 0.000015 # 15$/M - api_base: "https://exampleopenaiendpoint-production.up.railway.app" - api_key: my-fake-key - - model_name: fake-openai-endpoint-2 - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - stream_timeout: 0.001 - timeout: 1 - rpm: 1 - - model_name: fake-openai-endpoint - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - ## bedrock chat completions - - model_name: "*anthropic.claude*" - litellm_params: - model: bedrock/*anthropic.claude* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - guardrailConfig: - "guardrailIdentifier": "h4dsqwhp6j66" - "guardrailVersion": "2" - "trace": "enabled" - -## bedrock embeddings - - model_name: "*amazon.titan-embed-*" - litellm_params: - model: bedrock/amazon.titan-embed-* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - - model_name: "*cohere.embed-*" - litellm_params: - model: bedrock/cohere.embed-* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - - - model_name: "bedrock/*" - litellm_params: - model: bedrock/* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - + # GPT-4 Turbo Models - model_name: gpt-4 litellm_params: - model: azure/chatgpt-v-2 - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_version: "2023-05-15" - api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - rpm: 480 - timeout: 300 - stream_timeout: 60 - -litellm_settings: - fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] - # callbacks: ["otel", "prometheus"] - default_redis_batch_cache_expiry: 10 - # default_team_settings: - # - team_id: "dbe2f686-a686-4896-864a-4c3924458709" - # success_callback: ["langfuse"] - # langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1 - # langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1 - -# litellm_settings: -# cache: True -# cache_params: -# type: redis - -# # disable caching on the actual API call -# supported_call_types: [] - -# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url -# host: os.environ/REDIS_HOST -# port: os.environ/REDIS_PORT -# password: os.environ/REDIS_PASSWORD - -# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests -# # see https://docs.litellm.ai/docs/proxy/prometheus -# callbacks: ['otel'] + model: gpt-4 + - model_name: rerank-model + litellm_params: + model: jina_ai/jina-reranker-v2-base-multilingual -# # router_settings: -# # routing_strategy: latency-based-routing -# # routing_strategy_args: -# # # only assign 40% of traffic to the fastest deployment to avoid overloading it -# # lowest_latency_buffer: 0.4 - -# # # consider last five minutes of calls for latency calculation -# # ttl: 300 -# # redis_host: os.environ/REDIS_HOST -# # redis_port: os.environ/REDIS_PORT -# # redis_password: os.environ/REDIS_PASSWORD - -# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml -# # general_settings: -# # master_key: os.environ/LITELLM_MASTER_KEY -# # database_url: os.environ/DATABASE_URL -# # disable_master_key_return: true -# # # alerting: ['slack', 'email'] -# # alerting: ['email'] - -# # # Batch write spend updates every 60s -# # proxy_batch_write_at: 60 - -# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl -# # # our api keys rarely change -# # user_api_key_cache_ttl: 3600 +router_settings: + model_group_alias: + "gpt-4-turbo": # Aliased model name + model: "gpt-4" # Actual model name in 'model_list' + hidden: true \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2d869af85..4baf13b61 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum): class KeyManagementSettings(LiteLLMBase): - hosted_keys: List + hosted_keys: Optional[List] = None + store_virtual_keys: Optional[bool] = False + """ + If True, virtual keys created by litellm will be stored in the secret manager + """ + + access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only" + """ + Access mode for the secret manager, when write_only will only use for writing secrets + """ class TeamDefaultSettings(LiteLLMBase): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 12b6ec372..8d3afa33f 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -8,6 +8,7 @@ Run checks for: 2. If user is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ + import time import traceback from datetime import datetime diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py new file mode 100644 index 000000000..08645a468 --- /dev/null +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -0,0 +1,267 @@ +import asyncio +import json +import uuid +from datetime import datetime, timezone +from re import A +from typing import Any, List, Optional + +from fastapi import status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + GenerateKeyRequest, + KeyManagementSystem, + KeyRequest, + LiteLLM_AuditLogs, + LiteLLM_VerificationToken, + LitellmTableNames, + ProxyErrorTypes, + ProxyException, + UpdateKeyRequest, + UserAPIKeyAuth, + WebhookEvent, +) + + +class KeyManagementEventHooks: + + @staticmethod + async def async_key_generated_hook( + data: GenerateKeyRequest, + response: dict, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Hook that runs after a successful /key/generate request + + Handles the following: + - Sending Email with Key Details + - Storing Audit Logs for key generation + - Storing Generated Key in DB + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + proxy_logging_obj, + ) + + if data.send_invite_email is True: + await KeyManagementEventHooks._send_key_created_email(response) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(response, default=str) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=response.get("token_id", ""), + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + # store the generated key in the secret manager + await KeyManagementEventHooks._store_virtual_key_in_secret_manager( + secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}", + secret_token=response.get("token", ""), + ) + + @staticmethod + async def async_key_updated_hook( + data: UpdateKeyRequest, + existing_key_row: Any, + response: Any, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/update processing hook + + Handles the following: + - Storing Audit Logs for key update + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(data.json(exclude_none=True), default=str) + + _before_value = existing_key_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=data.key, + action="updated", + updated_values=_updated_values, + before_value=_before_value, + ) + ) + ) + pass + + @staticmethod + async def async_key_deleted_hook( + data: KeyRequest, + keys_being_deleted: List[LiteLLM_VerificationToken], + response: dict, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/delete processing hook + + Handles the following: + - Storing Audit Logs for key deletion + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for key in data.keys: + key_row = await prisma_client.get_data( # type: ignore + token=key, table_name="key", query_type="find_unique" + ) + + if key_row is None: + raise ProxyException( + message=f"Key {key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + + key_row = key_row.json(exclude_none=True) + _key_row = json.dumps(key_row, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=key, + action="deleted", + updated_values="{}", + before_value=_key_row, + ) + ) + ) + # delete the keys from the secret manager + await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager( + keys_being_deleted=keys_being_deleted + ) + pass + + @staticmethod + async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str): + """ + Store a virtual key in the secret manager + + Args: + secret_name: Name of the virtual key + secret_token: Value of the virtual key (example: sk-1234) + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + # store the key in the secret manager + if ( + litellm._key_management_system + == KeyManagementSystem.AWS_SECRET_MANAGER + and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2) + ): + await litellm.secret_manager_client.async_write_secret( + secret_name=secret_name, + secret_value=secret_token, + ) + + @staticmethod + async def _delete_virtual_keys_from_secret_manager( + keys_being_deleted: List[LiteLLM_VerificationToken], + ): + """ + Deletes virtual keys from the secret manager + + Args: + keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2): + for key in keys_being_deleted: + if key.key_alias is not None: + await litellm.secret_manager_client.async_delete_secret( + secret_name=key.key_alias + ) + else: + verbose_proxy_logger.warning( + f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager." + ) + + @staticmethod + async def _send_key_created_email(response: dict): + from litellm.proxy.proxy_server import general_settings, proxy_logging_obj + + if "email" not in general_settings.get("alerting", []): + raise ValueError( + "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" + ) + event = WebhookEvent( + event="key_created", + event_group="key", + event_message="API Key Created", + token=response.get("token", ""), + spend=response.get("spend", 0.0), + max_budget=response.get("max_budget", 0.0), + user_id=response.get("user_id", None), + team_id=response.get("team_id", "Default Team"), + key_alias=response.get("key_alias", None), + ) + + # If user configured email alerting - send an Email letting their end-user know the key was created + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( + webhook_event=event, + ) + ) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 789e79f37..3d1d3b491 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup: ) return user_api_key_logged_metadata + @staticmethod + def add_key_level_controls( + key_metadata: dict, data: dict, _metadata_variable_name: str + ): + data = data.copy() + if "cache" in key_metadata: + data["cache"] = {} + if isinstance(key_metadata["cache"], dict): + for k, v in key_metadata["cache"].items(): + if k in SupportedCacheControls: + data["cache"][k] = v + + ## KEY-LEVEL SPEND LOGS / TAGS + if "tags" in key_metadata and key_metadata["tags"] is not None: + if "tags" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["tags"], list + ): + data[_metadata_variable_name]["tags"].extend(key_metadata["tags"]) + else: + data[_metadata_variable_name]["tags"] = key_metadata["tags"] + if "spend_logs_metadata" in key_metadata and isinstance( + key_metadata["spend_logs_metadata"], dict + ): + if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["spend_logs_metadata"], dict + ): + for key, value in key_metadata["spend_logs_metadata"].items(): + if ( + key not in data[_metadata_variable_name]["spend_logs_metadata"] + ): # don't override k-v pair sent by request (user request) + data[_metadata_variable_name]["spend_logs_metadata"][ + key + ] = value + else: + data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ + "spend_logs_metadata" + ] + + ## KEY-LEVEL DISABLE FALLBACKS + if "disable_fallbacks" in key_metadata and isinstance( + key_metadata["disable_fallbacks"], bool + ): + data["disable_fallbacks"] = key_metadata["disable_fallbacks"] + return data + async def add_litellm_data_to_request( # noqa: PLR0915 data: dict, @@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata - if "cache" in key_metadata: - data["cache"] = {} - if isinstance(key_metadata["cache"], dict): - for k, v in key_metadata["cache"].items(): - if k in SupportedCacheControls: - data["cache"][k] = v - - ## KEY-LEVEL SPEND LOGS / TAGS - if "tags" in key_metadata and key_metadata["tags"] is not None: - if "tags" in data[_metadata_variable_name] and isinstance( - data[_metadata_variable_name]["tags"], list - ): - data[_metadata_variable_name]["tags"].extend(key_metadata["tags"]) - else: - data[_metadata_variable_name]["tags"] = key_metadata["tags"] - if "spend_logs_metadata" in key_metadata and isinstance( - key_metadata["spend_logs_metadata"], dict - ): - if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( - data[_metadata_variable_name]["spend_logs_metadata"], dict - ): - for key, value in key_metadata["spend_logs_metadata"].items(): - if ( - key not in data[_metadata_variable_name]["spend_logs_metadata"] - ): # don't override k-v pair sent by request (user request) - data[_metadata_variable_name]["spend_logs_metadata"][key] = value - else: - data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ - "spend_logs_metadata" - ] - + data = LiteLLMProxyRequestSetup.add_key_level_controls( + key_metadata=key_metadata, + data=data, + _metadata_variable_name=_metadata_variable_name, + ) ## TEAM-LEVEL SPEND LOGS/TAGS team_metadata = user_api_key_dict.team_metadata or {} if "tags" in team_metadata and team_metadata["tags"] is not None: diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 01baa5a43..e38236e9b 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -17,7 +17,7 @@ import secrets import traceback import uuid from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import List, Optional, Tuple import fastapi from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status @@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import ( get_key_object, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed from litellm.secret_managers.main import get_secret @@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915 data.soft_budget ) # include the user-input soft budget in the response - if data.send_invite_email is True: - if "email" not in general_settings.get("alerting", []): - raise ValueError( - "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" - ) - event = WebhookEvent( - event="key_created", - event_group="key", - event_message="API Key Created", - token=response.get("token", ""), - spend=response.get("spend", 0.0), - max_budget=response.get("max_budget", 0.0), - user_id=response.get("user_id", None), - team_id=response.get("team_id", "Default Team"), - key_alias=response.get("key_alias", None), - ) - - # If user configured email alerting - send an Email letting their end-user know the key was created - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( - webhook_event=event, - ) - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(response, default=str) - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=response.get("token_id", ""), - action="created", - updated_values=_updated_values, - before_value=None, - ) - ) + asyncio.create_task( + KeyManagementEventHooks.async_key_generated_hook( + data=data, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, ) + ) return GenerateKeyResponse(**response) except Exception as e: @@ -407,30 +372,15 @@ async def update_key_fn( proxy_logging_obj=proxy_logging_obj, ) - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(data_json, default=str) - - _before_value = existing_key_row.json(exclude_none=True) - _before_value = json.dumps(_before_value, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=data.key, - action="updated", - updated_values=_updated_values, - before_value=_before_value, - ) - ) + asyncio.create_task( + KeyManagementEventHooks.async_key_updated_hook( + data=data, + existing_key_row=existing_key_row, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, ) + ) if response is None: raise ValueError("Failed to update key got response = None") @@ -496,6 +446,9 @@ async def delete_key_fn( user_custom_key_generate, ) + if prisma_client is None: + raise Exception("Not connected to DB!") + keys = data.keys if len(keys) == 0: raise ProxyException( @@ -516,45 +469,7 @@ async def delete_key_fn( ): user_id = None # unless they're admin - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes - if litellm.store_audit_logs is True: - # make an audit log for each team deleted - for key in data.keys: - key_row = await prisma_client.get_data( # type: ignore - token=key, table_name="key", query_type="find_unique" - ) - - if key_row is None: - raise ProxyException( - message=f"Key {key} not found", - type=ProxyErrorTypes.bad_request_error, - param="key", - code=status.HTTP_404_NOT_FOUND, - ) - - key_row = key_row.json(exclude_none=True) - _key_row = json.dumps(key_row, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=key, - action="deleted", - updated_values="{}", - before_value=_key_row, - ) - ) - ) - - number_deleted_keys = await delete_verification_token( + number_deleted_keys, _keys_being_deleted = await delete_verification_token( tokens=keys, user_id=user_id ) if number_deleted_keys is None: @@ -588,6 +503,16 @@ async def delete_key_fn( f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" ) + asyncio.create_task( + KeyManagementEventHooks.async_key_deleted_hook( + data=data, + keys_being_deleted=_keys_being_deleted, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + response=number_deleted_keys, + ) + ) + return {"deleted_keys": keys} except Exception as e: if isinstance(e, HTTPException): @@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915 return key_data -async def delete_verification_token(tokens: List, user_id: Optional[str] = None): +async def delete_verification_token( + tokens: List, user_id: Optional[str] = None +) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + """ + Helper that deletes the list of tokens from the database + + Args: + tokens: List of tokens to delete + user_id: Optional user_id to filter by + + Returns: + Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + Optional[Dict]: + - Number of deleted tokens + List[LiteLLM_VerificationToken]: + - List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted, + this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs + """ from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client try: if prisma_client: + tokens = [_hash_token_if_needed(token=key) for key in tokens] + _keys_being_deleted = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} + ) + ) + # Assuming 'db' is your Prisma Client instance # check if admin making request - don't filter by user-id if user_id == litellm_proxy_admin_name: @@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None) ) verbose_proxy_logger.debug(traceback.format_exc()) raise e - return deleted_tokens + return deleted_tokens, _keys_being_deleted @router.post( diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index f9f8276c7..094828de1 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915 key_management_system == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + ### LOAD FROM AWS SECRET MANAGER ### - load_aws_secret_manager(use_aws_secret_manager=True) + AWSSecretsManagerV2.load_aws_secret_manager( + use_aws_secret_manager=True + ) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) elif ( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 29d14c910..71e3dee0e 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -7,6 +7,8 @@ model_list: -litellm_settings: - callbacks: ["gcs_bucket"] - +general_settings: + key_management_system: "aws_secret_manager" + key_management_settings: + store_virtual_keys: true + access_mode: "write_only" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index bbf4b0b93..92ca32e52 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -245,10 +245,7 @@ from litellm.router import ( from litellm.router import ModelInfo as RouterModelInfo from litellm.router import updateDeployment from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler -from litellm.secret_managers.aws_secret_manager import ( - load_aws_kms, - load_aws_secret_manager, -) +from litellm.secret_managers.aws_secret_manager import load_aws_kms from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.main import ( get_secret, @@ -1825,8 +1822,13 @@ class ProxyConfig: key_management_system == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): - ### LOAD FROM AWS SECRET MANAGER ### - load_aws_secret_manager(use_aws_secret_manager=True) + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + AWSSecretsManagerV2.load_aws_secret_manager( + use_aws_secret_manager=True + ) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) elif ( diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index a06aff135..9cc8a8c1d 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -8,7 +8,8 @@ from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.cohere.rerank import CohereRerank -from litellm.llms.together_ai.rerank import TogetherAIRerank +from litellm.llms.jina_ai.rerank.handler import JinaAIRerank +from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.rerank import RerankRequest, RerankResponse from litellm.types.router import * @@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout cohere_rerank = CohereRerank() together_rerank = TogetherAIRerank() azure_ai_rerank = AzureAIRerank() +jina_ai_rerank = JinaAIRerank() ################################################# @@ -247,7 +249,23 @@ def rerank( api_key=api_key, _is_async=_is_async, ) + elif _custom_llm_provider == "jina_ai": + if dynamic_api_key is None: + raise ValueError( + "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" + ) + response = jina_ai_rerank.rerank( + model=model, + api_key=dynamic_api_key, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + _is_async=_is_async, + ) else: raise ValueError(f"Unsupported provider: {_custom_llm_provider}") diff --git a/litellm/router.py b/litellm/router.py index 4735d422b..97065bc85 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -679,9 +679,8 @@ class Router: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._completion - kwargs.get("request_timeout", self.timeout) - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = self.function_with_fallbacks(**kwargs) return response except Exception as e: @@ -783,8 +782,7 @@ class Router: kwargs["stream"] = stream kwargs["original_function"] = self._acompletion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) request_priority = kwargs.get("priority") or self.default_priority @@ -948,6 +946,17 @@ class Router: self.fail_calls[model_name] += 1 raise e + def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None: + """ + Adds/updates to kwargs: + - num_retries + - litellm_trace_id + - metadata + """ + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("litellm_trace_id", str(uuid.uuid4())) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: """ Adds default litellm params to kwargs, if set. @@ -1511,9 +1520,7 @@ class Router: kwargs["model"] = model kwargs["file"] = file kwargs["original_function"] = self._atranscription - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response @@ -1688,9 +1695,7 @@ class Router: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._arerank - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) @@ -1839,9 +1844,7 @@ class Router: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._atext_completion - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response @@ -2112,9 +2115,7 @@ class Router: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._aembedding - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: @@ -2609,6 +2610,7 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group: Optional[str] = kwargs.get("model") + disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False) fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks: Optional[List] = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks @@ -2616,6 +2618,7 @@ class Router: content_policy_fallbacks: Optional[List] = kwargs.get( "content_policy_fallbacks", self.content_policy_fallbacks ) + try: self._handle_mock_testing_fallbacks( kwargs=kwargs, @@ -2635,7 +2638,7 @@ class Router: original_model_group: Optional[str] = kwargs.get("model") # type: ignore fallback_failure_exception_str = "" - if original_model_group is None: + if disable_fallbacks is True or original_model_group is None: raise e input_kwargs = { diff --git a/litellm/secret_managers/aws_secret_manager.py b/litellm/secret_managers/aws_secret_manager.py index f0e510fa8..fbe951e64 100644 --- a/litellm/secret_managers/aws_secret_manager.py +++ b/litellm/secret_managers/aws_secret_manager.py @@ -23,28 +23,6 @@ def validate_environment(): raise ValueError("Missing required environment variable - AWS_REGION_NAME") -def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]): - if use_aws_secret_manager is None or use_aws_secret_manager is False: - return - try: - import boto3 - from botocore.exceptions import ClientError - - validate_environment() - - # Create a Secrets Manager client - session = boto3.session.Session() # type: ignore - client = session.client( - service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME") - ) - - litellm.secret_manager_client = client - litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER - - except Exception as e: - raise e - - def load_aws_kms(use_aws_kms: Optional[bool]): if use_aws_kms is None or use_aws_kms is False: return diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py new file mode 100644 index 000000000..69add6f23 --- /dev/null +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -0,0 +1,310 @@ +""" +This is a file for the AWS Secret Manager Integration + +Handles Async Operations for: +- Read Secret +- Write Secret +- Delete Secret + +Relevant issue: https://github.com/BerriAI/litellm/issues/1883 + +Requires: +* `os.environ["AWS_REGION_NAME"], +* `pip install boto3>=1.28.57` +""" + +import ast +import asyncio +import base64 +import json +import os +import re +import sys +from typing import Any, Dict, Optional, Union + +import httpx + +import litellm +from litellm._logging import verbose_logger +from litellm.llms.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) +from litellm.llms.custom_httpx.types import httpxSpecialProvider +from litellm.proxy._types import KeyManagementSystem + + +class AWSSecretsManagerV2(BaseAWSLLM): + @classmethod + def validate_environment(cls): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + @classmethod + def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]): + """ + Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + """ + if use_aws_secret_manager is None or use_aws_secret_manager is False: + return + try: + import boto3 + + cls.validate_environment() + litellm.secret_manager_client = cls() + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + + except Exception as e: + raise e + + async def async_read_secret( + self, + secret_name: str, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> Optional[str]: + """ + Async function to read a secret from AWS Secrets Manager + + Returns: + str: Secret value + Raises: + ValueError: If the secret is not found or an HTTP error occurs + """ + endpoint_url, headers, body = self._prepare_request( + action="GetSecretValue", + secret_name=secret_name, + optional_params=optional_params, + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json()["SecretString"] + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + except Exception as e: + verbose_logger.exception( + "Error reading secret from AWS Secrets Manager: %s", str(e) + ) + return None + + def sync_read_secret( + self, + secret_name: str, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> Optional[str]: + """ + Sync function to read a secret from AWS Secrets Manager + + Done for backwards compatibility with existing codebase, since get_secret is a sync function + """ + + # self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop + if secret_name in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION_NAME", + "AWS_REGION", + "AWS_BEDROCK_RUNTIME_ENDPOINT", + ]: + return os.getenv(secret_name) + + endpoint_url, headers, body = self._prepare_request( + action="GetSecretValue", + secret_name=secret_name, + optional_params=optional_params, + ) + + sync_client = _get_httpx_client( + params={"timeout": timeout}, + ) + + try: + response = sync_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json()["SecretString"] + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + except Exception as e: + verbose_logger.exception( + "Error reading secret from AWS Secrets Manager: %s", str(e) + ) + return None + + async def async_write_secret( + self, + secret_name: str, + secret_value: str, + description: Optional[str] = None, + client_request_token: Optional[str] = None, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> dict: + """ + Async function to write a secret to AWS Secrets Manager + + Args: + secret_name: Name of the secret + secret_value: Value to store (can be a JSON string) + description: Optional description for the secret + client_request_token: Optional unique identifier to ensure idempotency + optional_params: Additional AWS parameters + timeout: Request timeout + """ + import uuid + + # Prepare the request data + data = {"Name": secret_name, "SecretString": secret_value} + if description: + data["Description"] = description + + data["ClientRequestToken"] = str(uuid.uuid4()) + + endpoint_url, headers, body = self._prepare_request( + action="CreateSecret", + secret_name=secret_name, + secret_value=secret_value, + optional_params=optional_params, + request_data=data, # Pass the complete request data + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as err: + raise ValueError(f"HTTP error occurred: {err.response.text}") + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + + async def async_delete_secret( + self, + secret_name: str, + recovery_window_in_days: Optional[int] = 7, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> dict: + """ + Async function to delete a secret from AWS Secrets Manager + + Args: + secret_name: Name of the secret to delete + recovery_window_in_days: Number of days before permanent deletion (default: 7) + optional_params: Additional AWS parameters + timeout: Request timeout + + Returns: + dict: Response from AWS Secrets Manager containing deletion details + """ + # Prepare the request data + data = { + "SecretId": secret_name, + "RecoveryWindowInDays": recovery_window_in_days, + } + + endpoint_url, headers, body = self._prepare_request( + action="DeleteSecret", + secret_name=secret_name, + optional_params=optional_params, + request_data=data, + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as err: + raise ValueError(f"HTTP error occurred: {err.response.text}") + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + + def _prepare_request( + self, + action: str, # "GetSecretValue" or "PutSecretValue" + secret_name: str, + secret_value: Optional[str] = None, + optional_params: Optional[dict] = None, + request_data: Optional[dict] = None, + ) -> tuple[str, Any, bytes]: + """Prepare the AWS Secrets Manager request""" + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + optional_params = optional_params or {} + boto3_credentials_info = self._get_boto_credentials_from_optional_params( + optional_params + ) + + # Get endpoint + _, endpoint_url = self.get_runtime_endpoint( + api_base=None, + aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, + aws_region_name=boto3_credentials_info.aws_region_name, + ) + endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager") + + # Use provided request_data if available, otherwise build default data + if request_data: + data = request_data + else: + data = {"SecretId": secret_name} + if secret_value and action == "PutSecretValue": + data["SecretString"] = secret_value + + body = json.dumps(data).encode("utf-8") + headers = { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": f"secretsmanager.{action}", + } + + # Sign request + request = AWSRequest( + method="POST", url=endpoint_url, data=body, headers=headers + ) + SigV4Auth( + boto3_credentials_info.credentials, + "secretsmanager", + boto3_credentials_info.aws_region_name, + ).add_auth(request) + prepped = request.prepare() + + return endpoint_url, prepped.headers, body + + +# if __name__ == "__main__": +# print("loading aws secret manager v2") +# aws_secret_manager_v2 = AWSSecretsManagerV2() + +# print("writing secret to aws secret manager v2") +# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2")) +# print("reading secret from aws secret manager v2") diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index f3d6d420a..ce6d30755 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -5,7 +5,7 @@ import json import os import sys import traceback -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import httpx from dotenv import load_dotenv @@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915 raise ValueError("Unsupported OIDC provider") try: - if litellm.secret_manager_client is not None: + if ( + _should_read_secret_from_secret_manager() + and litellm.secret_manager_client is not None + ): try: client = litellm.secret_manager_client key_manager = "local" @@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915 if key_management_settings is not None: if ( - secret_name not in key_management_settings.hosted_keys + key_management_settings.hosted_keys is not None + and secret_name not in key_management_settings.hosted_keys ): # allow user to specify which keys to check in hosted key manager key_manager = "local" @@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915 if isinstance(secret, str): secret = secret.strip() elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: - try: - get_secret_value_response = client.get_secret_value( - SecretId=secret_name - ) - print_verbose( - f"get_secret_value_response: {get_secret_value_response}" - ) - except Exception as e: - print_verbose(f"An error occurred - {str(e)}") - # For a list of exceptions thrown, see - # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html - raise e + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) - # assume there is 1 secret per secret_name - secret_dict = json.loads(get_secret_value_response["SecretString"]) - print_verbose(f"secret_dict: {secret_dict}") - for k, v in secret_dict.items(): - secret = v - print_verbose(f"secret: {secret}") + if isinstance(client, AWSSecretsManagerV2): + secret = client.sync_read_secret(secret_name=secret_name) + print_verbose(f"get_secret_value_response: {secret}") elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: try: secret = client.get_secret_from_google_secret_manager( @@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915 return default_value else: raise e + + +def _should_read_secret_from_secret_manager() -> bool: + """ + Returns True if the secret manager should be used to read the secret, False otherwise + + - If the secret manager client is not set, return False + - If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True + - Otherwise, return False + """ + if litellm.secret_manager_client is not None: + if litellm._key_management_settings is not None: + if ( + litellm._key_management_settings.access_mode == "read_only" + or litellm._key_management_settings.access_mode == "read_and_write" + ): + return True + return False diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index d016021fb..00b07ba13 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank from typing import List, Optional, Union from pydantic import BaseModel, PrivateAttr +from typing_extensions import TypedDict class RerankRequest(BaseModel): @@ -19,10 +20,26 @@ class RerankRequest(BaseModel): max_chunks_per_doc: Optional[int] = None +class RerankBilledUnits(TypedDict, total=False): + search_units: int + total_tokens: int + + +class RerankTokens(TypedDict, total=False): + input_tokens: int + output_tokens: int + + +class RerankResponseMeta(TypedDict, total=False): + api_version: dict + billed_units: RerankBilledUnits + tokens: RerankTokens + + class RerankResponse(BaseModel): id: str results: List[dict] # Contains index and relevance_score - meta: Optional[dict] = None # Contains api_version and billed_units + meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units # Define private attributes using PrivateAttr _hidden_params: dict = PrivateAttr(default_factory=dict) diff --git a/litellm/types/router.py b/litellm/types/router.py index 6119ca4b7..bb93aaa63 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel): max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None + ## LOGGING PARAMS ## + litellm_trace_id: Optional[str] = None ## UNIFIED PROJECT/REGION ## region_name: Optional[str] = None ## VERTEX AI ## @@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel): None # timeout when making stream=True calls, if str, pass in as os.environ/ ), organization: Optional[str] = None, # for openai orgs + ## LOGGING PARAMS ## + litellm_trace_id: Optional[str] = None, ## UNIFIED PROJECT/REGION ## region_name: Optional[str] = None, ## VERTEX AI ## diff --git a/litellm/types/utils.py b/litellm/types/utils.py index e3df357be..d02129681 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False): all_litellm_params = [ "metadata", + "litellm_trace_id", "tags", "acompletion", "aimg_generation", @@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"] class StandardLoggingPayload(TypedDict): id: str + trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries) call_type: str response_cost: float response_cost_failure_debug_info: Optional[ diff --git a/litellm/utils.py b/litellm/utils.py index a0f544312..f4f31e6cf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915 messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], + litellm_trace_id=kwargs.get("litellm_trace_id"), function_id=function_id or "", call_type=call_type, start_time=start_time, @@ -2056,6 +2057,7 @@ def get_litellm_params( azure_ad_token_provider=None, user_continue_message=None, base_model=None, + litellm_trace_id=None, ): litellm_params = { "acompletion": acompletion, @@ -2084,6 +2086,7 @@ def get_litellm_params( "user_continue_message": user_continue_message, "base_model": base_model or _get_base_model_from_litellm_call_metadata(metadata=metadata), + "litellm_trace_id": litellm_trace_id, } return litellm_params diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index fb8fb105c..cae3bee12 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -2986,19 +2986,19 @@ "supports_function_calling": true }, "vertex_ai/imagegeneration@006": { - "cost_per_image": 0.020, + "output_cost_per_image": 0.020, "litellm_provider": "vertex_ai-image-models", "mode": "image_generation", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" }, "vertex_ai/imagen-3.0-generate-001": { - "cost_per_image": 0.04, + "output_cost_per_image": 0.04, "litellm_provider": "vertex_ai-image-models", "mode": "image_generation", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" }, "vertex_ai/imagen-3.0-fast-generate-001": { - "cost_per_image": 0.02, + "output_cost_per_image": 0.02, "litellm_provider": "vertex_ai-image-models", "mode": "image_generation", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" @@ -5620,6 +5620,13 @@ "litellm_provider": "bedrock", "mode": "image_generation" }, + "stability.stable-image-ultra-v1:0": { + "max_tokens": 77, + "max_input_tokens": 77, + "output_cost_per_image": 0.14, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "sagemaker/meta-textgeneration-llama-2-7b": { "max_tokens": 4096, "max_input_tokens": 4096, diff --git a/pyproject.toml b/pyproject.toml index 17d37c0ce..fedfebc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.52.6" +version = "1.52.9" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.52.6" +version = "1.52.9" version_files = [ "pyproject.toml:^version" ] diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index acb764ba1..955eed957 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -13,8 +13,11 @@ sys.path.insert( import litellm from litellm.exceptions import BadRequestError from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.utils import CustomStreamWrapper - +from litellm.utils import ( + CustomStreamWrapper, + get_supported_openai_params, + get_optional_params, +) # test_example.py from abc import ABC, abstractmethod @@ -45,6 +48,9 @@ class BaseLLMChatTest(ABC): ) assert response is not None + # for OpenAI the content contains the JSON schema, so we need to assert that the content is not None + assert response.choices[0].message.content is not None + def test_message_with_name(self): base_completion_call_args = self.get_base_completion_call_args() messages = [ @@ -79,6 +85,49 @@ class BaseLLMChatTest(ABC): print(response) + # OpenAI guarantees that the JSON schema is returned in the content + # relevant issue: https://github.com/BerriAI/litellm/issues/6741 + assert response.choices[0].message.content is not None + + def test_json_response_format_stream(self): + """ + Test that the JSON response format with streaming is supported by the LLM API + """ + base_completion_call_args = self.get_base_completion_call_args() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your output should be a JSON object with no additional properties. ", + }, + { + "role": "user", + "content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60", + }, + ] + + response = litellm.completion( + **base_completion_call_args, + messages=messages, + response_format={"type": "json_object"}, + stream=True, + ) + + print(response) + + content = "" + for chunk in response: + content += chunk.choices[0].delta.content or "" + + print("content=", content) + + # OpenAI guarantees that the JSON schema is returned in the content + # relevant issue: https://github.com/BerriAI/litellm/issues/6741 + # we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call) + assert content is not None + assert len(content) > 0 + @pytest.fixture def pdf_messages(self): import base64 diff --git a/tests/llm_translation/base_rerank_unit_tests.py b/tests/llm_translation/base_rerank_unit_tests.py new file mode 100644 index 000000000..2a8b80194 --- /dev/null +++ b/tests/llm_translation/base_rerank_unit_tests.py @@ -0,0 +1,115 @@ +import asyncio +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.exceptions import BadRequestError +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import ( + CustomStreamWrapper, + get_supported_openai_params, + get_optional_params, +) + +# test_example.py +from abc import ABC, abstractmethod + + +def assert_response_shape(response, custom_llm_provider): + expected_response_shape = {"id": str, "results": list, "meta": dict} + + expected_results_shape = {"index": int, "relevance_score": float} + + expected_meta_shape = {"api_version": dict, "billed_units": dict} + + expected_api_version_shape = {"version": str} + + expected_billed_units_shape = {"search_units": int} + + assert isinstance(response.id, expected_response_shape["id"]) + assert isinstance(response.results, expected_response_shape["results"]) + for result in response.results: + assert isinstance(result["index"], expected_results_shape["index"]) + assert isinstance( + result["relevance_score"], expected_results_shape["relevance_score"] + ) + assert isinstance(response.meta, expected_response_shape["meta"]) + + if custom_llm_provider == "cohere": + + assert isinstance( + response.meta["api_version"], expected_meta_shape["api_version"] + ) + assert isinstance( + response.meta["api_version"]["version"], + expected_api_version_shape["version"], + ) + assert isinstance( + response.meta["billed_units"], expected_meta_shape["billed_units"] + ) + assert isinstance( + response.meta["billed_units"]["search_units"], + expected_billed_units_shape["search_units"], + ) + + +class BaseLLMRerankTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + + @abstractmethod + def get_base_rerank_call_args(self) -> dict: + """Must return the base rerank call args""" + pass + + @abstractmethod + def get_custom_llm_provider(self) -> litellm.LlmProviders: + """Must return the custom llm provider""" + pass + + @pytest.mark.asyncio() + @pytest.mark.parametrize("sync_mode", [True, False]) + async def test_basic_rerank(self, sync_mode): + rerank_call_args = self.get_base_rerank_call_args() + custom_llm_provider = self.get_custom_llm_provider() + if sync_mode is True: + response = litellm.rerank( + **rerank_call_args, + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape( + response=response, custom_llm_provider=custom_llm_provider.value + ) + else: + response = await litellm.arerank( + **rerank_call_args, + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("async re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape( + response=response, custom_llm_provider=custom_llm_provider.value + ) diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py index c399c3a47..8a788e0fb 100644 --- a/tests/llm_translation/test_anthropic_completion.py +++ b/tests/llm_translation/test_anthropic_completion.py @@ -33,8 +33,10 @@ from litellm import ( ) from litellm.adapters.anthropic_adapter import anthropic_adapter from litellm.types.llms.anthropic import AnthropicResponse - +from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk +from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk from litellm.llms.anthropic.common_utils import process_anthropic_headers +from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion from httpx import Headers from base_llm_unit_tests import BaseLLMChatTest @@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest): assert _document_validation["type"] == "document" assert _document_validation["source"]["media_type"] == "application/pdf" assert _document_validation["source"]["type"] == "base64" + + +def test_convert_tool_response_to_message_with_values(): + """Test converting a tool response with 'values' key to a message""" + tool_calls = [ + ChatCompletionToolCallChunk( + id="test_id", + type="function", + function=ChatCompletionToolCallFunctionChunk( + name="json_tool_call", + arguments='{"values": {"name": "John", "age": 30}}', + ), + index=0, + ) + ] + + message = AnthropicChatCompletion._convert_tool_response_to_message( + tool_calls=tool_calls + ) + + assert message is not None + assert message.content == '{"name": "John", "age": 30}' + + +def test_convert_tool_response_to_message_without_values(): + """ + Test converting a tool response without 'values' key to a message + + Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly. + + Relevant issue: https://github.com/BerriAI/litellm/issues/6741 + """ + tool_calls = [ + ChatCompletionToolCallChunk( + id="test_id", + type="function", + function=ChatCompletionToolCallFunctionChunk( + name="json_tool_call", arguments='{"name": "John", "age": 30}' + ), + index=0, + ) + ] + + message = AnthropicChatCompletion._convert_tool_response_to_message( + tool_calls=tool_calls + ) + + assert message is not None + assert message.content == '{"name": "John", "age": 30}' + + +def test_convert_tool_response_to_message_invalid_json(): + """Test converting a tool response with invalid JSON""" + tool_calls = [ + ChatCompletionToolCallChunk( + id="test_id", + type="function", + function=ChatCompletionToolCallFunctionChunk( + name="json_tool_call", arguments="invalid json" + ), + index=0, + ) + ] + + message = AnthropicChatCompletion._convert_tool_response_to_message( + tool_calls=tool_calls + ) + + assert message is not None + assert message.content == "invalid json" + + +def test_convert_tool_response_to_message_no_arguments(): + """Test converting a tool response with no arguments""" + tool_calls = [ + ChatCompletionToolCallChunk( + id="test_id", + type="function", + function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"), + index=0, + ) + ] + + message = AnthropicChatCompletion._convert_tool_response_to_message( + tool_calls=tool_calls + ) + + assert message is None diff --git a/tests/llm_translation/test_jina_ai.py b/tests/llm_translation/test_jina_ai.py new file mode 100644 index 000000000..c169b5587 --- /dev/null +++ b/tests/llm_translation/test_jina_ai.py @@ -0,0 +1,23 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +from base_rerank_unit_tests import BaseLLMRerankTest +import litellm + + +class TestJinaAI(BaseLLMRerankTest): + def get_custom_llm_provider(self) -> litellm.LlmProviders: + return litellm.LlmProviders.JINA_AI + + def get_base_rerank_call_args(self) -> dict: + return { + "model": "jina_ai/jina-reranker-v2-base-multilingual", + } diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 8677d6b73..c9527c830 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -923,9 +923,22 @@ def test_watsonx_text_top_k(): assert optional_params["top_k"] == 10 + def test_together_ai_model_params(): optional_params = get_optional_params( model="together_ai", custom_llm_provider="together_ai", logprobs=1 ) print(optional_params) assert optional_params["logprobs"] == 1 + +def test_forward_user_param(): + from litellm.utils import get_supported_openai_params, get_optional_params + + model = "claude-3-5-sonnet-20240620" + optional_params = get_optional_params( + model=model, + user="test_user", + custom_llm_provider="anthropic", + ) + + assert optional_params["metadata"]["user_id"] == "test_user" diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index a06179a49..73960020d 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -16,6 +16,7 @@ import pytest import litellm from litellm import get_optional_params from litellm.llms.custom_httpx.http_handler import HTTPHandler +import httpx def test_completion_pydantic_obj_2(): @@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url): mock_post.assert_called_once() print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"]) assert mock_post.call_args.kwargs["json"] == expected_request_body + + +@pytest.mark.parametrize( + "model, expected_url", + [ + ( + "textembedding-gecko@001", + "https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict", + ), + ( + "123456789", + "https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict", + ), + ], +) +def test_vertex_embedding_url(model, expected_url): + """ + Test URL generation for embedding models, including numeric model IDs (fine-tuned models + + Relevant issue: https://github.com/BerriAI/litellm/issues/6482 + + When a fine-tuned embedding model is used, the URL is different from the standard one. + """ + from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url + + url, endpoint = _get_vertex_url( + mode="embedding", + model=model, + stream=False, + vertex_project="project-id", + vertex_location="us-central1", + vertex_api_version="v1", + ) + + assert url == expected_url + assert endpoint == "predict" diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 2de53696f..5a07d17b7 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -18,6 +18,8 @@ import json import os import tempfile from unittest.mock import AsyncMock, MagicMock, patch +from respx import MockRouter +import httpx import pytest @@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode): data = { "model": model, "messages": messages, + "timeout": 10, } if sync_mode: response = litellm.completion(**data) @@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode): assert isinstance(response._hidden_params["response_cost"], float) except litellm.RateLimitError as e: pass + except litellm.Timeout as e: + pass except litellm.InternalServerError as e: pass except Exception as e: @@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base): assert url == api_base + ":" else: assert url == test_endpoint + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_vertexai_embedding_finetuned(respx_mock: MockRouter): + """ + Tests that: + - Request URL and body are correctly formatted for Vertex AI embeddings + - Response is properly parsed into litellm's embedding response format + """ + load_vertex_ai_credentials() + litellm.set_verbose = True + + # Test input + input_text = ["good morning from litellm", "this is another item"] + + # Expected request/response + expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict" + expected_request = { + "instances": [ + {"inputs": "good morning from litellm"}, + {"inputs": "this is another item"}, + ], + "parameters": {}, + } + + mock_response = { + "predictions": [ + [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector + [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector + ], + "deployedModelId": "2275167734310371328", + "model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876", + "modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876", + "modelVersionId": "1", + } + + # Setup mock request + mock_request = respx_mock.post(expected_url).mock( + return_value=httpx.Response(200, json=mock_response) + ) + + # Make request + response = await litellm.aembedding( + vertex_project="633608382793", + model="vertex_ai/1004708436694269952", + input=input_text, + ) + + # Assert request was made correctly + assert mock_request.called + request_body = json.loads(mock_request.calls[0].request.content) + print("\n\nrequest_body", request_body) + print("\n\nexpected_request", expected_request) + assert request_body == expected_request + + # Assert response structure + assert response is not None + assert hasattr(response, "data") + assert len(response.data) == len(input_text) + + # Assert embedding structure + for embedding in response.data: + assert "embedding" in embedding + assert isinstance(embedding["embedding"], list) + assert len(embedding["embedding"]) > 0 + assert all(isinstance(x, float) for x in embedding["embedding"]) diff --git a/tests/local_testing/test_aws_secret_manager.py b/tests/local_testing/test_aws_secret_manager.py new file mode 100644 index 000000000..f2e2319cc --- /dev/null +++ b/tests/local_testing/test_aws_secret_manager.py @@ -0,0 +1,139 @@ +# What is this? + +import asyncio +import os +import sys +import traceback + +from dotenv import load_dotenv + +import litellm.types +import litellm.types.utils + + +load_dotenv() +import io + +import sys +import os + +# Ensure the project root is in the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +print("Python Path:", sys.path) +print("Current Working Directory:", os.getcwd()) + + +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +import uuid +import json +from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 + + +def check_aws_credentials(): + """Helper function to check if AWS credentials are set""" + required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"] + missing_vars = [var for var in required_vars if not os.getenv(var)] + if missing_vars: + pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}") + + +@pytest.mark.asyncio +async def test_write_and_read_simple_secret(): + """Test writing and reading a simple string secret""" + check_aws_credentials() + + secret_manager = AWSSecretsManagerV2() + test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}" + test_secret_value = "test_value_123" + + try: + # Write secret + write_response = await secret_manager.async_write_secret( + secret_name=test_secret_name, + secret_value=test_secret_value, + description="LiteLLM Test Secret", + ) + + print("Write Response:", write_response) + + assert write_response is not None + assert "ARN" in write_response + assert "Name" in write_response + assert write_response["Name"] == test_secret_name + + # Read secret back + read_value = await secret_manager.async_read_secret( + secret_name=test_secret_name + ) + + print("Read Value:", read_value) + + assert read_value == test_secret_value + finally: + # Cleanup: Delete the secret + delete_response = await secret_manager.async_delete_secret( + secret_name=test_secret_name + ) + print("Delete Response:", delete_response) + assert delete_response is not None + + +@pytest.mark.asyncio +async def test_write_and_read_json_secret(): + """Test writing and reading a JSON structured secret""" + check_aws_credentials() + + secret_manager = AWSSecretsManagerV2() + test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json" + test_secret_value = { + "api_key": "test_key", + "model": "gpt-4", + "temperature": 0.7, + "metadata": {"team": "ml", "project": "litellm"}, + } + + try: + # Write JSON secret + write_response = await secret_manager.async_write_secret( + secret_name=test_secret_name, + secret_value=json.dumps(test_secret_value), + description="LiteLLM JSON Test Secret", + ) + + print("Write Response:", write_response) + + # Read and parse JSON secret + read_value = await secret_manager.async_read_secret( + secret_name=test_secret_name + ) + parsed_value = json.loads(read_value) + + print("Read Value:", read_value) + + assert parsed_value == test_secret_value + assert parsed_value["api_key"] == "test_key" + assert parsed_value["metadata"]["team"] == "ml" + finally: + # Cleanup: Delete the secret + delete_response = await secret_manager.async_delete_secret( + secret_name=test_secret_name + ) + print("Delete Response:", delete_response) + assert delete_response is not None + + +@pytest.mark.asyncio +async def test_read_nonexistent_secret(): + """Test reading a secret that doesn't exist""" + check_aws_credentials() + + secret_manager = AWSSecretsManagerV2() + nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}" + + response = await secret_manager.async_read_secret(secret_name=nonexistent_secret) + + assert response is None diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 43bcfc882..3ce4cb7d7 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries = 3 +# litellm.num_retries=3 litellm.cache = None litellm.success_callback = [] diff --git a/tests/local_testing/test_cost_calc.py b/tests/local_testing/test_cost_calc.py index ecead0679..1831c2a45 100644 --- a/tests/local_testing/test_cost_calc.py +++ b/tests/local_testing/test_cost_calc.py @@ -10,7 +10,7 @@ import os sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system-path from typing import Literal import pytest diff --git a/tests/local_testing/test_custom_callback_input.py b/tests/local_testing/test_custom_callback_input.py index 1744d3891..9b7b6d532 100644 --- a/tests/local_testing/test_custom_callback_input.py +++ b/tests/local_testing/test_custom_callback_input.py @@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode): print(f"standard_logging_object usage: {built_response.usage}") except litellm.InternalServerError: pass + + +def test_standard_logging_retries(): + """ + know if a request was retried. + """ + from litellm.types.utils import StandardLoggingPayload + from litellm.router import Router + + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "openai/gpt-3.5-turbo", + "api_key": "test-api-key", + }, + } + ] + ) + + with patch.object( + customHandler, "log_failure_event", new=MagicMock() + ) as mock_client: + try: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + num_retries=1, + mock_response="litellm.RateLimitError", + ) + except litellm.RateLimitError: + pass + + assert mock_client.call_count == 2 + assert ( + mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][ + "trace_id" + ] + is not None + ) + assert ( + mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][ + "trace_id" + ] + == mock_client.call_args_list[1].kwargs["kwargs"][ + "standard_logging_object" + ]["trace_id"] + ) diff --git a/tests/local_testing/test_exceptions.py b/tests/local_testing/test_exceptions.py index d5f67cecf..67c36928f 100644 --- a/tests/local_testing/test_exceptions.py +++ b/tests/local_testing/test_exceptions.py @@ -58,6 +58,7 @@ async def test_content_policy_exception_azure(): except litellm.ContentPolicyViolationError as e: print("caught a content policy violation error! Passed") print("exception", e) + assert e.response is not None assert e.litellm_debug_info is not None assert isinstance(e.litellm_debug_info, str) assert len(e.litellm_debug_info) > 0 @@ -1152,3 +1153,24 @@ async def test_exception_with_headers_httpx( if exception_raised is False: print(resp) assert exception_raised + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", ["azure/chatgpt-v-2", "openai/gpt-3.5-turbo"]) +async def test_bad_request_error_contains_httpx_response(model): + """ + Test that the BadRequestError contains the httpx response + + Relevant issue: https://github.com/BerriAI/litellm/issues/6732 + """ + try: + await litellm.acompletion( + model=model, + messages=[{"role": "user", "content": "Hello world"}], + bad_arg="bad_arg", + ) + pytest.fail("Expected to raise BadRequestError") + except litellm.BadRequestError as e: + print("e.response", e.response) + print("vars(e.response)", vars(e.response)) + assert e.response is not None diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index 6654c10c2..423ffe2fd 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai(): model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model="jina_ai/jina-embeddings-v3", ) - assert custom_llm_provider == "openai_like" + assert custom_llm_provider == "jina_ai" assert api_base == "https://api.jina.ai/v1" assert model == "jina-embeddings-v3" diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index 82ce9c465..11506ed3d 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat(): "template": "tools", } ), - ): + ) as mock_client: info = OllamaConfig().get_model_info("mistral") - print("info", info) assert info["supports_function_calling"] is True info = get_model_info("ollama/mistral") - print("info", info) + assert info["supports_function_calling"] is True + + mock_client.assert_called() + + print(mock_client.call_args.kwargs) + + assert mock_client.call_args.kwargs["json"]["name"] == "mistral" diff --git a/tests/local_testing/test_router_fallbacks.py b/tests/local_testing/test_router_fallbacks.py index cad640a54..3c9750691 100644 --- a/tests/local_testing/test_router_fallbacks.py +++ b/tests/local_testing/test_router_fallbacks.py @@ -1138,9 +1138,9 @@ async def test_router_content_policy_fallbacks( router = Router( model_list=[ { - "model_name": "claude-2", + "model_name": "claude-2.1", "litellm_params": { - "model": "claude-2", + "model": "claude-2.1", "api_key": "", "mock_response": mock_response, }, @@ -1164,7 +1164,7 @@ async def test_router_content_policy_fallbacks( { "model_name": "my-general-model", "litellm_params": { - "model": "claude-2", + "model": "claude-2.1", "api_key": "", "mock_response": Exception("Should not have called this."), }, @@ -1172,14 +1172,14 @@ async def test_router_content_policy_fallbacks( { "model_name": "my-context-window-model", "litellm_params": { - "model": "claude-2", + "model": "claude-2.1", "api_key": "", "mock_response": Exception("Should not have called this."), }, }, ], content_policy_fallbacks=( - [{"claude-2": ["my-fallback-model"]}] + [{"claude-2.1": ["my-fallback-model"]}] if fallback_type == "model-specific" else None ), @@ -1190,12 +1190,12 @@ async def test_router_content_policy_fallbacks( if sync_mode is True: response = router.completion( - model="claude-2", + model="claude-2.1", messages=[{"role": "user", "content": "Hey, how's it going?"}], ) else: response = await router.acompletion( - model="claude-2", + model="claude-2.1", messages=[{"role": "user", "content": "Hey, how's it going?"}], ) @@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode): assert isinstance( exc_info.value, litellm.AuthenticationError ), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}" + + +@pytest.mark.asyncio +async def test_router_disable_fallbacks_dynamically(): + from litellm.router import run_async_fallback + + router = Router( + model_list=[ + { + "model_name": "bad-model", + "litellm_params": { + "model": "openai/my-bad-model", + "api_key": "my-bad-api-key", + }, + }, + { + "model_name": "good-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ], + fallbacks=[{"bad-model": ["good-model"]}], + default_fallbacks=["good-model"], + ) + + with patch.object( + router, + "log_retry", + new=MagicMock(return_value=None), + ) as mock_client: + try: + resp = await router.acompletion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + disable_fallbacks=True, + ) + print(resp) + except Exception as e: + print(e) + + mock_client.assert_not_called() diff --git a/tests/local_testing/test_router_utils.py b/tests/local_testing/test_router_utils.py index 538ab4d0b..d266cfbd9 100644 --- a/tests/local_testing/test_router_utils.py +++ b/tests/local_testing/test_router_utils.py @@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv +from unittest.mock import patch, MagicMock, AsyncMock load_dotenv() @@ -83,3 +84,93 @@ def test_returned_settings(): except Exception: print(traceback.format_exc()) pytest.fail("An error occurred - " + traceback.format_exc()) + + +from litellm.types.utils import CallTypes + + +def test_update_kwargs_before_fallbacks_unit_test(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + } + ], + ) + + kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]} + + router._update_kwargs_before_fallbacks( + model="gpt-3.5-turbo", + kwargs=kwargs, + ) + + assert kwargs["litellm_trace_id"] is not None + + +@pytest.mark.parametrize( + "call_type", + [ + CallTypes.acompletion, + CallTypes.atext_completion, + CallTypes.aembedding, + CallTypes.arerank, + CallTypes.atranscription, + ], +) +@pytest.mark.asyncio +async def test_update_kwargs_before_fallbacks(call_type): + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + } + ], + ) + + if call_type.value.startswith("a"): + with patch.object(router, "async_function_with_fallbacks") as mock_client: + if call_type.value == "acompletion": + input_kwarg = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + elif ( + call_type.value == "atext_completion" + or call_type.value == "aimage_generation" + ): + input_kwarg = { + "prompt": "Hello, how are you?", + } + elif call_type.value == "aembedding" or call_type.value == "arerank": + input_kwarg = { + "input": "Hello, how are you?", + } + elif call_type.value == "atranscription": + input_kwarg = { + "file": "path/to/file", + } + else: + input_kwarg = {} + + await getattr(router, call_type.value)( + model="gpt-3.5-turbo", + **input_kwarg, + ) + + mock_client.assert_called_once() + + print(mock_client.call_args.kwargs) + assert mock_client.call_args.kwargs["litellm_trace_id"] is not None diff --git a/tests/local_testing/test_secret_manager.py b/tests/local_testing/test_secret_manager.py index 397128ecb..1b95119a3 100644 --- a/tests/local_testing/test_secret_manager.py +++ b/tests/local_testing/test_secret_manager.py @@ -15,22 +15,29 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest - +import litellm from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM -from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager -from litellm.secret_managers.main import get_secret +from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 +from litellm.secret_managers.main import ( + get_secret, + _should_read_secret_from_secret_manager, +) -@pytest.mark.skip(reason="AWS Suspended Account") def test_aws_secret_manager(): - load_aws_secret_manager(use_aws_secret_manager=True) + import json + + AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True) secret_val = get_secret("litellm_master_key") print(f"secret_val: {secret_val}") - assert secret_val == "sk-1234" + # cast json to dict + secret_val = json.loads(secret_val) + + assert secret_val["litellm_master_key"] == "sk-1234" def redact_oidc_signature(secret_val): @@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory(): ) print("secret_val: {}".format(secret_val)) assert secret_val == "lite-llm" + + +def test_should_read_secret_from_secret_manager(): + """ + Test that _should_read_secret_from_secret_manager returns correct values based on access mode + """ + from litellm.proxy._types import KeyManagementSettings + + # Test when secret manager client is None + litellm.secret_manager_client = None + litellm._key_management_settings = KeyManagementSettings() + assert _should_read_secret_from_secret_manager() is False + + # Test with secret manager client and read_only access + litellm.secret_manager_client = "dummy_client" + litellm._key_management_settings = KeyManagementSettings(access_mode="read_only") + assert _should_read_secret_from_secret_manager() is True + + # Test with secret manager client and read_and_write access + litellm._key_management_settings = KeyManagementSettings( + access_mode="read_and_write" + ) + assert _should_read_secret_from_secret_manager() is True + + # Test with secret manager client and write_only access + litellm._key_management_settings = KeyManagementSettings(access_mode="write_only") + assert _should_read_secret_from_secret_manager() is False + + # Reset global variables + litellm.secret_manager_client = None + litellm._key_management_settings = KeyManagementSettings() + + +def test_get_secret_with_access_mode(): + """ + Test that get_secret respects access mode settings + """ + from litellm.proxy._types import KeyManagementSettings + + # Set up test environment + test_secret_name = "TEST_SECRET_KEY" + test_secret_value = "test_secret_value" + os.environ[test_secret_name] = test_secret_value + + # Test with write_only access (should read from os.environ) + litellm.secret_manager_client = "dummy_client" + litellm._key_management_settings = KeyManagementSettings(access_mode="write_only") + assert get_secret(test_secret_name) == test_secret_value + + # Test with no KeyManagementSettings but secret_manager_client set + litellm.secret_manager_client = "dummy_client" + litellm._key_management_settings = KeyManagementSettings() + assert _should_read_secret_from_secret_manager() is True + + # Test with read_only access + litellm._key_management_settings = KeyManagementSettings(access_mode="read_only") + assert _should_read_secret_from_secret_manager() is True + + # Test with read_and_write access + litellm._key_management_settings = KeyManagementSettings( + access_mode="read_and_write" + ) + assert _should_read_secret_from_secret_manager() is True + + # Reset global variables + litellm.secret_manager_client = None + litellm._key_management_settings = KeyManagementSettings() + del os.environ[test_secret_name] diff --git a/tests/local_testing/test_stream_chunk_builder.py b/tests/local_testing/test_stream_chunk_builder.py index 2548abdb7..4fb44299d 100644 --- a/tests/local_testing/test_stream_chunk_builder.py +++ b/tests/local_testing/test_stream_chunk_builder.py @@ -184,12 +184,11 @@ def test_stream_chunk_builder_litellm_usage_chunks(): {"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"}, {"role": "user", "content": "\nI am waiting...\n\n...\n"}, ] - # make a regular gemini call usage: litellm.Usage = Usage( - completion_tokens=64, + completion_tokens=27, prompt_tokens=55, - total_tokens=119, + total_tokens=82, completion_tokens_details=None, prompt_tokens_details=None, ) diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index bc4827d92..0bc6953f9 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -718,7 +718,7 @@ async def test_acompletion_claude_2_stream(): try: litellm.set_verbose = True response = await litellm.acompletion( - model="claude-2", + model="claude-2.1", messages=[{"role": "user", "content": "hello from litellm"}], stream=True, ) @@ -3274,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming(): ], # "claude-3-opus-20240229" ) # @pytest.mark.asyncio -async def test_acompletion_claude_3_function_call_with_streaming(model): +async def test_acompletion_function_call_with_streaming(model): litellm.set_verbose = True tools = [ { @@ -3335,6 +3335,8 @@ async def test_acompletion_claude_3_function_call_with_streaming(model): # raise Exception("it worked! ") except litellm.InternalServerError as e: pytest.skip(f"InternalServerError - {str(e)}") + except litellm.ServiceUnavailableError: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 78b558cd2..b97ab3514 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed(): request=request, api_key="Bearer sk-123456789", ) + + +## E2E Virtual Key + Secret Manager Tests ######################################### + + +@pytest.mark.asyncio +async def test_key_generate_with_secret_manager_call(prisma_client): + """ + Generate a key + assert it exists in the secret manager + + delete the key + assert it is deleted from the secret manager + """ + from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 + from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings + + litellm.set_verbose = True + + #### Test Setup ############################################################ + aws_secret_manager_client = AWSSecretsManagerV2() + litellm.secret_manager_client = aws_secret_manager_client + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + litellm._key_management_settings = KeyManagementSettings( + store_virtual_keys=True, + ) + general_settings = { + "key_management_system": "aws_secret_manager", + "key_management_settings": { + "store_virtual_keys": True, + }, + } + + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + await litellm.proxy.proxy_server.prisma_client.connect() + ############################################################################ + + # generate new key + key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + await asyncio.sleep(2) + + # read from the secret manager + result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias) + + # Assert the correct key is stored in the secret manager + print("response from AWS Secret Manager") + print(result) + assert result == generated_key + + # delete the key + await delete_key_fn( + data=KeyRequest(keys=[generated_key]), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234" + ), + ) + + await asyncio.sleep(2) + + # Assert the key is deleted from the secret manager + result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias) + assert result is None + + # cleanup + setattr(litellm.proxy.proxy_server, "general_settings", {}) + + +################################################################################ diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index 5588d0414..b1c00ce75 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils( assert new_data["failure_callback"] == expected_failure_callbacks +@pytest.mark.asyncio +@pytest.mark.parametrize( + "disable_fallbacks_set", + [ + True, + False, + ], +) +async def test_disable_fallbacks_by_key(disable_fallbacks_set): + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + + key_metadata = {"disable_fallbacks": disable_fallbacks_set} + existing_data = { + "model": "azure/chatgpt-v-2", + "messages": [{"role": "user", "content": "write 1 sentence poem"}], + } + data = LiteLLMProxyRequestSetup.add_key_level_controls( + key_metadata=key_metadata, + data=existing_data, + _metadata_variable_name="metadata", + ) + + assert data["disable_fallbacks"] == disable_fallbacks_set + + @pytest.mark.asyncio @pytest.mark.parametrize( "callback_type, expected_success_callbacks, expected_failure_callbacks",