diff --git a/docs/my-website/docs/enterprise.md b/docs/my-website/docs/enterprise.md
index e3758266a1..5bd09ec156 100644
--- a/docs/my-website/docs/enterprise.md
+++ b/docs/my-website/docs/enterprise.md
@@ -20,6 +20,8 @@ This covers:
- **Spend Tracking**
- ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags)
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
+ - **Advanced Metrics**
+ - ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](./proxy/prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
- **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai)
diff --git a/docs/my-website/docs/proxy/debugging.md b/docs/my-website/docs/proxy/debugging.md
index 571a97c0ec..38680982a3 100644
--- a/docs/my-website/docs/proxy/debugging.md
+++ b/docs/my-website/docs/proxy/debugging.md
@@ -88,4 +88,31 @@ Expected Output:
```bash
# no info statements
-```
\ No newline at end of file
+```
+
+## Common Errors
+
+1. "No available deployments..."
+
+```
+No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
+```
+
+This can be caused due to all your models hitting rate limit errors, causing the cooldown to kick in.
+
+How to control this?
+- Adjust the cooldown time
+
+```yaml
+router_settings:
+ cooldown_time: 0 # 👈 KEY CHANGE
+```
+
+- Disable Cooldowns [NOT RECOMMENDED]
+
+```yaml
+router_settings:
+ disable_cooldowns: True
+```
+
+This is not recommended, as it will lead to requests being routed to deployments over their tpm/rpm limit.
\ No newline at end of file
diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md
index e061a917e2..5dabba5ed3 100644
--- a/docs/my-website/docs/proxy/enterprise.md
+++ b/docs/my-website/docs/proxy/enterprise.md
@@ -23,6 +23,8 @@ Features:
- **Spend Tracking**
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
+- **Advanced Metrics**
+ - ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
- **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md
index 2c7481f4c6..6790b25b02 100644
--- a/docs/my-website/docs/proxy/prometheus.md
+++ b/docs/my-website/docs/proxy/prometheus.md
@@ -1,3 +1,6 @@
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
# 📈 Prometheus metrics [BETA]
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
@@ -61,6 +64,56 @@ http://localhost:4000/metrics
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
+### ✨ (Enterprise) LLM Remaining Requests and Remaining Tokens
+Set this on your config.yaml to allow you to track how close you are to hitting your TPM / RPM limits on each model group
+
+```yaml
+litellm_settings:
+ success_callback: ["prometheus"]
+ failure_callback: ["prometheus"]
+ return_response_headers: true # ensures the LLM API calls track the response headers
+```
+
+| Metric Name | Description |
+|----------------------|--------------------------------------|
+| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
+| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
+
+Example Metric
+
+
+
+
+```shell
+litellm_remaining_requests
+{
+ api_base="https://api.openai.com/v1",
+ api_provider="openai",
+ litellm_model_name="gpt-3.5-turbo",
+ model_group="gpt-3.5-turbo"
+}
+8998.0
+```
+
+
+
+
+
+```shell
+litellm_remaining_tokens
+{
+ api_base="https://api.openai.com/v1",
+ api_provider="openai",
+ litellm_model_name="gpt-3.5-turbo",
+ model_group="gpt-3.5-turbo"
+}
+999981.0
+```
+
+
+
+
+
## Monitor System Health
To monitor the health of litellm adjacent services (redis / postgres), do:
diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md
index cda3a46af9..cc1d5fe821 100644
--- a/docs/my-website/docs/proxy/user_keys.md
+++ b/docs/my-website/docs/proxy/user_keys.md
@@ -152,6 +152,27 @@ response = chat(messages)
print(response)
```
+
+
+
+```js
+import { ChatOpenAI } from "@langchain/openai";
+
+
+const model = new ChatOpenAI({
+ modelName: "gpt-4",
+ openAIApiKey: "sk-1234",
+ modelKwargs: {"metadata": "hello world"} // 👈 PASS Additional params here
+}, {
+ basePath: "http://0.0.0.0:4000",
+});
+
+const message = await model.invoke("Hi there!");
+
+console.log(message);
+
+```
+
diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md
index 240e6c8e04..905954e979 100644
--- a/docs/my-website/docs/routing.md
+++ b/docs/my-website/docs/routing.md
@@ -815,6 +815,35 @@ model_list:
+**Expected Response**
+
+```
+No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
+```
+
+#### **Disable cooldowns**
+
+
+
+
+
+```python
+from litellm import Router
+
+
+router = Router(..., disable_cooldowns=True)
+```
+
+
+
+```yaml
+router_settings:
+ disable_cooldowns: True
+```
+
+
+
+
### Retries
For both async + sync functions, we support retrying failed requests.
diff --git a/litellm/__init__.py b/litellm/__init__.py
index cc31ea9990..29b5bc360a 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -125,6 +125,9 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
##################
### PREVIEW FEATURES ###
enable_preview_features: bool = False
+return_response_headers: bool = (
+ False # get response headers from LLM Api providers - example x-remaining-requests,
+)
##################
logging: bool = True
caching: bool = (
diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py
index 4f0ffa387e..6cd7469079 100644
--- a/litellm/integrations/prometheus.py
+++ b/litellm/integrations/prometheus.py
@@ -2,14 +2,20 @@
#### What this does ####
# On success, log events to Prometheus
-import dotenv, os
-import requests # type: ignore
+import datetime
+import os
+import subprocess
+import sys
import traceback
-import datetime, subprocess, sys
-import litellm, uuid
-from litellm._logging import print_verbose, verbose_logger
+import uuid
from typing import Optional, Union
+import dotenv
+import requests # type: ignore
+
+import litellm
+from litellm._logging import print_verbose, verbose_logger
+
class PrometheusLogger:
# Class variables or attributes
@@ -20,6 +26,8 @@ class PrometheusLogger:
try:
from prometheus_client import Counter, Gauge
+ from litellm.proxy.proxy_server import premium_user
+
self.litellm_llm_api_failed_requests_metric = Counter(
name="litellm_llm_api_failed_requests_metric",
documentation="Total number of failed LLM API calls via litellm",
@@ -88,6 +96,31 @@ class PrometheusLogger:
labelnames=["hashed_api_key", "api_key_alias"],
)
+ # Litellm-Enterprise Metrics
+ if premium_user is True:
+ # Remaining Rate Limit for model
+ self.litellm_remaining_requests_metric = Gauge(
+ "litellm_remaining_requests",
+ "remaining requests for model, returned from LLM API Provider",
+ labelnames=[
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name",
+ ],
+ )
+
+ self.litellm_remaining_tokens_metric = Gauge(
+ "litellm_remaining_tokens",
+ "remaining tokens for model, returned from LLM API Provider",
+ labelnames=[
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name",
+ ],
+ )
+
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
@@ -104,6 +137,8 @@ class PrometheusLogger:
):
try:
# Define prometheus client
+ from litellm.proxy.proxy_server import premium_user
+
verbose_logger.debug(
f"prometheus Logging - Enters logging function for model {kwargs}"
)
@@ -199,6 +234,10 @@ class PrometheusLogger:
user_api_key, user_api_key_alias
).set(_remaining_api_key_budget)
+ # set x-ratelimit headers
+ if premium_user is True:
+ self.set_remaining_tokens_requests_metric(kwargs)
+
### FAILURE INCREMENT ###
if "exception" in kwargs:
self.litellm_llm_api_failed_requests_metric.labels(
@@ -216,6 +255,58 @@ class PrometheusLogger:
verbose_logger.debug(traceback.format_exc())
pass
+ def set_remaining_tokens_requests_metric(self, request_kwargs: dict):
+ try:
+ verbose_logger.debug("setting remaining tokens requests metric")
+ _response_headers = request_kwargs.get("response_headers")
+ _litellm_params = request_kwargs.get("litellm_params", {}) or {}
+ _metadata = _litellm_params.get("metadata", {})
+ litellm_model_name = request_kwargs.get("model", None)
+ model_group = _metadata.get("model_group", None)
+ api_base = _metadata.get("api_base", None)
+ llm_provider = _litellm_params.get("custom_llm_provider", None)
+
+ remaining_requests = None
+ remaining_tokens = None
+ # OpenAI / OpenAI Compatible headers
+ if (
+ _response_headers
+ and "x-ratelimit-remaining-requests" in _response_headers
+ ):
+ remaining_requests = _response_headers["x-ratelimit-remaining-requests"]
+ if (
+ _response_headers
+ and "x-ratelimit-remaining-tokens" in _response_headers
+ ):
+ remaining_tokens = _response_headers["x-ratelimit-remaining-tokens"]
+ verbose_logger.debug(
+ f"remaining requests: {remaining_requests}, remaining tokens: {remaining_tokens}"
+ )
+
+ if remaining_requests:
+ """
+ "model_group",
+ "api_provider",
+ "api_base",
+ "litellm_model_name"
+ """
+ self.litellm_remaining_requests_metric.labels(
+ model_group, llm_provider, api_base, litellm_model_name
+ ).set(remaining_requests)
+
+ if remaining_tokens:
+ self.litellm_remaining_tokens_metric.labels(
+ model_group, llm_provider, api_base, litellm_model_name
+ ).set(remaining_tokens)
+
+ except Exception as e:
+ verbose_logger.error(
+ "Prometheus Error: set_remaining_tokens_requests_metric. Exception occured - {}".format(
+ str(e)
+ )
+ )
+ return
+
def safe_get_remaining_budget(
max_budget: Optional[float], spend: Optional[float]
diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py
index bce0fef8cd..04195705a0 100644
--- a/litellm/integrations/slack_alerting.py
+++ b/litellm/integrations/slack_alerting.py
@@ -606,6 +606,13 @@ class SlackAlerting(CustomLogger):
and request_data.get("litellm_status", "") != "success"
and request_data.get("litellm_status", "") != "fail"
):
+ ## CHECK IF CACHE IS UPDATED
+ litellm_call_id = request_data.get("litellm_call_id", "")
+ status: Optional[str] = await self.internal_usage_cache.async_get_cache(
+ key="request_status:{}".format(litellm_call_id), local_only=True
+ )
+ if status is not None and (status == "success" or status == "fail"):
+ return
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py
index e127ecea6a..8932e44941 100644
--- a/litellm/llms/azure.py
+++ b/litellm/llms/azure.py
@@ -1,6 +1,7 @@
import asyncio
import json
import os
+import time
import types
import uuid
from typing import (
@@ -21,8 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
from typing_extensions import overload
import litellm
-from litellm import OpenAIConfig
+from litellm import ImageResponse, OpenAIConfig
from litellm.caching import DualCache
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import (
Choices,
CustomStreamWrapper,
@@ -32,6 +35,7 @@ from litellm.utils import (
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
+ modify_url,
)
from ..types.llms.openai import (
@@ -458,6 +462,36 @@ class AzureChatCompletion(BaseLLM):
return azure_client
+ async def make_azure_openai_chat_completion_request(
+ self,
+ azure_client: AsyncAzureOpenAI,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ ):
+ """
+ Helper to:
+ - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
+ - call chat.completions.create by default
+ """
+ try:
+ if litellm.return_response_headers is True:
+ raw_response = (
+ await azure_client.chat.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ )
+
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ return headers, response
+ else:
+ response = await azure_client.chat.completions.create(
+ **data, timeout=timeout
+ )
+ return None, response
+ except Exception as e:
+ raise e
+
def completion(
self,
model: str,
@@ -470,7 +504,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: str,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
- logging_obj,
+ logging_obj: LiteLLMLoggingObj,
optional_params,
litellm_params,
logger_fn,
@@ -649,9 +683,9 @@ class AzureChatCompletion(BaseLLM):
data: dict,
timeout: Any,
model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
- logging_obj=None,
):
response = None
try:
@@ -701,9 +735,13 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
- response = await azure_client.chat.completions.create(
- **data, timeout=timeout
+
+ headers, response = await self.make_azure_openai_chat_completion_request(
+ azure_client=azure_client,
+ data=data,
+ timeout=timeout,
)
+ logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
logging_obj.post_call(
@@ -812,7 +850,7 @@ class AzureChatCompletion(BaseLLM):
async def async_streaming(
self,
- logging_obj,
+ logging_obj: LiteLLMLoggingObj,
api_base: str,
api_key: str,
api_version: str,
@@ -861,9 +899,14 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
- response = await azure_client.chat.completions.create(
- **data, timeout=timeout
+
+ headers, response = await self.make_azure_openai_chat_completion_request(
+ azure_client=azure_client,
+ data=data,
+ timeout=timeout,
)
+ logging_obj.model_call_details["response_headers"] = headers
+
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
@@ -1011,6 +1054,234 @@ class AzureChatCompletion(BaseLLM):
else:
raise AzureOpenAIError(status_code=500, message=str(e))
+ async def make_async_azure_httpx_request(
+ self,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: str,
+ api_version: str,
+ api_key: str,
+ data: dict,
+ ) -> httpx.Response:
+ """
+ Implemented for azure dall-e-2 image gen calls
+
+ Alternative to needing a custom transport implementation
+ """
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ async_handler = AsyncHTTPHandler(**_params) # type: ignore
+ else:
+ async_handler = client # type: ignore
+
+ if (
+ "images/generations" in api_base
+ and api_version
+ in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
+ "2023-06-01-preview",
+ "2023-07-01-preview",
+ "2023-08-01-preview",
+ "2023-09-01-preview",
+ "2023-10-01-preview",
+ ]
+ ): # CREATE + POLL for azure dall-e-2 calls
+
+ api_base = modify_url(
+ original_url=api_base, new_path="/openai/images/generations:submit"
+ )
+
+ data.pop(
+ "model", None
+ ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
+ response = await async_handler.post(
+ url=api_base,
+ data=json.dumps(data),
+ headers={
+ "Content-Type": "application/json",
+ "api-key": api_key,
+ },
+ )
+ operation_location_url = response.headers["operation-location"]
+ response = await async_handler.get(
+ url=operation_location_url,
+ headers={
+ "api-key": api_key,
+ },
+ )
+
+ await response.aread()
+
+ timeout_secs: int = 120
+ start_time = time.time()
+ if "status" not in response.json():
+ raise Exception(
+ "Expected 'status' in response. Got={}".format(response.json())
+ )
+ while response.json()["status"] not in ["succeeded", "failed"]:
+ if time.time() - start_time > timeout_secs:
+ timeout_msg = {
+ "error": {
+ "code": "Timeout",
+ "message": "Operation polling timed out.",
+ }
+ }
+
+ raise AzureOpenAIError(
+ status_code=408, message="Operation polling timed out."
+ )
+
+ await asyncio.sleep(int(response.headers.get("retry-after") or 10))
+ response = await async_handler.get(
+ url=operation_location_url,
+ headers={
+ "api-key": api_key,
+ },
+ )
+ await response.aread()
+
+ if response.json()["status"] == "failed":
+ error_data = response.json()
+ raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
+
+ return response
+ return await async_handler.post(
+ url=api_base,
+ json=data,
+ headers={
+ "Content-Type": "application/json;",
+ "api-key": api_key,
+ },
+ )
+
+ def make_sync_azure_httpx_request(
+ self,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: str,
+ api_version: str,
+ api_key: str,
+ data: dict,
+ ) -> httpx.Response:
+ """
+ Implemented for azure dall-e-2 image gen calls
+
+ Alternative to needing a custom transport implementation
+ """
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler = HTTPHandler(**_params) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ if (
+ "images/generations" in api_base
+ and api_version
+ in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
+ "2023-06-01-preview",
+ "2023-07-01-preview",
+ "2023-08-01-preview",
+ "2023-09-01-preview",
+ "2023-10-01-preview",
+ ]
+ ): # CREATE + POLL for azure dall-e-2 calls
+
+ api_base = modify_url(
+ original_url=api_base, new_path="/openai/images/generations:submit"
+ )
+
+ data.pop(
+ "model", None
+ ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
+ response = sync_handler.post(
+ url=api_base,
+ data=json.dumps(data),
+ headers={
+ "Content-Type": "application/json",
+ "api-key": api_key,
+ },
+ )
+ operation_location_url = response.headers["operation-location"]
+ response = sync_handler.get(
+ url=operation_location_url,
+ headers={
+ "api-key": api_key,
+ },
+ )
+
+ response.read()
+
+ timeout_secs: int = 120
+ start_time = time.time()
+ if "status" not in response.json():
+ raise Exception(
+ "Expected 'status' in response. Got={}".format(response.json())
+ )
+ while response.json()["status"] not in ["succeeded", "failed"]:
+ if time.time() - start_time > timeout_secs:
+ raise AzureOpenAIError(
+ status_code=408, message="Operation polling timed out."
+ )
+
+ time.sleep(int(response.headers.get("retry-after") or 10))
+ response = sync_handler.get(
+ url=operation_location_url,
+ headers={
+ "api-key": api_key,
+ },
+ )
+ response.read()
+
+ if response.json()["status"] == "failed":
+ error_data = response.json()
+ raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
+
+ return response
+ return sync_handler.post(
+ url=api_base,
+ json=data,
+ headers={
+ "Content-Type": "application/json;",
+ "api-key": api_key,
+ },
+ )
+
+ def create_azure_base_url(
+ self, azure_client_params: dict, model: Optional[str]
+ ) -> str:
+
+ api_base: str = azure_client_params.get(
+ "azure_endpoint", ""
+ ) # "https://example-endpoint.openai.azure.com"
+ if api_base.endswith("/"):
+ api_base = api_base.rstrip("/")
+ api_version: str = azure_client_params.get("api_version", "")
+ if model is None:
+ model = ""
+ new_api_base = (
+ api_base
+ + "/openai/deployments/"
+ + model
+ + "/images/generations"
+ + "?api-version="
+ + api_version
+ )
+
+ return new_api_base
+
async def aimage_generation(
self,
data: dict,
@@ -1022,30 +1293,40 @@ class AzureChatCompletion(BaseLLM):
logging_obj=None,
timeout=None,
):
- response = None
+ response: Optional[dict] = None
try:
- if client is None:
- client_session = litellm.aclient_session or httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(),
- )
- azure_client = AsyncAzureOpenAI(
- http_client=client_session, **azure_client_params
- )
- else:
- azure_client = client
+ # response = await azure_client.images.generate(**data, timeout=timeout)
+ api_base: str = azure_client_params.get(
+ "api_base", ""
+ ) # "https://example-endpoint.openai.azure.com"
+ if api_base.endswith("/"):
+ api_base = api_base.rstrip("/")
+ api_version: str = azure_client_params.get("api_version", "")
+ img_gen_api_base = self.create_azure_base_url(
+ azure_client_params=azure_client_params, model=data.get("model", "")
+ )
+
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
- api_key=azure_client.api_key,
+ api_key=api_key,
additional_args={
- "headers": {"api_key": azure_client.api_key},
- "api_base": azure_client._base_url._uri_reference,
- "acompletion": True,
"complete_input_dict": data,
+ "api_base": img_gen_api_base,
+ "headers": {"api_key": api_key},
},
)
- response = await azure_client.images.generate(**data, timeout=timeout)
- stringified_response = response.model_dump()
+ httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
+ client=None,
+ timeout=timeout,
+ api_base=img_gen_api_base,
+ api_version=api_version,
+ api_key=api_key,
+ data=data,
+ )
+ response = httpx_response.json()["result"]
+
+ stringified_response = response
## LOGGING
logging_obj.post_call(
input=input,
@@ -1128,28 +1409,30 @@ class AzureChatCompletion(BaseLLM):
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response
- if client is None:
- client_session = litellm.client_session or httpx.Client(
- transport=CustomHTTPTransport(),
- )
- azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
- else:
- azure_client = client
+ img_gen_api_base = self.create_azure_base_url(
+ azure_client_params=azure_client_params, model=data.get("model", "")
+ )
## LOGGING
logging_obj.pre_call(
- input=prompt,
- api_key=azure_client.api_key,
+ input=data["prompt"],
+ api_key=api_key,
additional_args={
- "headers": {"api_key": azure_client.api_key},
- "api_base": azure_client._base_url._uri_reference,
- "acompletion": False,
"complete_input_dict": data,
+ "api_base": img_gen_api_base,
+ "headers": {"api_key": api_key},
},
)
+ httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
+ client=None,
+ timeout=timeout,
+ api_base=img_gen_api_base,
+ api_version=api_version or "",
+ api_key=api_key or "",
+ data=data,
+ )
+ response = httpx_response.json()["result"]
- ## COMPLETION CALL
- response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
## LOGGING
logging_obj.post_call(
input=prompt,
@@ -1158,7 +1441,7 @@ class AzureChatCompletion(BaseLLM):
original_response=response,
)
# return response
- return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
+ return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py
index 32e63b9576..990ef2faeb 100644
--- a/litellm/llms/openai.py
+++ b/litellm/llms/openai.py
@@ -21,6 +21,7 @@ from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import ProviderField
from litellm.utils import (
Choices,
@@ -652,6 +653,36 @@ class OpenAIChatCompletion(BaseLLM):
else:
return client
+ async def make_openai_chat_completion_request(
+ self,
+ openai_aclient: AsyncOpenAI,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ ):
+ """
+ Helper to:
+ - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
+ - call chat.completions.create by default
+ """
+ try:
+ if litellm.return_response_headers is True:
+ raw_response = (
+ await openai_aclient.chat.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ )
+
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ return headers, response
+ else:
+ response = await openai_aclient.chat.completions.create(
+ **data, timeout=timeout
+ )
+ return None, response
+ except Exception as e:
+ raise e
+
def completion(
self,
model_response: ModelResponse,
@@ -836,13 +867,13 @@ class OpenAIChatCompletion(BaseLLM):
self,
data: dict,
model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
- logging_obj=None,
headers=None,
):
response = None
@@ -869,8 +900,8 @@ class OpenAIChatCompletion(BaseLLM):
},
)
- response = await openai_aclient.chat.completions.create(
- **data, timeout=timeout
+ headers, response = await self.make_openai_chat_completion_request(
+ openai_aclient=openai_aclient, data=data, timeout=timeout
)
stringified_response = response.model_dump()
logging_obj.post_call(
@@ -879,9 +910,11 @@ class OpenAIChatCompletion(BaseLLM):
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
+ logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
+ hidden_params={"headers": headers},
)
except Exception as e:
raise e
@@ -931,10 +964,10 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(
self,
- logging_obj,
timeout: Union[float, httpx.Timeout],
data: dict,
model: str,
+ logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
@@ -965,9 +998,10 @@ class OpenAIChatCompletion(BaseLLM):
},
)
- response = await openai_aclient.chat.completions.create(
- **data, timeout=timeout
+ headers, response = await self.make_openai_chat_completion_request(
+ openai_aclient=openai_aclient, data=data, timeout=timeout
)
+ logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
@@ -992,17 +1026,43 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
+ # Embedding
+ async def make_openai_embedding_request(
+ self,
+ openai_aclient: AsyncOpenAI,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ ):
+ """
+ Helper to:
+ - call embeddings.create.with_raw_response when litellm.return_response_headers is True
+ - call embeddings.create by default
+ """
+ try:
+ if litellm.return_response_headers is True:
+ raw_response = await openai_aclient.embeddings.with_raw_response.create(
+ **data, timeout=timeout
+ ) # type: ignore
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ return headers, response
+ else:
+ response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
+ return None, response
+ except Exception as e:
+ raise e
+
async def aembedding(
self,
input: list,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
timeout: float,
+ logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[AsyncOpenAI] = None,
max_retries=None,
- logging_obj=None,
):
response = None
try:
@@ -1014,7 +1074,10 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries,
client=client,
)
- response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
+ headers, response = await self.make_openai_embedding_request(
+ openai_aclient=openai_aclient, data=data, timeout=timeout
+ )
+ logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@@ -1229,6 +1292,34 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=str(e))
+ # Audio Transcriptions
+ async def make_openai_audio_transcriptions_request(
+ self,
+ openai_aclient: AsyncOpenAI,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ ):
+ """
+ Helper to:
+ - call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
+ - call openai_aclient.audio.transcriptions.create by default
+ """
+ try:
+ if litellm.return_response_headers is True:
+ raw_response = (
+ await openai_aclient.audio.transcriptions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ ) # type: ignore
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ return headers, response
+ else:
+ response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
+ return None, response
+ except Exception as e:
+ raise e
+
def audio_transcriptions(
self,
model: str,
@@ -1286,11 +1377,11 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
model_response: TranscriptionResponse,
timeout: float,
+ logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
- logging_obj=None,
):
try:
openai_aclient = self._get_openai_client(
@@ -1302,9 +1393,12 @@ class OpenAIChatCompletion(BaseLLM):
client=client,
)
- response = await openai_aclient.audio.transcriptions.create(
- **data, timeout=timeout
- ) # type: ignore
+ headers, response = await self.make_openai_audio_transcriptions_request(
+ openai_aclient=openai_aclient,
+ data=data,
+ timeout=timeout,
+ )
+ logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@@ -1497,9 +1591,9 @@ class OpenAITextCompletion(BaseLLM):
model: str,
messages: list,
timeout: float,
+ logging_obj: LiteLLMLoggingObj,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
- logging_obj=None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py
index 9e361d3cc5..2ea0e199e8 100644
--- a/litellm/llms/vertex_httpx.py
+++ b/litellm/llms/vertex_httpx.py
@@ -1035,6 +1035,9 @@ class VertexLLM(BaseLLM):
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None
) # type: ignore
+ cached_content: Optional[str] = optional_params.pop(
+ "cached_content", None
+ )
generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params
)
@@ -1050,6 +1053,8 @@ class VertexLLM(BaseLLM):
data["safetySettings"] = safety_settings
if generation_config is not None:
data["generationConfig"] = generation_config
+ if cached_content is not None:
+ data["cachedContent"] = cached_content
headers = {
"Content-Type": "application/json",
diff --git a/litellm/main.py b/litellm/main.py
index 48d430d524..d6819b5ec0 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -48,6 +48,7 @@ from litellm import ( # type: ignore
get_litellm_params,
get_optional_params,
)
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import (
CustomStreamWrapper,
Usage,
@@ -476,6 +477,15 @@ def mock_completion(
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
+ elif (
+ isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
+ ):
+ raise litellm.RateLimitError(
+ message="this is a mock rate limit error",
+ status_code=getattr(mock_response, "status_code", 429), # type: ignore
+ llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
+ model=model,
+ )
time_delay = kwargs.get("mock_delay", None)
if time_delay is not None:
time.sleep(time_delay)
@@ -2203,15 +2213,26 @@ def completion(
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "aws_bedrock_client" in optional_params:
+ verbose_logger.warning(
+ "'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
+ )
# Extract credentials for legacy boto3 client and pass thru to httpx
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
+
if creds.access_key:
optional_params["aws_access_key_id"] = creds.access_key
if creds.secret_key:
optional_params["aws_secret_access_key"] = creds.secret_key
if creds.token:
optional_params["aws_session_token"] = creds.token
+ if (
+ "aws_region_name" not in optional_params
+ or optional_params["aws_region_name"] is None
+ ):
+ optional_params["aws_region_name"] = (
+ aws_bedrock_client.meta.region_name
+ )
if model in litellm.BEDROCK_CONVERSE_MODELS:
response = bedrock_converse_chat_completion.completion(
@@ -4242,7 +4263,7 @@ def transcription(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
max_retries: Optional[int] = None,
- litellm_logging_obj=None,
+ litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
custom_llm_provider=None,
**kwargs,
):
@@ -4257,6 +4278,18 @@ def transcription(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
+ client: Optional[
+ Union[
+ openai.AsyncOpenAI,
+ openai.OpenAI,
+ openai.AzureOpenAI,
+ openai.AsyncAzureOpenAI,
+ ]
+ ] = kwargs.pop("client", None)
+
+ if litellm_logging_obj:
+ litellm_logging_obj.model_call_details["client"] = str(client)
+
if max_retries is None:
max_retries = openai.DEFAULT_MAX_RETRIES
@@ -4296,6 +4329,7 @@ def transcription(
optional_params=optional_params,
model_response=model_response,
atranscription=atranscription,
+ client=client,
timeout=timeout,
logging_obj=litellm_logging_obj,
api_base=api_base,
@@ -4329,6 +4363,7 @@ def transcription(
optional_params=optional_params,
model_response=model_response,
atranscription=atranscription,
+ client=client,
timeout=timeout,
logging_obj=litellm_logging_obj,
max_retries=max_retries,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index f263378918..f135922ea1 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,16 +1,9 @@
model_list:
- - model_name: gpt-3.5-turbo # all requests where model not in your config go to this deployment
+ - model_name: claude-3-5-sonnet # all requests where model not in your config go to this deployment
litellm_params:
- model: "gpt-3.5-turbo"
- rpm: 100
+ model: "openai/*"
+ mock_response: "Hello world!"
-litellm_settings:
- callbacks: ["dynamic_rate_limiter"]
- priority_reservation: {"dev": 0, "prod": 1}
-# success_callback: ["s3"]
-# s3_callback_params:
-# s3_bucket_name: my-test-bucket-22-litellm # AWS Bucket Name for S3
-# s3_region_name: us-west-2 # AWS Region Name for S3
-# s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3
-# s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
-# s3_path: my-test-path
+general_settings:
+ alerting: ["slack"]
+ alerting_threshold: 10
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index 88b778a6d4..9f2324e51c 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -36,6 +36,7 @@ general_settings:
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
litellm_settings:
+ return_response_headers: true
success_callback: ["prometheus"]
callbacks: ["otel", "hide_secrets"]
failure_callback: ["prometheus"]
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 0577ec0a04..1ca1807223 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -1182,9 +1182,13 @@ async def _run_background_health_check():
Update health_check_results, based on this.
"""
global health_check_results, llm_model_list, health_check_interval
+
+ # make 1 deep copy of llm_model_list -> use this for all background health checks
+ _llm_model_list = copy.deepcopy(llm_model_list)
+
while True:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
- model_list=llm_model_list
+ model_list=_llm_model_list
)
# Update the global variable with the health check results
@@ -3066,8 +3070,11 @@ async def chat_completion(
# Post Call Processing
if llm_router is not None:
data["deployment"] = llm_router.get_deployment(model_id=model_id)
- data["litellm_status"] = "success" # used for alerting
-
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
@@ -3117,7 +3124,6 @@ async def chat_completion(
return response
except RejectedRequestError as e:
_data = e.request_data
- _data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
@@ -3150,7 +3156,6 @@ async def chat_completion(
_chat_response.usage = _usage # type: ignore
return _chat_response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
get_error_message_str(e=e), traceback.format_exc()
@@ -3306,7 +3311,11 @@ async def completion(
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
verbose_proxy_logger.debug("final response: %s", response)
if (
@@ -3345,7 +3354,6 @@ async def completion(
return response
except RejectedRequestError as e:
_data = e.request_data
- _data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
@@ -3384,7 +3392,6 @@ async def completion(
_response.choices[0].text = e.message
return _response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -3536,7 +3543,11 @@ async def embeddings(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -3559,7 +3570,6 @@ async def embeddings(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -3687,8 +3697,11 @@ async def image_generation(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
-
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@@ -3710,7 +3723,6 @@ async def image_generation(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -3825,7 +3837,11 @@ async def audio_speech(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -3991,7 +4007,11 @@ async def audio_transcriptions(
os.remove(file.filename) # Delete the saved file
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4014,7 +4034,6 @@ async def audio_transcriptions(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4093,7 +4112,11 @@ async def get_assistants(
response = await llm_router.aget_assistants(**data)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4114,7 +4137,6 @@ async def get_assistants(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4185,7 +4207,11 @@ async def create_threads(
response = await llm_router.acreate_thread(**data)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4206,7 +4232,6 @@ async def create_threads(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4276,7 +4301,11 @@ async def get_thread(
response = await llm_router.aget_thread(thread_id=thread_id, **data)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4297,7 +4326,6 @@ async def get_thread(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4370,7 +4398,11 @@ async def add_messages(
response = await llm_router.a_add_message(thread_id=thread_id, **data)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4391,7 +4423,6 @@ async def add_messages(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4460,7 +4491,11 @@ async def get_messages(
response = await llm_router.aget_messages(thread_id=thread_id, **data)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4481,7 +4516,6 @@ async def get_messages(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4564,7 +4598,11 @@ async def run_thread(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4585,7 +4623,6 @@ async def run_thread(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4675,7 +4712,11 @@ async def create_batch(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4696,7 +4737,6 @@ async def create_batch(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4781,7 +4821,11 @@ async def retrieve_batch(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4802,7 +4846,6 @@ async def retrieve_batch(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -4897,7 +4940,11 @@ async def create_file(
)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -4918,7 +4965,6 @@ async def create_file(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
@@ -5041,7 +5087,11 @@ async def moderations(
response = await litellm.amoderation(**data)
### ALERTING ###
- data["litellm_status"] = "success" # used for alerting
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
@@ -5062,7 +5112,6 @@ async def moderations(
return response
except Exception as e:
- data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py
index 8895717c61..c4afaedc21 100644
--- a/litellm/proxy/secret_managers/aws_secret_manager.py
+++ b/litellm/proxy/secret_managers/aws_secret_manager.py
@@ -153,7 +153,7 @@ def decrypt_env_var() -> Dict[str, Any]:
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
decrypted_value = aws_kms.decrypt_value(secret_name=k)
# reset env var
- k = re.sub("litellm_secret_aws_kms", "", k, flags=re.IGNORECASE)
+ k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
new_values[k] = decrypted_value
return new_values
diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py
index 96aeb4a816..179d094667 100644
--- a/litellm/proxy/utils.py
+++ b/litellm/proxy/utils.py
@@ -272,6 +272,16 @@ class ProxyLogging:
callback_list=callback_list
)
+ async def update_request_status(
+ self, litellm_call_id: str, status: Literal["success", "fail"]
+ ):
+ await self.internal_usage_cache.async_set_cache(
+ key="request_status:{}".format(litellm_call_id),
+ value=status,
+ local_only=True,
+ ttl=3600,
+ )
+
# The actual implementation of the function
async def pre_call_hook(
self,
@@ -560,6 +570,9 @@ class ProxyLogging:
"""
### ALERTING ###
+ await self.update_request_status(
+ litellm_call_id=request_data.get("litellm_call_id", ""), status="fail"
+ )
if "llm_exceptions" in self.alert_types and not isinstance(
original_exception, HTTPException
):
@@ -611,6 +624,7 @@ class ProxyLogging:
Covers:
1. /chat/completions
"""
+
for callback in litellm.callbacks:
try:
_callback: Optional[CustomLogger] = None
diff --git a/litellm/router.py b/litellm/router.py
index 39cc92ab19..ac61ec729f 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -156,6 +156,7 @@ class Router:
cooldown_time: Optional[
float
] = None, # (seconds) time to cooldown a deployment after failure
+ disable_cooldowns: Optional[bool] = None,
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
@@ -307,6 +308,7 @@ class Router:
self.allowed_fails = allowed_fails or litellm.allowed_fails
self.cooldown_time = cooldown_time or 60
+ self.disable_cooldowns = disable_cooldowns
self.failed_calls = (
InMemoryCache()
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
@@ -2990,6 +2992,8 @@ class Router:
the exception is not one that should be immediately retried (e.g. 401)
"""
+ if self.disable_cooldowns is True:
+ return
if deployment is None:
return
@@ -3030,24 +3034,50 @@ class Router:
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
- if updated_fails > allowed_fails or _should_retry == False:
+ if updated_fails > allowed_fails or _should_retry is False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
- cached_value = self.cache.get_cache(key=cooldown_key)
+ cached_value = self.cache.get_cache(
+ key=cooldown_key
+ ) # [(deployment_id, {last_error_str, last_error_status_code})]
+ cached_value_deployment_ids = []
+ if (
+ cached_value is not None
+ and isinstance(cached_value, list)
+ and len(cached_value) > 0
+ and isinstance(cached_value[0], tuple)
+ ):
+ cached_value_deployment_ids = [cv[0] for cv in cached_value]
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
# update value
- try:
- if deployment in cached_value:
+ if cached_value is not None and len(cached_value_deployment_ids) > 0:
+ if deployment in cached_value_deployment_ids:
pass
else:
- cached_value = cached_value + [deployment]
+ cached_value = cached_value + [
+ (
+ deployment,
+ {
+ "Exception Received": str(original_exception),
+ "Status Code": str(exception_status),
+ },
+ )
+ ]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
- except:
- cached_value = [deployment]
+ else:
+ cached_value = [
+ (
+ deployment,
+ {
+ "Exception Received": str(original_exception),
+ "Status Code": str(exception_status),
+ },
+ )
+ ]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
@@ -3063,7 +3093,33 @@ class Router:
key=deployment, value=updated_fails, ttl=cooldown_time
)
- async def _async_get_cooldown_deployments(self):
+ async def _async_get_cooldown_deployments(self) -> List[str]:
+ """
+ Async implementation of '_get_cooldown_deployments'
+ """
+ dt = get_utc_datetime()
+ current_minute = dt.strftime("%H-%M")
+ # get the current cooldown list for that minute
+ cooldown_key = f"{current_minute}:cooldown_models"
+
+ # ----------------------
+ # Return cooldown models
+ # ----------------------
+ cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
+
+ cached_value_deployment_ids = []
+ if (
+ cooldown_models is not None
+ and isinstance(cooldown_models, list)
+ and len(cooldown_models) > 0
+ and isinstance(cooldown_models[0], tuple)
+ ):
+ cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
+
+ verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
+ return cached_value_deployment_ids
+
+ async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]:
"""
Async implementation of '_get_cooldown_deployments'
"""
@@ -3080,7 +3136,7 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
- def _get_cooldown_deployments(self):
+ def _get_cooldown_deployments(self) -> List[str]:
"""
Get the list of models being cooled down for this minute
"""
@@ -3094,8 +3150,17 @@ class Router:
# ----------------------
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
+ cached_value_deployment_ids = []
+ if (
+ cooldown_models is not None
+ and isinstance(cooldown_models, list)
+ and len(cooldown_models) > 0
+ and isinstance(cooldown_models[0], tuple)
+ ):
+ cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
+
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
- return cooldown_models
+ return cached_value_deployment_ids
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
@@ -4737,7 +4802,7 @@ class Router:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError(
- f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
+ f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
)
if (
diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py
index 6e39c30b36..fb4ba7556b 100644
--- a/litellm/tests/test_bedrock_completion.py
+++ b/litellm/tests/test_bedrock_completion.py
@@ -856,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
mock_client_post.assert_called_once()
+
+
+def test_completion_bedrock_external_client_region():
+ print("\ncalling bedrock claude external client auth")
+ import os
+
+ aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
+ aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
+ aws_region_name = "us-east-1"
+
+ os.environ.pop("AWS_ACCESS_KEY_ID", None)
+ os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
+
+ client = HTTPHandler()
+
+ try:
+ import boto3
+
+ litellm.set_verbose = True
+
+ bedrock = boto3.client(
+ service_name="bedrock-runtime",
+ region_name=aws_region_name,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
+ )
+ with patch.object(client, "post", new=Mock()) as mock_client_post:
+ try:
+ response = completion(
+ model="bedrock/anthropic.claude-instant-v1",
+ messages=messages,
+ max_tokens=10,
+ temperature=0.1,
+ aws_bedrock_client=bedrock,
+ client=client,
+ )
+ # Add any assertions here to check the response
+ print(response)
+ except Exception as e:
+ pass
+
+ print(f"mock_client_post.call_args: {mock_client_post.call_args}")
+ assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
+
+ mock_client_post.assert_called_once()
+
+ os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
+ os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
+ except RateLimitError:
+ pass
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index 5138e9b61b..1c10ef461e 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
-# litellm.num_retries = 3
+# litellm.num_retries=3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py
index 3d8cb3c2a3..fb390bb488 100644
--- a/litellm/tests/test_exceptions.py
+++ b/litellm/tests/test_exceptions.py
@@ -249,6 +249,25 @@ def test_completion_azure_exception():
# test_completion_azure_exception()
+def test_azure_embedding_exceptions():
+ try:
+
+ response = litellm.embedding(
+ model="azure/azure-embedding-model",
+ input="hello",
+ messages="hello",
+ )
+ pytest.fail(f"Bad request this should have failed but got {response}")
+
+ except Exception as e:
+ print(vars(e))
+ # CRUCIAL Test - Ensures our exceptions are readable and not overly complicated. some users have complained exceptions will randomly have another exception raised in our exception mapping
+ assert (
+ e.message
+ == "litellm.APIError: AzureException APIError - Embeddings.create() got an unexpected keyword argument 'messages'"
+ )
+
+
async def asynctest_completion_azure_exception():
try:
import openai
diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py
index 49ec18f24c..67857b8c86 100644
--- a/litellm/tests/test_image_generation.py
+++ b/litellm/tests/test_image_generation.py
@@ -1,20 +1,23 @@
# What this tests?
## This tests the litellm support for the openai /generations endpoint
-import sys, os
-import traceback
-from dotenv import load_dotenv
import logging
+import os
+import sys
+import traceback
+
+from dotenv import load_dotenv
logging.basicConfig(level=logging.DEBUG)
load_dotenv()
-import os
import asyncio
+import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
+
import litellm
@@ -39,13 +42,25 @@ def test_image_generation_openai():
# test_image_generation_openai()
-def test_image_generation_azure():
+@pytest.mark.parametrize(
+ "sync_mode",
+ [True, False],
+) #
+@pytest.mark.asyncio
+async def test_image_generation_azure(sync_mode):
try:
- response = litellm.image_generation(
- prompt="A cute baby sea otter",
- model="azure/",
- api_version="2023-06-01-preview",
- )
+ if sync_mode:
+ response = litellm.image_generation(
+ prompt="A cute baby sea otter",
+ model="azure/",
+ api_version="2023-06-01-preview",
+ )
+ else:
+ response = await litellm.aimage_generation(
+ prompt="A cute baby sea otter",
+ model="azure/",
+ api_version="2023-06-01-preview",
+ )
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py
index 2dda57c2e9..17fc26d60e 100644
--- a/litellm/types/llms/vertex_ai.py
+++ b/litellm/types/llms/vertex_ai.py
@@ -155,6 +155,16 @@ class ToolConfig(TypedDict):
functionCallingConfig: FunctionCallingConfig
+class TTL(TypedDict, total=False):
+ seconds: Required[float]
+ nano: float
+
+
+class CachedContent(TypedDict, total=False):
+ ttl: TTL
+ expire_time: str
+
+
class RequestBody(TypedDict, total=False):
contents: Required[List[ContentType]]
system_instruction: SystemInstructions
@@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False):
toolConfig: ToolConfig
safetySettings: List[SafetSettingsConfig]
generationConfig: GenerationConfig
+ cachedContent: str
class SafetyRatings(TypedDict):
diff --git a/litellm/utils.py b/litellm/utils.py
index 103f854b68..82e3ca1712 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -4815,6 +4815,12 @@ def function_to_dict(input_function): # noqa: C901
return result
+def modify_url(original_url, new_path):
+ url = httpx.URL(original_url)
+ modified_url = url.copy_with(path=new_path)
+ return str(modified_url)
+
+
def load_test_model(
model: str,
custom_llm_provider: str = "",
@@ -5810,6 +5816,18 @@ def exception_type(
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
extra_information = f"\nModel: {model}"
+
+ exception_provider = "Unknown"
+ if (
+ isinstance(custom_llm_provider, str)
+ and len(custom_llm_provider) > 0
+ ):
+ exception_provider = (
+ custom_llm_provider[0].upper()
+ + custom_llm_provider[1:]
+ + "Exception"
+ )
+
if _api_base:
extra_information += f"\nAPI Base: `{_api_base}`"
if (
diff --git a/pyproject.toml b/pyproject.toml
index 2519c167f5..c698a18e16 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "1.41.2"
+version = "1.41.3"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@@ -90,7 +90,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
-version = "1.41.2"
+version = "1.41.3"
version_files = [
"pyproject.toml:^version"
]
diff --git a/tests/test_whisper.py b/tests/test_whisper.py
index 1debbbc1db..09819f796c 100644
--- a/tests/test_whisper.py
+++ b/tests/test_whisper.py
@@ -8,6 +8,9 @@ from openai import AsyncOpenAI
import sys, os, dotenv
from typing import Optional
from dotenv import load_dotenv
+from litellm.integrations.custom_logger import CustomLogger
+import litellm
+import logging
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
@@ -84,9 +87,32 @@ async def test_transcription_async_openai():
assert isinstance(transcript.text, str)
+# This file includes the custom callbacks for LiteLLM Proxy
+# Once defined, these can be passed in proxy_config.yaml
+class MyCustomHandler(CustomLogger):
+ def __init__(self):
+ self.openai_client = None
+
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ # init logging config
+ print("logging a transcript kwargs: ", kwargs)
+ print("openai client=", kwargs.get("client"))
+ self.openai_client = kwargs.get("client")
+
+ except:
+ pass
+
+
+proxy_handler_instance = MyCustomHandler()
+
+
+# Set litellm.callbacks = [proxy_handler_instance] on the proxy
+# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
@pytest.mark.asyncio
async def test_transcription_on_router():
litellm.set_verbose = True
+ litellm.callbacks = [proxy_handler_instance]
print("\n Testing async transcription on router\n")
try:
model_list = [
@@ -108,11 +134,29 @@ async def test_transcription_on_router():
]
router = Router(model_list=model_list)
+
+ router_level_clients = []
+ for deployment in router.model_list:
+ _deployment_openai_client = router._get_client(
+ deployment=deployment,
+ kwargs={"model": "whisper-1"},
+ client_type="async",
+ )
+
+ router_level_clients.append(str(_deployment_openai_client))
+
response = await router.atranscription(
model="whisper",
file=audio_file,
)
print(response)
+
+ # PROD Test
+ # Ensure we ONLY use OpenAI/Azure client initialized on the router level
+ await asyncio.sleep(5)
+ print("OpenAI Client used= ", proxy_handler_instance.openai_client)
+ print("all router level clients= ", router_level_clients)
+ assert proxy_handler_instance.openai_client in router_level_clients
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")