forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_support_dynamic_rpm_limiting
This commit is contained in:
commit
21d3a28e51
27 changed files with 1067 additions and 133 deletions
|
@ -20,6 +20,8 @@ This covers:
|
||||||
- **Spend Tracking**
|
- **Spend Tracking**
|
||||||
- ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags)
|
- ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags)
|
||||||
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
|
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
|
||||||
|
- **Advanced Metrics**
|
||||||
|
- ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](./proxy/prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
|
||||||
- **Guardrails, PII Masking, Content Moderation**
|
- **Guardrails, PII Masking, Content Moderation**
|
||||||
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation)
|
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation)
|
||||||
- ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai)
|
- ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai)
|
||||||
|
|
|
@ -89,3 +89,30 @@ Expected Output:
|
||||||
```bash
|
```bash
|
||||||
# no info statements
|
# no info statements
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Common Errors
|
||||||
|
|
||||||
|
1. "No available deployments..."
|
||||||
|
|
||||||
|
```
|
||||||
|
No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be caused due to all your models hitting rate limit errors, causing the cooldown to kick in.
|
||||||
|
|
||||||
|
How to control this?
|
||||||
|
- Adjust the cooldown time
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
cooldown_time: 0 # 👈 KEY CHANGE
|
||||||
|
```
|
||||||
|
|
||||||
|
- Disable Cooldowns [NOT RECOMMENDED]
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
disable_cooldowns: True
|
||||||
|
```
|
||||||
|
|
||||||
|
This is not recommended, as it will lead to requests being routed to deployments over their tpm/rpm limit.
|
|
@ -23,6 +23,8 @@ Features:
|
||||||
- **Spend Tracking**
|
- **Spend Tracking**
|
||||||
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
|
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
|
||||||
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
|
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
|
||||||
|
- **Advanced Metrics**
|
||||||
|
- ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
|
||||||
- **Guardrails, PII Masking, Content Moderation**
|
- **Guardrails, PII Masking, Content Moderation**
|
||||||
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
||||||
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# 📈 Prometheus metrics [BETA]
|
# 📈 Prometheus metrics [BETA]
|
||||||
|
|
||||||
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
|
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
|
||||||
|
@ -61,6 +64,56 @@ http://localhost:4000/metrics
|
||||||
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
|
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
|
||||||
|
|
||||||
|
|
||||||
|
### ✨ (Enterprise) LLM Remaining Requests and Remaining Tokens
|
||||||
|
Set this on your config.yaml to allow you to track how close you are to hitting your TPM / RPM limits on each model group
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
success_callback: ["prometheus"]
|
||||||
|
failure_callback: ["prometheus"]
|
||||||
|
return_response_headers: true # ensures the LLM API calls track the response headers
|
||||||
|
```
|
||||||
|
|
||||||
|
| Metric Name | Description |
|
||||||
|
|----------------------|--------------------------------------|
|
||||||
|
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
|
||||||
|
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
|
||||||
|
|
||||||
|
Example Metric
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="Remaining Requests" label="Remaining Requests">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm_remaining_requests
|
||||||
|
{
|
||||||
|
api_base="https://api.openai.com/v1",
|
||||||
|
api_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
model_group="gpt-3.5-turbo"
|
||||||
|
}
|
||||||
|
8998.0
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="Requests" label="Remaining Tokens">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm_remaining_tokens
|
||||||
|
{
|
||||||
|
api_base="https://api.openai.com/v1",
|
||||||
|
api_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
model_group="gpt-3.5-turbo"
|
||||||
|
}
|
||||||
|
999981.0
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Monitor System Health
|
## Monitor System Health
|
||||||
|
|
||||||
To monitor the health of litellm adjacent services (redis / postgres), do:
|
To monitor the health of litellm adjacent services (redis / postgres), do:
|
||||||
|
|
|
@ -152,6 +152,27 @@ response = chat(messages)
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="langchain js" label="Langchain JS">
|
||||||
|
|
||||||
|
```js
|
||||||
|
import { ChatOpenAI } from "@langchain/openai";
|
||||||
|
|
||||||
|
|
||||||
|
const model = new ChatOpenAI({
|
||||||
|
modelName: "gpt-4",
|
||||||
|
openAIApiKey: "sk-1234",
|
||||||
|
modelKwargs: {"metadata": "hello world"} // 👈 PASS Additional params here
|
||||||
|
}, {
|
||||||
|
basePath: "http://0.0.0.0:4000",
|
||||||
|
});
|
||||||
|
|
||||||
|
const message = await model.invoke("Hi there!");
|
||||||
|
|
||||||
|
console.log(message);
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
|
@ -815,6 +815,35 @@ model_list:
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```
|
||||||
|
No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### **Disable cooldowns**
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
|
||||||
|
router = Router(..., disable_cooldowns=True)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
disable_cooldowns: True
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Retries
|
### Retries
|
||||||
|
|
||||||
For both async + sync functions, we support retrying failed requests.
|
For both async + sync functions, we support retrying failed requests.
|
||||||
|
|
|
@ -125,6 +125,9 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
##################
|
##################
|
||||||
### PREVIEW FEATURES ###
|
### PREVIEW FEATURES ###
|
||||||
enable_preview_features: bool = False
|
enable_preview_features: bool = False
|
||||||
|
return_response_headers: bool = (
|
||||||
|
False # get response headers from LLM Api providers - example x-remaining-requests,
|
||||||
|
)
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
|
|
@ -2,14 +2,20 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, log events to Prometheus
|
# On success, log events to Prometheus
|
||||||
|
|
||||||
import dotenv, os
|
import datetime
|
||||||
import requests # type: ignore
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import uuid
|
||||||
import litellm, uuid
|
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
|
||||||
|
|
||||||
class PrometheusLogger:
|
class PrometheusLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
@ -20,6 +26,8 @@ class PrometheusLogger:
|
||||||
try:
|
try:
|
||||||
from prometheus_client import Counter, Gauge
|
from prometheus_client import Counter, Gauge
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||||
name="litellm_llm_api_failed_requests_metric",
|
name="litellm_llm_api_failed_requests_metric",
|
||||||
documentation="Total number of failed LLM API calls via litellm",
|
documentation="Total number of failed LLM API calls via litellm",
|
||||||
|
@ -88,6 +96,31 @@ class PrometheusLogger:
|
||||||
labelnames=["hashed_api_key", "api_key_alias"],
|
labelnames=["hashed_api_key", "api_key_alias"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Litellm-Enterprise Metrics
|
||||||
|
if premium_user is True:
|
||||||
|
# Remaining Rate Limit for model
|
||||||
|
self.litellm_remaining_requests_metric = Gauge(
|
||||||
|
"litellm_remaining_requests",
|
||||||
|
"remaining requests for model, returned from LLM API Provider",
|
||||||
|
labelnames=[
|
||||||
|
"model_group",
|
||||||
|
"api_provider",
|
||||||
|
"api_base",
|
||||||
|
"litellm_model_name",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.litellm_remaining_tokens_metric = Gauge(
|
||||||
|
"litellm_remaining_tokens",
|
||||||
|
"remaining tokens for model, returned from LLM API Provider",
|
||||||
|
labelnames=[
|
||||||
|
"model_group",
|
||||||
|
"api_provider",
|
||||||
|
"api_base",
|
||||||
|
"litellm_model_name",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -104,6 +137,8 @@ class PrometheusLogger:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Define prometheus client
|
# Define prometheus client
|
||||||
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"prometheus Logging - Enters logging function for model {kwargs}"
|
f"prometheus Logging - Enters logging function for model {kwargs}"
|
||||||
)
|
)
|
||||||
|
@ -199,6 +234,10 @@ class PrometheusLogger:
|
||||||
user_api_key, user_api_key_alias
|
user_api_key, user_api_key_alias
|
||||||
).set(_remaining_api_key_budget)
|
).set(_remaining_api_key_budget)
|
||||||
|
|
||||||
|
# set x-ratelimit headers
|
||||||
|
if premium_user is True:
|
||||||
|
self.set_remaining_tokens_requests_metric(kwargs)
|
||||||
|
|
||||||
### FAILURE INCREMENT ###
|
### FAILURE INCREMENT ###
|
||||||
if "exception" in kwargs:
|
if "exception" in kwargs:
|
||||||
self.litellm_llm_api_failed_requests_metric.labels(
|
self.litellm_llm_api_failed_requests_metric.labels(
|
||||||
|
@ -216,6 +255,58 @@ class PrometheusLogger:
|
||||||
verbose_logger.debug(traceback.format_exc())
|
verbose_logger.debug(traceback.format_exc())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def set_remaining_tokens_requests_metric(self, request_kwargs: dict):
|
||||||
|
try:
|
||||||
|
verbose_logger.debug("setting remaining tokens requests metric")
|
||||||
|
_response_headers = request_kwargs.get("response_headers")
|
||||||
|
_litellm_params = request_kwargs.get("litellm_params", {}) or {}
|
||||||
|
_metadata = _litellm_params.get("metadata", {})
|
||||||
|
litellm_model_name = request_kwargs.get("model", None)
|
||||||
|
model_group = _metadata.get("model_group", None)
|
||||||
|
api_base = _metadata.get("api_base", None)
|
||||||
|
llm_provider = _litellm_params.get("custom_llm_provider", None)
|
||||||
|
|
||||||
|
remaining_requests = None
|
||||||
|
remaining_tokens = None
|
||||||
|
# OpenAI / OpenAI Compatible headers
|
||||||
|
if (
|
||||||
|
_response_headers
|
||||||
|
and "x-ratelimit-remaining-requests" in _response_headers
|
||||||
|
):
|
||||||
|
remaining_requests = _response_headers["x-ratelimit-remaining-requests"]
|
||||||
|
if (
|
||||||
|
_response_headers
|
||||||
|
and "x-ratelimit-remaining-tokens" in _response_headers
|
||||||
|
):
|
||||||
|
remaining_tokens = _response_headers["x-ratelimit-remaining-tokens"]
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"remaining requests: {remaining_requests}, remaining tokens: {remaining_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if remaining_requests:
|
||||||
|
"""
|
||||||
|
"model_group",
|
||||||
|
"api_provider",
|
||||||
|
"api_base",
|
||||||
|
"litellm_model_name"
|
||||||
|
"""
|
||||||
|
self.litellm_remaining_requests_metric.labels(
|
||||||
|
model_group, llm_provider, api_base, litellm_model_name
|
||||||
|
).set(remaining_requests)
|
||||||
|
|
||||||
|
if remaining_tokens:
|
||||||
|
self.litellm_remaining_tokens_metric.labels(
|
||||||
|
model_group, llm_provider, api_base, litellm_model_name
|
||||||
|
).set(remaining_tokens)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Prometheus Error: set_remaining_tokens_requests_metric. Exception occured - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def safe_get_remaining_budget(
|
def safe_get_remaining_budget(
|
||||||
max_budget: Optional[float], spend: Optional[float]
|
max_budget: Optional[float], spend: Optional[float]
|
||||||
|
|
|
@ -606,6 +606,13 @@ class SlackAlerting(CustomLogger):
|
||||||
and request_data.get("litellm_status", "") != "success"
|
and request_data.get("litellm_status", "") != "success"
|
||||||
and request_data.get("litellm_status", "") != "fail"
|
and request_data.get("litellm_status", "") != "fail"
|
||||||
):
|
):
|
||||||
|
## CHECK IF CACHE IS UPDATED
|
||||||
|
litellm_call_id = request_data.get("litellm_call_id", "")
|
||||||
|
status: Optional[str] = await self.internal_usage_cache.async_get_cache(
|
||||||
|
key="request_status:{}".format(litellm_call_id), local_only=True
|
||||||
|
)
|
||||||
|
if status is not None and (status == "success" or status == "fail"):
|
||||||
|
return
|
||||||
if request_data.get("deployment", None) is not None and isinstance(
|
if request_data.get("deployment", None) is not None and isinstance(
|
||||||
request_data["deployment"], dict
|
request_data["deployment"], dict
|
||||||
):
|
):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -21,8 +22,10 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
from typing_extensions import overload
|
from typing_extensions import overload
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import OpenAIConfig
|
from litellm import ImageResponse, OpenAIConfig
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -32,6 +35,7 @@ from litellm.utils import (
|
||||||
UnsupportedParamsError,
|
UnsupportedParamsError,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
get_secret,
|
get_secret,
|
||||||
|
modify_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..types.llms.openai import (
|
from ..types.llms.openai import (
|
||||||
|
@ -458,6 +462,36 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
return azure_client
|
return azure_client
|
||||||
|
|
||||||
|
async def make_azure_openai_chat_completion_request(
|
||||||
|
self,
|
||||||
|
azure_client: AsyncAzureOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call chat.completions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = (
|
||||||
|
await azure_client.chat.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await azure_client.chat.completions.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -470,7 +504,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
optional_params,
|
optional_params,
|
||||||
litellm_params,
|
litellm_params,
|
||||||
logger_fn,
|
logger_fn,
|
||||||
|
@ -649,9 +683,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None, # this is the AsyncAzureOpenAI
|
client=None, # this is the AsyncAzureOpenAI
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -701,9 +735,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.chat.completions.create(
|
|
||||||
**data, timeout=timeout
|
headers, response = await self.make_azure_openai_chat_completion_request(
|
||||||
|
azure_client=azure_client,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
|
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -812,7 +850,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
|
@ -861,9 +899,14 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.chat.completions.create(
|
|
||||||
**data, timeout=timeout
|
headers, response = await self.make_azure_openai_chat_completion_request(
|
||||||
|
azure_client=azure_client,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
|
|
||||||
# return response
|
# return response
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
@ -1011,6 +1054,234 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
async def make_async_azure_httpx_request(
|
||||||
|
self,
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
api_key: str,
|
||||||
|
data: dict,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Implemented for azure dall-e-2 image gen calls
|
||||||
|
|
||||||
|
Alternative to needing a custom transport implementation
|
||||||
|
"""
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
async_handler = client # type: ignore
|
||||||
|
|
||||||
|
if (
|
||||||
|
"images/generations" in api_base
|
||||||
|
and api_version
|
||||||
|
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
|
"2023-06-01-preview",
|
||||||
|
"2023-07-01-preview",
|
||||||
|
"2023-08-01-preview",
|
||||||
|
"2023-09-01-preview",
|
||||||
|
"2023-10-01-preview",
|
||||||
|
]
|
||||||
|
): # CREATE + POLL for azure dall-e-2 calls
|
||||||
|
|
||||||
|
api_base = modify_url(
|
||||||
|
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||||
|
)
|
||||||
|
|
||||||
|
data.pop(
|
||||||
|
"model", None
|
||||||
|
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
|
||||||
|
response = await async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
data=json.dumps(data),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
response = await async_handler.get(
|
||||||
|
url=operation_location_url,
|
||||||
|
headers={
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
timeout_secs: int = 120
|
||||||
|
start_time = time.time()
|
||||||
|
if "status" not in response.json():
|
||||||
|
raise Exception(
|
||||||
|
"Expected 'status' in response. Got={}".format(response.json())
|
||||||
|
)
|
||||||
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
|
if time.time() - start_time > timeout_secs:
|
||||||
|
timeout_msg = {
|
||||||
|
"error": {
|
||||||
|
"code": "Timeout",
|
||||||
|
"message": "Operation polling timed out.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=408, message="Operation polling timed out."
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
|
||||||
|
response = await async_handler.get(
|
||||||
|
url=operation_location_url,
|
||||||
|
headers={
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
if response.json()["status"] == "failed":
|
||||||
|
error_data = response.json()
|
||||||
|
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
|
||||||
|
|
||||||
|
return response
|
||||||
|
return await async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
json=data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json;",
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_sync_azure_httpx_request(
|
||||||
|
self,
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
api_key: str,
|
||||||
|
data: dict,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Implemented for azure dall-e-2 image gen calls
|
||||||
|
|
||||||
|
Alternative to needing a custom transport implementation
|
||||||
|
"""
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
sync_handler = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
sync_handler = client # type: ignore
|
||||||
|
|
||||||
|
if (
|
||||||
|
"images/generations" in api_base
|
||||||
|
and api_version
|
||||||
|
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
|
"2023-06-01-preview",
|
||||||
|
"2023-07-01-preview",
|
||||||
|
"2023-08-01-preview",
|
||||||
|
"2023-09-01-preview",
|
||||||
|
"2023-10-01-preview",
|
||||||
|
]
|
||||||
|
): # CREATE + POLL for azure dall-e-2 calls
|
||||||
|
|
||||||
|
api_base = modify_url(
|
||||||
|
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||||
|
)
|
||||||
|
|
||||||
|
data.pop(
|
||||||
|
"model", None
|
||||||
|
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
|
||||||
|
response = sync_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
data=json.dumps(data),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
operation_location_url = response.headers["operation-location"]
|
||||||
|
response = sync_handler.get(
|
||||||
|
url=operation_location_url,
|
||||||
|
headers={
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response.read()
|
||||||
|
|
||||||
|
timeout_secs: int = 120
|
||||||
|
start_time = time.time()
|
||||||
|
if "status" not in response.json():
|
||||||
|
raise Exception(
|
||||||
|
"Expected 'status' in response. Got={}".format(response.json())
|
||||||
|
)
|
||||||
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
|
if time.time() - start_time > timeout_secs:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=408, message="Operation polling timed out."
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(int(response.headers.get("retry-after") or 10))
|
||||||
|
response = sync_handler.get(
|
||||||
|
url=operation_location_url,
|
||||||
|
headers={
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.read()
|
||||||
|
|
||||||
|
if response.json()["status"] == "failed":
|
||||||
|
error_data = response.json()
|
||||||
|
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
|
||||||
|
|
||||||
|
return response
|
||||||
|
return sync_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
json=data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json;",
|
||||||
|
"api-key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_azure_base_url(
|
||||||
|
self, azure_client_params: dict, model: Optional[str]
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
api_base: str = azure_client_params.get(
|
||||||
|
"azure_endpoint", ""
|
||||||
|
) # "https://example-endpoint.openai.azure.com"
|
||||||
|
if api_base.endswith("/"):
|
||||||
|
api_base = api_base.rstrip("/")
|
||||||
|
api_version: str = azure_client_params.get("api_version", "")
|
||||||
|
if model is None:
|
||||||
|
model = ""
|
||||||
|
new_api_base = (
|
||||||
|
api_base
|
||||||
|
+ "/openai/deployments/"
|
||||||
|
+ model
|
||||||
|
+ "/images/generations"
|
||||||
|
+ "?api-version="
|
||||||
|
+ api_version
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_api_base
|
||||||
|
|
||||||
async def aimage_generation(
|
async def aimage_generation(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -1022,30 +1293,40 @@ class AzureChatCompletion(BaseLLM):
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
response = None
|
response: Optional[dict] = None
|
||||||
try:
|
try:
|
||||||
if client is None:
|
# response = await azure_client.images.generate(**data, timeout=timeout)
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
api_base: str = azure_client_params.get(
|
||||||
transport=AsyncCustomHTTPTransport(),
|
"api_base", ""
|
||||||
|
) # "https://example-endpoint.openai.azure.com"
|
||||||
|
if api_base.endswith("/"):
|
||||||
|
api_base = api_base.rstrip("/")
|
||||||
|
api_version: str = azure_client_params.get("api_version", "")
|
||||||
|
img_gen_api_base = self.create_azure_base_url(
|
||||||
|
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||||
)
|
)
|
||||||
azure_client = AsyncAzureOpenAI(
|
|
||||||
http_client=client_session, **azure_client_params
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
azure_client = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["prompt"],
|
input=data["prompt"],
|
||||||
api_key=azure_client.api_key,
|
api_key=api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"headers": {"api_key": azure_client.api_key},
|
|
||||||
"api_base": azure_client._base_url._uri_reference,
|
|
||||||
"acompletion": True,
|
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
|
"api_base": img_gen_api_base,
|
||||||
|
"headers": {"api_key": api_key},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.images.generate(**data, timeout=timeout)
|
httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
|
||||||
stringified_response = response.model_dump()
|
client=None,
|
||||||
|
timeout=timeout,
|
||||||
|
api_base=img_gen_api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
response = httpx_response.json()["result"]
|
||||||
|
|
||||||
|
stringified_response = response
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1128,28 +1409,30 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if client is None:
|
img_gen_api_base = self.create_azure_base_url(
|
||||||
client_session = litellm.client_session or httpx.Client(
|
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||||
transport=CustomHTTPTransport(),
|
|
||||||
)
|
)
|
||||||
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
|
||||||
else:
|
|
||||||
azure_client = client
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=data["prompt"],
|
||||||
api_key=azure_client.api_key,
|
api_key=api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"headers": {"api_key": azure_client.api_key},
|
|
||||||
"api_base": azure_client._base_url._uri_reference,
|
|
||||||
"acompletion": False,
|
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
|
"api_base": img_gen_api_base,
|
||||||
|
"headers": {"api_key": api_key},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
|
||||||
|
client=None,
|
||||||
|
timeout=timeout,
|
||||||
|
api_base=img_gen_api_base,
|
||||||
|
api_version=api_version or "",
|
||||||
|
api_key=api_key or "",
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
response = httpx_response.json()["result"]
|
||||||
|
|
||||||
## COMPLETION CALL
|
|
||||||
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -1158,7 +1441,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
original_response=response,
|
original_response=response,
|
||||||
)
|
)
|
||||||
# return response
|
# return response
|
||||||
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
|
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -21,6 +21,7 @@ from pydantic import BaseModel
|
||||||
from typing_extensions import overload, override
|
from typing_extensions import overload, override
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
|
@ -652,6 +653,36 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
async def make_openai_chat_completion_request(
|
||||||
|
self,
|
||||||
|
openai_aclient: AsyncOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call chat.completions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = (
|
||||||
|
await openai_aclient.chat.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await openai_aclient.chat.completions.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
@ -836,13 +867,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
|
@ -869,8 +900,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_aclient.chat.completions.create(
|
headers, response = await self.make_openai_chat_completion_request(
|
||||||
**data, timeout=timeout
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
)
|
)
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -879,9 +910,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
original_response=stringified_response,
|
original_response=stringified_response,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
return convert_to_model_response_object(
|
return convert_to_model_response_object(
|
||||||
response_object=stringified_response,
|
response_object=stringified_response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
|
hidden_params={"headers": headers},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -931,10 +964,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
|
@ -965,9 +998,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_aclient.chat.completions.create(
|
headers, response = await self.make_openai_chat_completion_request(
|
||||||
**data, timeout=timeout
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -992,17 +1026,43 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
async def make_openai_embedding_request(
|
||||||
|
self,
|
||||||
|
openai_aclient: AsyncOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call embeddings.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
) # type: ignore
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
input: list,
|
input: list,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -1014,7 +1074,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
headers, response = await self.make_openai_embedding_request(
|
||||||
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -1229,6 +1292,34 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
raise OpenAIError(status_code=500, message=str(e))
|
raise OpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
# Audio Transcriptions
|
||||||
|
async def make_openai_audio_transcriptions_request(
|
||||||
|
self,
|
||||||
|
openai_aclient: AsyncOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call openai_aclient.audio.transcriptions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = (
|
||||||
|
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
) # type: ignore
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def audio_transcriptions(
|
def audio_transcriptions(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1286,11 +1377,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: TranscriptionResponse,
|
model_response: TranscriptionResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
openai_aclient = self._get_openai_client(
|
openai_aclient = self._get_openai_client(
|
||||||
|
@ -1302,9 +1393,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_aclient.audio.transcriptions.create(
|
headers, response = await self.make_openai_audio_transcriptions_request(
|
||||||
**data, timeout=timeout
|
openai_aclient=openai_aclient,
|
||||||
) # type: ignore
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -1497,9 +1591,9 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
logging_obj=None,
|
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
|
|
|
@ -1035,6 +1035,9 @@ class VertexLLM(BaseLLM):
|
||||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
"safety_settings", None
|
"safety_settings", None
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
cached_content: Optional[str] = optional_params.pop(
|
||||||
|
"cached_content", None
|
||||||
|
)
|
||||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
@ -1050,6 +1053,8 @@ class VertexLLM(BaseLLM):
|
||||||
data["safetySettings"] = safety_settings
|
data["safetySettings"] = safety_settings
|
||||||
if generation_config is not None:
|
if generation_config is not None:
|
||||||
data["generationConfig"] = generation_config
|
data["generationConfig"] = generation_config
|
||||||
|
if cached_content is not None:
|
||||||
|
data["cachedContent"] = cached_content
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
|
@ -48,6 +48,7 @@ from litellm import ( # type: ignore
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
)
|
)
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
Usage,
|
Usage,
|
||||||
|
@ -476,6 +477,15 @@ def mock_completion(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
|
||||||
|
):
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="this is a mock rate limit error",
|
||||||
|
status_code=getattr(mock_response, "status_code", 429), # type: ignore
|
||||||
|
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
time_delay = kwargs.get("mock_delay", None)
|
time_delay = kwargs.get("mock_delay", None)
|
||||||
if time_delay is not None:
|
if time_delay is not None:
|
||||||
time.sleep(time_delay)
|
time.sleep(time_delay)
|
||||||
|
@ -2203,15 +2213,26 @@ def completion(
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
if "aws_bedrock_client" in optional_params:
|
if "aws_bedrock_client" in optional_params:
|
||||||
|
verbose_logger.warning(
|
||||||
|
"'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
|
||||||
|
)
|
||||||
# Extract credentials for legacy boto3 client and pass thru to httpx
|
# Extract credentials for legacy boto3 client and pass thru to httpx
|
||||||
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
||||||
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
||||||
|
|
||||||
if creds.access_key:
|
if creds.access_key:
|
||||||
optional_params["aws_access_key_id"] = creds.access_key
|
optional_params["aws_access_key_id"] = creds.access_key
|
||||||
if creds.secret_key:
|
if creds.secret_key:
|
||||||
optional_params["aws_secret_access_key"] = creds.secret_key
|
optional_params["aws_secret_access_key"] = creds.secret_key
|
||||||
if creds.token:
|
if creds.token:
|
||||||
optional_params["aws_session_token"] = creds.token
|
optional_params["aws_session_token"] = creds.token
|
||||||
|
if (
|
||||||
|
"aws_region_name" not in optional_params
|
||||||
|
or optional_params["aws_region_name"] is None
|
||||||
|
):
|
||||||
|
optional_params["aws_region_name"] = (
|
||||||
|
aws_bedrock_client.meta.region_name
|
||||||
|
)
|
||||||
|
|
||||||
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||||
response = bedrock_converse_chat_completion.completion(
|
response = bedrock_converse_chat_completion.completion(
|
||||||
|
@ -4242,7 +4263,7 @@ def transcription(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
litellm_logging_obj=None,
|
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -4257,6 +4278,18 @@ def transcription(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
|
client: Optional[
|
||||||
|
Union[
|
||||||
|
openai.AsyncOpenAI,
|
||||||
|
openai.OpenAI,
|
||||||
|
openai.AzureOpenAI,
|
||||||
|
openai.AsyncAzureOpenAI,
|
||||||
|
]
|
||||||
|
] = kwargs.pop("client", None)
|
||||||
|
|
||||||
|
if litellm_logging_obj:
|
||||||
|
litellm_logging_obj.model_call_details["client"] = str(client)
|
||||||
|
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = openai.DEFAULT_MAX_RETRIES
|
max_retries = openai.DEFAULT_MAX_RETRIES
|
||||||
|
|
||||||
|
@ -4296,6 +4329,7 @@ def transcription(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
atranscription=atranscription,
|
atranscription=atranscription,
|
||||||
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4329,6 +4363,7 @@ def transcription(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
atranscription=atranscription,
|
atranscription=atranscription,
|
||||||
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
|
|
@ -1,16 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo # all requests where model not in your config go to this deployment
|
- model_name: claude-3-5-sonnet # all requests where model not in your config go to this deployment
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "gpt-3.5-turbo"
|
model: "openai/*"
|
||||||
rpm: 100
|
mock_response: "Hello world!"
|
||||||
|
|
||||||
litellm_settings:
|
general_settings:
|
||||||
callbacks: ["dynamic_rate_limiter"]
|
alerting: ["slack"]
|
||||||
priority_reservation: {"dev": 0, "prod": 1}
|
alerting_threshold: 10
|
||||||
# success_callback: ["s3"]
|
|
||||||
# s3_callback_params:
|
|
||||||
# s3_bucket_name: my-test-bucket-22-litellm # AWS Bucket Name for S3
|
|
||||||
# s3_region_name: us-west-2 # AWS Region Name for S3
|
|
||||||
# s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
|
||||||
# s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
||||||
# s3_path: my-test-path
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ general_settings:
|
||||||
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
|
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
return_response_headers: true
|
||||||
success_callback: ["prometheus"]
|
success_callback: ["prometheus"]
|
||||||
callbacks: ["otel", "hide_secrets"]
|
callbacks: ["otel", "hide_secrets"]
|
||||||
failure_callback: ["prometheus"]
|
failure_callback: ["prometheus"]
|
||||||
|
|
|
@ -1182,9 +1182,13 @@ async def _run_background_health_check():
|
||||||
Update health_check_results, based on this.
|
Update health_check_results, based on this.
|
||||||
"""
|
"""
|
||||||
global health_check_results, llm_model_list, health_check_interval
|
global health_check_results, llm_model_list, health_check_interval
|
||||||
|
|
||||||
|
# make 1 deep copy of llm_model_list -> use this for all background health checks
|
||||||
|
_llm_model_list = copy.deepcopy(llm_model_list)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||||
model_list=llm_model_list
|
model_list=_llm_model_list
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the global variable with the health check results
|
# Update the global variable with the health check results
|
||||||
|
@ -3066,8 +3070,11 @@ async def chat_completion(
|
||||||
# Post Call Processing
|
# Post Call Processing
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
|
@ -3117,7 +3124,6 @@ async def chat_completion(
|
||||||
return response
|
return response
|
||||||
except RejectedRequestError as e:
|
except RejectedRequestError as e:
|
||||||
_data = e.request_data
|
_data = e.request_data
|
||||||
_data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
original_exception=e,
|
original_exception=e,
|
||||||
|
@ -3150,7 +3156,6 @@ async def chat_completion(
|
||||||
_chat_response.usage = _usage # type: ignore
|
_chat_response.usage = _usage # type: ignore
|
||||||
return _chat_response
|
return _chat_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
|
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
|
||||||
get_error_message_str(e=e), traceback.format_exc()
|
get_error_message_str(e=e), traceback.format_exc()
|
||||||
|
@ -3306,7 +3311,11 @@ async def completion(
|
||||||
response_cost = hidden_params.get("response_cost", None) or ""
|
response_cost = hidden_params.get("response_cost", None) or ""
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("final response: %s", response)
|
verbose_proxy_logger.debug("final response: %s", response)
|
||||||
if (
|
if (
|
||||||
|
@ -3345,7 +3354,6 @@ async def completion(
|
||||||
return response
|
return response
|
||||||
except RejectedRequestError as e:
|
except RejectedRequestError as e:
|
||||||
_data = e.request_data
|
_data = e.request_data
|
||||||
_data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
original_exception=e,
|
original_exception=e,
|
||||||
|
@ -3384,7 +3392,6 @@ async def completion(
|
||||||
_response.choices[0].text = e.message
|
_response.choices[0].text = e.message
|
||||||
return _response
|
return _response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -3536,7 +3543,11 @@ async def embeddings(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -3559,7 +3570,6 @@ async def embeddings(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -3687,8 +3697,11 @@ async def image_generation(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
@ -3710,7 +3723,6 @@ async def image_generation(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -3825,7 +3837,11 @@ async def audio_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -3991,7 +4007,11 @@ async def audio_transcriptions(
|
||||||
os.remove(file.filename) # Delete the saved file
|
os.remove(file.filename) # Delete the saved file
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4014,7 +4034,6 @@ async def audio_transcriptions(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4093,7 +4112,11 @@ async def get_assistants(
|
||||||
response = await llm_router.aget_assistants(**data)
|
response = await llm_router.aget_assistants(**data)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4114,7 +4137,6 @@ async def get_assistants(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4185,7 +4207,11 @@ async def create_threads(
|
||||||
response = await llm_router.acreate_thread(**data)
|
response = await llm_router.acreate_thread(**data)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4206,7 +4232,6 @@ async def create_threads(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4276,7 +4301,11 @@ async def get_thread(
|
||||||
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4297,7 +4326,6 @@ async def get_thread(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4370,7 +4398,11 @@ async def add_messages(
|
||||||
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4391,7 +4423,6 @@ async def add_messages(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4460,7 +4491,11 @@ async def get_messages(
|
||||||
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4481,7 +4516,6 @@ async def get_messages(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4564,7 +4598,11 @@ async def run_thread(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4585,7 +4623,6 @@ async def run_thread(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4675,7 +4712,11 @@ async def create_batch(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4696,7 +4737,6 @@ async def create_batch(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4781,7 +4821,11 @@ async def retrieve_batch(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4802,7 +4846,6 @@ async def retrieve_batch(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -4897,7 +4940,11 @@ async def create_file(
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -4918,7 +4965,6 @@ async def create_file(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
@ -5041,7 +5087,11 @@ async def moderations(
|
||||||
response = await litellm.amoderation(**data)
|
response = await litellm.amoderation(**data)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
@ -5062,7 +5112,6 @@ async def moderations(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
|
|
@ -153,7 +153,7 @@ def decrypt_env_var() -> Dict[str, Any]:
|
||||||
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
|
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
|
||||||
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
||||||
# reset env var
|
# reset env var
|
||||||
k = re.sub("litellm_secret_aws_kms", "", k, flags=re.IGNORECASE)
|
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
|
||||||
new_values[k] = decrypted_value
|
new_values[k] = decrypted_value
|
||||||
|
|
||||||
return new_values
|
return new_values
|
||||||
|
|
|
@ -272,6 +272,16 @@ class ProxyLogging:
|
||||||
callback_list=callback_list
|
callback_list=callback_list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def update_request_status(
|
||||||
|
self, litellm_call_id: str, status: Literal["success", "fail"]
|
||||||
|
):
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key="request_status:{}".format(litellm_call_id),
|
||||||
|
value=status,
|
||||||
|
local_only=True,
|
||||||
|
ttl=3600,
|
||||||
|
)
|
||||||
|
|
||||||
# The actual implementation of the function
|
# The actual implementation of the function
|
||||||
async def pre_call_hook(
|
async def pre_call_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -560,6 +570,9 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
|
await self.update_request_status(
|
||||||
|
litellm_call_id=request_data.get("litellm_call_id", ""), status="fail"
|
||||||
|
)
|
||||||
if "llm_exceptions" in self.alert_types and not isinstance(
|
if "llm_exceptions" in self.alert_types and not isinstance(
|
||||||
original_exception, HTTPException
|
original_exception, HTTPException
|
||||||
):
|
):
|
||||||
|
@ -611,6 +624,7 @@ class ProxyLogging:
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
_callback: Optional[CustomLogger] = None
|
_callback: Optional[CustomLogger] = None
|
||||||
|
|
|
@ -156,6 +156,7 @@ class Router:
|
||||||
cooldown_time: Optional[
|
cooldown_time: Optional[
|
||||||
float
|
float
|
||||||
] = None, # (seconds) time to cooldown a deployment after failure
|
] = None, # (seconds) time to cooldown a deployment after failure
|
||||||
|
disable_cooldowns: Optional[bool] = None,
|
||||||
routing_strategy: Literal[
|
routing_strategy: Literal[
|
||||||
"simple-shuffle",
|
"simple-shuffle",
|
||||||
"least-busy",
|
"least-busy",
|
||||||
|
@ -307,6 +308,7 @@ class Router:
|
||||||
|
|
||||||
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
||||||
self.cooldown_time = cooldown_time or 60
|
self.cooldown_time = cooldown_time or 60
|
||||||
|
self.disable_cooldowns = disable_cooldowns
|
||||||
self.failed_calls = (
|
self.failed_calls = (
|
||||||
InMemoryCache()
|
InMemoryCache()
|
||||||
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
||||||
|
@ -2990,6 +2992,8 @@ class Router:
|
||||||
|
|
||||||
the exception is not one that should be immediately retried (e.g. 401)
|
the exception is not one that should be immediately retried (e.g. 401)
|
||||||
"""
|
"""
|
||||||
|
if self.disable_cooldowns is True:
|
||||||
|
return
|
||||||
|
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
@ -3030,24 +3034,50 @@ class Router:
|
||||||
exception_status = 500
|
exception_status = 500
|
||||||
_should_retry = litellm._should_retry(status_code=exception_status)
|
_should_retry = litellm._should_retry(status_code=exception_status)
|
||||||
|
|
||||||
if updated_fails > allowed_fails or _should_retry == False:
|
if updated_fails > allowed_fails or _should_retry is False:
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
cached_value = self.cache.get_cache(
|
||||||
|
key=cooldown_key
|
||||||
|
) # [(deployment_id, {last_error_str, last_error_status_code})]
|
||||||
|
|
||||||
|
cached_value_deployment_ids = []
|
||||||
|
if (
|
||||||
|
cached_value is not None
|
||||||
|
and isinstance(cached_value, list)
|
||||||
|
and len(cached_value) > 0
|
||||||
|
and isinstance(cached_value[0], tuple)
|
||||||
|
):
|
||||||
|
cached_value_deployment_ids = [cv[0] for cv in cached_value]
|
||||||
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
||||||
# update value
|
# update value
|
||||||
try:
|
if cached_value is not None and len(cached_value_deployment_ids) > 0:
|
||||||
if deployment in cached_value:
|
if deployment in cached_value_deployment_ids:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
cached_value = cached_value + [deployment]
|
cached_value = cached_value + [
|
||||||
|
(
|
||||||
|
deployment,
|
||||||
|
{
|
||||||
|
"Exception Received": str(original_exception),
|
||||||
|
"Status Code": str(exception_status),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
# save updated value
|
# save updated value
|
||||||
self.cache.set_cache(
|
self.cache.set_cache(
|
||||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||||
)
|
)
|
||||||
except:
|
else:
|
||||||
cached_value = [deployment]
|
cached_value = [
|
||||||
|
(
|
||||||
|
deployment,
|
||||||
|
{
|
||||||
|
"Exception Received": str(original_exception),
|
||||||
|
"Status Code": str(exception_status),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
# save updated value
|
# save updated value
|
||||||
self.cache.set_cache(
|
self.cache.set_cache(
|
||||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||||
|
@ -3063,7 +3093,33 @@ class Router:
|
||||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_get_cooldown_deployments(self):
|
async def _async_get_cooldown_deployments(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Async implementation of '_get_cooldown_deployments'
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
# get the current cooldown list for that minute
|
||||||
|
cooldown_key = f"{current_minute}:cooldown_models"
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Return cooldown models
|
||||||
|
# ----------------------
|
||||||
|
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||||
|
|
||||||
|
cached_value_deployment_ids = []
|
||||||
|
if (
|
||||||
|
cooldown_models is not None
|
||||||
|
and isinstance(cooldown_models, list)
|
||||||
|
and len(cooldown_models) > 0
|
||||||
|
and isinstance(cooldown_models[0], tuple)
|
||||||
|
):
|
||||||
|
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||||
|
|
||||||
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
|
return cached_value_deployment_ids
|
||||||
|
|
||||||
|
async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]:
|
||||||
"""
|
"""
|
||||||
Async implementation of '_get_cooldown_deployments'
|
Async implementation of '_get_cooldown_deployments'
|
||||||
"""
|
"""
|
||||||
|
@ -3080,7 +3136,7 @@ class Router:
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cooldown_models
|
||||||
|
|
||||||
def _get_cooldown_deployments(self):
|
def _get_cooldown_deployments(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the list of models being cooled down for this minute
|
Get the list of models being cooled down for this minute
|
||||||
"""
|
"""
|
||||||
|
@ -3094,8 +3150,17 @@ class Router:
|
||||||
# ----------------------
|
# ----------------------
|
||||||
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
||||||
|
|
||||||
|
cached_value_deployment_ids = []
|
||||||
|
if (
|
||||||
|
cooldown_models is not None
|
||||||
|
and isinstance(cooldown_models, list)
|
||||||
|
and len(cooldown_models) > 0
|
||||||
|
and isinstance(cooldown_models[0], tuple)
|
||||||
|
):
|
||||||
|
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||||
|
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cached_value_deployment_ids
|
||||||
|
|
||||||
def _get_healthy_deployments(self, model: str):
|
def _get_healthy_deployments(self, model: str):
|
||||||
_all_deployments: list = []
|
_all_deployments: list = []
|
||||||
|
@ -4737,7 +4802,7 @@ class Router:
|
||||||
if _allowed_model_region is None:
|
if _allowed_model_region is None:
|
||||||
_allowed_model_region = "n/a"
|
_allowed_model_region = "n/a"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -856,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
|
||||||
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
||||||
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
||||||
mock_client_post.assert_called_once()
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_bedrock_external_client_region():
|
||||||
|
print("\ncalling bedrock claude external client auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
|
||||||
|
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||||
|
aws_region_name = "us-east-1"
|
||||||
|
|
||||||
|
os.environ.pop("AWS_ACCESS_KEY_ID", None)
|
||||||
|
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
bedrock = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
|
||||||
|
)
|
||||||
|
with patch.object(client, "post", new=Mock()) as mock_client_post:
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_bedrock_client=bedrock,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
|
||||||
|
assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
|
||||||
|
|
||||||
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries=3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
|
|
|
@ -249,6 +249,25 @@ def test_completion_azure_exception():
|
||||||
# test_completion_azure_exception()
|
# test_completion_azure_exception()
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_embedding_exceptions():
|
||||||
|
try:
|
||||||
|
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input="hello",
|
||||||
|
messages="hello",
|
||||||
|
)
|
||||||
|
pytest.fail(f"Bad request this should have failed but got {response}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(vars(e))
|
||||||
|
# CRUCIAL Test - Ensures our exceptions are readable and not overly complicated. some users have complained exceptions will randomly have another exception raised in our exception mapping
|
||||||
|
assert (
|
||||||
|
e.message
|
||||||
|
== "litellm.APIError: AzureException APIError - Embeddings.create() got an unexpected keyword argument 'messages'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def asynctest_completion_azure_exception():
|
async def asynctest_completion_azure_exception():
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
|
@ -1,20 +1,23 @@
|
||||||
# What this tests?
|
# What this tests?
|
||||||
## This tests the litellm support for the openai /generations endpoint
|
## This tests the litellm support for the openai /generations endpoint
|
||||||
|
|
||||||
import sys, os
|
|
||||||
import traceback
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,13 +42,25 @@ def test_image_generation_openai():
|
||||||
# test_image_generation_openai()
|
# test_image_generation_openai()
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_azure():
|
@pytest.mark.parametrize(
|
||||||
|
"sync_mode",
|
||||||
|
[True, False],
|
||||||
|
) #
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_generation_azure(sync_mode):
|
||||||
try:
|
try:
|
||||||
|
if sync_mode:
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter",
|
prompt="A cute baby sea otter",
|
||||||
model="azure/",
|
model="azure/",
|
||||||
api_version="2023-06-01-preview",
|
api_version="2023-06-01-preview",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="A cute baby sea otter",
|
||||||
|
model="azure/",
|
||||||
|
api_version="2023-06-01-preview",
|
||||||
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
|
|
@ -155,6 +155,16 @@ class ToolConfig(TypedDict):
|
||||||
functionCallingConfig: FunctionCallingConfig
|
functionCallingConfig: FunctionCallingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TTL(TypedDict, total=False):
|
||||||
|
seconds: Required[float]
|
||||||
|
nano: float
|
||||||
|
|
||||||
|
|
||||||
|
class CachedContent(TypedDict, total=False):
|
||||||
|
ttl: TTL
|
||||||
|
expire_time: str
|
||||||
|
|
||||||
|
|
||||||
class RequestBody(TypedDict, total=False):
|
class RequestBody(TypedDict, total=False):
|
||||||
contents: Required[List[ContentType]]
|
contents: Required[List[ContentType]]
|
||||||
system_instruction: SystemInstructions
|
system_instruction: SystemInstructions
|
||||||
|
@ -162,6 +172,7 @@ class RequestBody(TypedDict, total=False):
|
||||||
toolConfig: ToolConfig
|
toolConfig: ToolConfig
|
||||||
safetySettings: List[SafetSettingsConfig]
|
safetySettings: List[SafetSettingsConfig]
|
||||||
generationConfig: GenerationConfig
|
generationConfig: GenerationConfig
|
||||||
|
cachedContent: str
|
||||||
|
|
||||||
|
|
||||||
class SafetyRatings(TypedDict):
|
class SafetyRatings(TypedDict):
|
||||||
|
|
|
@ -4815,6 +4815,12 @@ def function_to_dict(input_function): # noqa: C901
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def modify_url(original_url, new_path):
|
||||||
|
url = httpx.URL(original_url)
|
||||||
|
modified_url = url.copy_with(path=new_path)
|
||||||
|
return str(modified_url)
|
||||||
|
|
||||||
|
|
||||||
def load_test_model(
|
def load_test_model(
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: str = "",
|
custom_llm_provider: str = "",
|
||||||
|
@ -5810,6 +5816,18 @@ def exception_type(
|
||||||
_model_group = _metadata.get("model_group")
|
_model_group = _metadata.get("model_group")
|
||||||
_deployment = _metadata.get("deployment")
|
_deployment = _metadata.get("deployment")
|
||||||
extra_information = f"\nModel: {model}"
|
extra_information = f"\nModel: {model}"
|
||||||
|
|
||||||
|
exception_provider = "Unknown"
|
||||||
|
if (
|
||||||
|
isinstance(custom_llm_provider, str)
|
||||||
|
and len(custom_llm_provider) > 0
|
||||||
|
):
|
||||||
|
exception_provider = (
|
||||||
|
custom_llm_provider[0].upper()
|
||||||
|
+ custom_llm_provider[1:]
|
||||||
|
+ "Exception"
|
||||||
|
)
|
||||||
|
|
||||||
if _api_base:
|
if _api_base:
|
||||||
extra_information += f"\nAPI Base: `{_api_base}`"
|
extra_information += f"\nAPI Base: `{_api_base}`"
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.41.2"
|
version = "1.41.3"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -90,7 +90,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.41.2"
|
version = "1.41.3"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,6 +8,9 @@ from openai import AsyncOpenAI
|
||||||
import sys, os, dotenv
|
import sys, os, dotenv
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
import litellm
|
||||||
|
import logging
|
||||||
|
|
||||||
# Get the current directory of the file being run
|
# Get the current directory of the file being run
|
||||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
@ -84,9 +87,32 @@ async def test_transcription_async_openai():
|
||||||
assert isinstance(transcript.text, str)
|
assert isinstance(transcript.text, str)
|
||||||
|
|
||||||
|
|
||||||
|
# This file includes the custom callbacks for LiteLLM Proxy
|
||||||
|
# Once defined, these can be passed in proxy_config.yaml
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
self.openai_client = None
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
# init logging config
|
||||||
|
print("logging a transcript kwargs: ", kwargs)
|
||||||
|
print("openai client=", kwargs.get("client"))
|
||||||
|
self.openai_client = kwargs.get("client")
|
||||||
|
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
proxy_handler_instance = MyCustomHandler()
|
||||||
|
|
||||||
|
|
||||||
|
# Set litellm.callbacks = [proxy_handler_instance] on the proxy
|
||||||
|
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcription_on_router():
|
async def test_transcription_on_router():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
litellm.callbacks = [proxy_handler_instance]
|
||||||
print("\n Testing async transcription on router\n")
|
print("\n Testing async transcription on router\n")
|
||||||
try:
|
try:
|
||||||
model_list = [
|
model_list = [
|
||||||
|
@ -108,11 +134,29 @@ async def test_transcription_on_router():
|
||||||
]
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
router_level_clients = []
|
||||||
|
for deployment in router.model_list:
|
||||||
|
_deployment_openai_client = router._get_client(
|
||||||
|
deployment=deployment,
|
||||||
|
kwargs={"model": "whisper-1"},
|
||||||
|
client_type="async",
|
||||||
|
)
|
||||||
|
|
||||||
|
router_level_clients.append(str(_deployment_openai_client))
|
||||||
|
|
||||||
response = await router.atranscription(
|
response = await router.atranscription(
|
||||||
model="whisper",
|
model="whisper",
|
||||||
file=audio_file,
|
file=audio_file,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
# PROD Test
|
||||||
|
# Ensure we ONLY use OpenAI/Azure client initialized on the router level
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
print("OpenAI Client used= ", proxy_handler_instance.openai_client)
|
||||||
|
print("all router level clients= ", router_level_clients)
|
||||||
|
assert proxy_handler_instance.openai_client in router_level_clients
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue