This commit is contained in:
Nandesh Guru 2024-04-24 13:46:46 -07:00
commit 30d1fe7fe3
131 changed files with 6119 additions and 1879 deletions

View file

@ -77,6 +77,9 @@ if __name__ == "__main__":
new_release_body = (
existing_release_body
+ "\n\n"
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
+ "\n\n"
+ "## Load Test LiteLLM Proxy Results"
+ "\n\n"
+ markdown_table

View file

@ -5,7 +5,7 @@
<p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, etc.]
<br>
</p>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/hosted" target="_blank"> Hosted Proxy (Preview)</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
<h4 align="center">
<a href="https://pypi.org/project/litellm/" target="_blank">
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
@ -128,7 +128,9 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
# OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy))
Set Budgets & Rate limits across multiple projects
Track spend + Load Balance across multiple projects
[Hosted Proxy (Preview)](https://docs.litellm.ai/docs/hosted)
The proxy provides:

View file

@ -8,12 +8,13 @@ For companies that need SSO, user management and professional support for LiteLL
:::
This covers:
- ✅ **Features under the [LiteLLM Commercial License](https://docs.litellm.ai/docs/proxy/enterprise):**
- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)**
- ✅ **Feature Prioritization**
- ✅ **Custom Integrations**
- ✅ **Professional Support - Dedicated discord + slack**
- ✅ **Custom SLAs**
- ✅ **Secure access with Single Sign-On**
- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
## Frequently Asked Questions

View file

@ -0,0 +1,49 @@
import Image from '@theme/IdealImage';
# Hosted LiteLLM Proxy
LiteLLM maintains the proxy, so you can focus on your core products.
## [**Get Onboarded**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
This is in alpha. Schedule a call with us, and we'll give you a hosted proxy within 30 minutes.
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
### **Status**: Alpha
Our proxy is already used in production by customers.
See our status page for [**live reliability**](https://status.litellm.ai/)
### **Benefits**
- **No Maintenance, No Infra**: We'll maintain the proxy, and spin up any additional infrastructure (e.g.: separate server for spend logs) to make sure you can load balance + track spend across multiple LLM projects.
- **Reliable**: Our hosted proxy is tested on 1k requests per second, making it reliable for high load.
- **Secure**: LiteLLM is currently undergoing SOC-2 compliance, to make sure your data is as secure as possible.
### Pricing
Pricing is based on usage. We can figure out a price that works for your team, on the call.
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
## **Screenshots**
### 1. Create keys
<Image img={require('../img/litellm_hosted_ui_create_key.png')} />
### 2. Add Models
<Image img={require('../img/litellm_hosted_ui_add_models.png')}/>
### 3. Track spend
<Image img={require('../img/litellm_hosted_usage_dashboard.png')} />
### 4. Configure load balancing
<Image img={require('../img/litellm_hosted_ui_router.png')} />
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)

View file

@ -57,7 +57,7 @@ os.environ["LANGSMITH_API_KEY"] = ""
os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse
litellm.success_callback = ["langfuse"]
litellm.success_callback = ["langsmith"]
response = litellm.completion(
model="gpt-3.5-turbo",

View file

@ -224,6 +224,91 @@ assert isinstance(
```
### Parallel Function Calling
Here's how to pass the result of a function call back to an anthropic model:
```python
from litellm import completion
import os
os.environ["ANTHROPIC_API_KEY"] = "sk-ant.."
litellm.set_verbose = True
### 1ST FUNCTION CALL ###
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
### 2ND FUNCTION CALL ###
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(second_response)
except Exception as e:
print(f"An error occurred - {str(e)}")
```
s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this!
## Usage - Vision
```python

View file

@ -3,8 +3,6 @@ import TabItem from '@theme/TabItem';
# Azure AI Studio
## Sample Usage
**Ensure the following:**
1. The API Base passed ends in the `/v1/` prefix
example:
@ -14,8 +12,11 @@ import TabItem from '@theme/TabItem';
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
**Quick Start**
```python
import litellm
response = litellm.completion(
@ -26,6 +27,9 @@ response = litellm.completion(
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
## Sample Usage - LiteLLM Proxy
1. Add models to your config.yaml
@ -99,6 +103,107 @@ response = litellm.completion(
</Tabs>
</TabItem>
</Tabs>
## Function Calling
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
# set env
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="azure/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $YOUR_API_KEY" \
-d '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "What'\''s the weather like in Boston today?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"tool_choice": "auto"
}'
```
</TabItem>
</Tabs>
## Supported Models
| Model Name | Function Call |

View file

@ -23,7 +23,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
```python
response = completion(
model="gemini/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
safety_settings=[
{
"category": "HARM_CATEGORY_HARASSMENT",

View file

@ -48,6 +48,8 @@ We support ALL Groq models, just set `groq/` as a prefix when sending completion
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| llama3-8b-8192 | `completion(model="groq/llama3-8b-8192", messages)` |
| llama3-70b-8192 | `completion(model="groq/llama3-70b-8192", messages)` |
| llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` |
| mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` |
| gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` |

View file

@ -50,6 +50,7 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported.
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
## Sample Usage - Embedding

View file

@ -163,6 +163,7 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
| Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-1106-preview | `response = completion(model="gpt-4-1106-preview", messages=messages)` |
@ -185,6 +186,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
## OpenAI Vision Models
| Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |
#### Usage

View file

@ -253,6 +253,7 @@ litellm.vertex_location = "us-central1 # Your Location
## Anthropic
| Model Name | Function Call |
|------------------|--------------------------------------|
| claude-3-opus@20240229 | `completion('vertex_ai/claude-3-opus@20240229', messages)` |
| claude-3-sonnet@20240229 | `completion('vertex_ai/claude-3-sonnet@20240229', messages)` |
| claude-3-haiku@20240307 | `completion('vertex_ai/claude-3-haiku@20240307', messages)` |

View file

@ -61,6 +61,22 @@ litellm_settings:
ttl: 600 # will be cached on redis for 600s
```
## SSL
just set `REDIS_SSL="True"` in your .env, and LiteLLM will pick this up.
```env
REDIS_SSL="True"
```
For quick testing, you can also use REDIS_URL, eg.:
```
REDIS_URL="rediss://.."
```
but we **don't** recommend using REDIS_URL in prod. We've noticed a performance difference between using it vs. redis_host, port, etc.
#### Step 2: Add Redis Credentials to .env
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.

View file

@ -600,6 +600,7 @@ general_settings:
"general_settings": {
"completion_model": "string",
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param

View file

@ -16,7 +16,7 @@ Expected Performance in Production
| `/chat/completions` Requests/hour | `126K` |
## 1. Switch of Debug Logging
## 1. Switch off Debug Logging
Remove `set_verbose: True` from your config.yaml
```yaml
@ -40,7 +40,7 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
```
## 2. Batch write spend updates every 60s
## 3. Batch write spend updates every 60s
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
@ -49,11 +49,35 @@ In production, we recommend using a longer interval period of 60s. This reduces
```yaml
general_settings:
master_key: sk-1234
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
```
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
## 3. Move spend logs to separate server
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
Recommended to do this for prod:
```yaml
router_settings:
routing_strategy: usage-based-routing-v2
# redis_url: "os.environ/REDIS_URL"
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD
```
## 5. Switch off resetting budgets
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
```yaml
general_settings:
disable_reset_budget: true
```
## 6. Move spend logs to separate server (BETA)
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
@ -141,24 +165,6 @@ A t2.micro should be sufficient to handle 1k logs / minute on this server.
This consumes at max 120MB, and <0.1 vCPU.
## 4. Switch off resetting budgets
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
```yaml
general_settings:
disable_spend_logs: true
disable_reset_budget: true
```
## 5. Switch of `litellm.telemetry`
Switch of all telemetry tracking done by litellm
```yaml
litellm_settings:
telemetry: False
```
## Machine Specifications to Deploy LiteLLM
| Service | Spec | CPUs | Memory | Architecture | Version|

View file

@ -14,6 +14,7 @@ model_list:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
```
Start the proxy
@ -48,9 +49,10 @@ http://localhost:4000/metrics
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model"` |
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model"` |
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model"` |
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model", "team", "end-user"` |
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model", "team", "end-user"` |
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
## Monitor System Health
@ -69,3 +71,4 @@ litellm_settings:
|----------------------|--------------------------------------|
| `litellm_redis_latency` | histogram latency for redis calls |
| `litellm_redis_fails` | Number of failed redis calls |
| `litellm_self_latency` | Histogram latency for successful litellm api call |

View file

@ -348,6 +348,29 @@ query_result = embeddings.embed_query(text)
print(f"TITAN EMBEDDINGS")
print(query_result[:5])
```
</TabItem>
<TabItem value="litellm" label="LiteLLM SDK">
This is **not recommended**. There is duplicate logic as the proxy also uses the sdk, which might lead to unexpected errors.
```python
from litellm import completion
response = completion(
model="openai/gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
api_key="anything",
base_url="http://0.0.0.0:4000"
)
print(response)
```
</TabItem>
</Tabs>

View file

@ -121,6 +121,9 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",

View file

@ -279,7 +279,7 @@ router_settings:
```
</TabItem>
<TabItem value="simple-shuffle" label="(Default) Weighted Pick">
<TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)">
**Default** Picks a deployment based on the provided **Requests per minute (rpm) or Tokens per minute (tpm)**

View file

@ -105,6 +105,12 @@ const config = {
label: 'Enterprise',
to: "docs/enterprise"
},
{
sidebarId: 'tutorialSidebar',
position: 'left',
label: '🚀 Hosted',
to: "docs/hosted"
},
{
href: 'https://github.com/BerriAI/litellm',
label: 'GitHub',

Binary file not shown.

After

Width:  |  Height:  |  Size: 398 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 496 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 348 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 460 KiB

View file

@ -63,7 +63,7 @@ const sidebars = {
label: "Logging, Alerting",
items: ["proxy/logging", "proxy/alerting", "proxy/streaming_logging"],
},
"proxy/grafana_metrics",
"proxy/prometheus",
"proxy/call_hooks",
"proxy/rules",
"proxy/cli",

View file

@ -16,11 +16,24 @@ dotenv.load_dotenv()
if set_verbose == True:
_turn_on_debug()
#############################################
### Callbacks /Logging / Success / Failure Handlers ###
input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_langfuse_default_tags: Optional[
List[
Literal[
"user_api_key_alias",
"user_api_key_user_id",
"user_api_key_user_email",
"user_api_key_team_alias",
"semantic-similarity",
"proxy_base_url",
]
]
] = None
_async_input_callback: List[Callable] = (
[]
) # internal variable - async custom callbacks are routed here.
@ -32,6 +45,8 @@ _async_failure_callback: List[Callable] = (
) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
## end of callbacks #############
email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
@ -51,6 +66,7 @@ replicate_key: Optional[str] = None
cohere_key: Optional[str] = None
maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None
ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None

View file

@ -32,6 +32,25 @@ def _get_redis_kwargs():
return available_args
def _get_redis_url_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = ["url"]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
@ -91,27 +110,39 @@ def _get_redis_client_logic(**env_overrides):
redis_kwargs.pop("password", None)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop(
"connection_pool", None
) # redis.from_url doesn't support setting your own connection pool
return redis.Redis.from_url(**redis_kwargs)
args = _get_redis_url_kwargs()
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop(
"connection_pool", None
) # redis.from_url doesn't support setting your own connection pool
return async_redis.Redis.from_url(**redis_kwargs)
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
litellm.print_verbose(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
)
return async_redis.Redis.from_url(**url_kwargs)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
@ -124,4 +155,9 @@ def get_redis_connection_pool(**env_overrides):
return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"]
)
connection_class = async_redis.Connection
if "ssl" in redis_kwargs and redis_kwargs["ssl"] is not None:
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -1,9 +1,13 @@
import litellm
import litellm, traceback
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
from typing import Union
class ServiceLogging:
class ServiceLogging(CustomLogger):
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
@ -14,11 +18,12 @@ class ServiceLogging:
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(self, service: ServiceTypes, duration: float):
def service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
@ -26,7 +31,7 @@ class ServiceLogging:
self.mock_testing_sync_success_hook += 1
def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
@ -34,7 +39,9 @@ class ServiceLogging:
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(self, service: ServiceTypes, duration: float):
async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
- For counting if the redis, postgres call is successful
"""
@ -42,7 +49,11 @@ class ServiceLogging:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False, error=None, service=service, duration=duration
is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
@ -51,7 +62,11 @@ class ServiceLogging:
)
async def async_service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception
self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
):
"""
- For counting if the redis, postgres call is unsuccessful
@ -59,8 +74,18 @@ class ServiceLogging:
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload(
is_error=True, error=str(error), service=service, duration=duration
is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
@ -69,3 +94,37 @@ class ServiceLogging:
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
)
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
original_exception, user_api_key_dict
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View file

@ -13,7 +13,6 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger
from litellm._service_logger import ServiceLogging
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
@ -90,6 +89,13 @@ class InMemoryCache(BaseCache):
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
@ -132,6 +138,7 @@ class RedisCache(BaseCache):
**kwargs,
):
from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis
redis_kwargs = {}
@ -142,18 +149,19 @@ class RedisCache(BaseCache):
if password is not None:
redis_kwargs["password"] = password
### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging
):
self.service_logger_obj = kwargs.pop("service_logger_obj")
else:
self.service_logger_obj = ServiceLogging()
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
parsed_kwargs = redis.connection.parse_url(redis_kwargs["url"])
redis_kwargs.update(parsed_kwargs)
self.redis_kwargs.update(parsed_kwargs)
# pop url
self.redis_kwargs.pop("url")
# redis namespaces
self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis
@ -165,8 +173,15 @@ class RedisCache(BaseCache):
except Exception as e:
pass
### HEALTH MONITORING OBJECT ###
self.service_logger_obj = ServiceLogging()
### ASYNC HEALTH PING ###
try:
# asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping())
except Exception:
pass
### SYNC HEALTH PING ###
self.redis_client.ping()
def init_async_client(self):
from ._redis import get_redis_async_client
@ -198,6 +213,42 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)
def increment_cache(self, key, value: int, **kwargs) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="increment_cache",
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="increment_cache",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
start_time = time.time()
try:
@ -216,7 +267,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_scan_iter",
)
) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys
@ -227,7 +280,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_scan_iter",
)
)
raise e
@ -267,7 +323,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache",
)
)
except Exception as e:
@ -275,7 +333,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
@ -292,6 +353,10 @@ class RedisCache(BaseCache):
"""
_redis_client = self.init_async_client()
start_time = time.time()
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
@ -316,7 +381,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_pipeline",
)
)
return results
@ -326,7 +393,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_pipeline",
)
)
@ -359,6 +429,7 @@ class RedisCache(BaseCache):
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_increment",
)
)
return result
@ -368,7 +439,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_increment",
)
)
verbose_logger.error(
@ -459,7 +533,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_get_cache",
)
)
return response
@ -469,7 +545,10 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_get_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
@ -497,7 +576,9 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS, duration=_duration
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_batch_get_cache",
)
)
@ -519,21 +600,81 @@ class RedisCache(BaseCache):
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_batch_get_cache",
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
async def ping(self):
def sync_ping(self) -> bool:
"""
Tests if the sync redis client is correctly setup.
"""
print_verbose(f"Pinging Sync Redis Cache")
start_time = time.time()
try:
response = self.redis_client.ping()
print_verbose(f"Redis Cache PING: {response}")
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="sync_ping",
)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="sync_ping",
)
print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def ping(self) -> bool:
_redis_client = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache")
try:
response = await redis_client.ping()
print_verbose(f"Redis Cache PING: {response}")
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_ping",
)
)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_ping",
)
)
print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
@ -1064,6 +1205,30 @@ class DualCache(BaseCache):
except Exception as e:
print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
try:
@ -1116,7 +1281,7 @@ class DualCache(BaseCache):
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items():
result[sublist_keys.index(key)] = value
result[keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
@ -1166,10 +1331,8 @@ class DualCache(BaseCache):
keys, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
@ -1185,22 +1348,23 @@ class DualCache(BaseCache):
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
for key, value in redis_result.items():
if value is not None:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value
sublist_dict = dict(zip(sublist_keys, redis_result))
for key, value in sublist_dict.items():
result[sublist_keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs)

View file

@ -6,7 +6,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### PRE-CALL CHECKS - router/proxy only ####
"""
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model

View file

@ -34,6 +34,14 @@ class LangFuseLogger:
flush_interval=1, # flush interval in seconds
)
# set the current langfuse project id in the environ
# this is used by Alerting to link to the correct project
try:
project_id = self.Langfuse.client.projects.get().data[0].id
os.environ["LANGFUSE_PROJECT_ID"] = project_id
except:
project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
self.upstream_langfuse_secret_key = os.getenv(
"UPSTREAM_LANGFUSE_SECRET_KEY"
@ -133,6 +141,7 @@ class LangFuseLogger:
self._log_langfuse_v2(
user_id,
metadata,
litellm_params,
output,
start_time,
end_time,
@ -224,6 +233,7 @@ class LangFuseLogger:
self,
user_id,
metadata,
litellm_params,
output,
start_time,
end_time,
@ -278,13 +288,13 @@ class LangFuseLogger:
clean_metadata = {}
if isinstance(metadata, dict):
for key, value in metadata.items():
# generate langfuse tags
if key in [
"user_api_key",
"user_api_key_user_id",
"user_api_key_team_id",
"semantic-similarity",
]:
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and key in litellm._langfuse_default_tags
):
tags.append(f"{key}:{value}")
# clean litellm metadata before logging
@ -298,13 +308,53 @@ class LangFuseLogger:
else:
clean_metadata[key] = value
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and "proxy_base_url" in litellm._langfuse_default_tags
):
proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
if proxy_base_url is not None:
tags.append(f"proxy_base_url:{proxy_base_url}")
api_base = litellm_params.get("api_base", None)
if api_base:
clean_metadata["api_base"] = api_base
vertex_location = kwargs.get("vertex_location", None)
if vertex_location:
clean_metadata["vertex_location"] = vertex_location
aws_region_name = kwargs.get("aws_region_name", None)
if aws_region_name:
clean_metadata["aws_region_name"] = aws_region_name
if supports_tags:
if "cache_hit" in kwargs:
if kwargs["cache_hit"] is None:
kwargs["cache_hit"] = False
tags.append(f"cache_hit:{kwargs['cache_hit']}")
clean_metadata["cache_hit"] = kwargs["cache_hit"]
trace_params.update({"tags": tags})
proxy_server_request = litellm_params.get("proxy_server_request", None)
if proxy_server_request:
method = proxy_server_request.get("method", None)
url = proxy_server_request.get("url", None)
headers = proxy_server_request.get("headers", None)
clean_headers = {}
if headers:
for key, value in headers.items():
# these headers can leak our API keys and/or JWT tokens
if key.lower() not in ["authorization", "cookie", "referer"]:
clean_headers[key] = value
clean_metadata["request"] = {
"method": method,
"url": url,
"headers": clean_headers,
}
print_verbose(f"trace_params: {trace_params}")
trace = self.Langfuse.trace(**trace_params)

View file

@ -7,6 +7,19 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import asyncio
import types
from pydantic import BaseModel
def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
types.FunctionType,
types.GeneratorType,
BaseModel,
)
return not isinstance(value, non_serializable_types)
class LangsmithLogger:
@ -21,7 +34,9 @@ class LangsmithLogger:
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
metadata = kwargs.get('litellm_params', {}).get("metadata", {}) or {} # if metadata is None
metadata = (
kwargs.get("litellm_params", {}).get("metadata", {}) or {}
) # if metadata is None
# set project name and run_name for langsmith logging
# users can pass project_name and run name to litellm.completion()
@ -51,26 +66,46 @@ class LangsmithLogger:
new_kwargs = {}
for key in kwargs:
value = kwargs[key]
if key == "start_time" or key == "end_time":
if key == "start_time" or key == "end_time" or value is None:
pass
elif type(value) == datetime.datetime:
new_kwargs[key] = value.isoformat()
elif type(value) != dict:
elif type(value) != dict and is_serializable(value=value):
new_kwargs[key] = value
requests.post(
"https://api.smith.langchain.com/runs",
json={
print(f"type of response: {type(response_obj)}")
for k, v in new_kwargs.items():
print(f"key={k}, type of arg: {type(v)}, value={v}")
if isinstance(response_obj, BaseModel):
try:
response_obj = response_obj.model_dump()
except:
response_obj = response_obj.dict() # type: ignore
print(f"response_obj: {response_obj}")
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": {**new_kwargs},
"outputs": response_obj.json(),
"inputs": new_kwargs,
"outputs": response_obj,
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
},
}
print(f"data: {data}")
response = requests.post(
"https://api.smith.langchain.com/runs",
json=data,
headers={"x-api-key": self.langsmith_api_key},
)
if response.status_code >= 300:
print_verbose(f"Error: {response.status_code}")
else:
print_verbose("Run successfully created")
print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}"
)

View file

@ -19,27 +19,33 @@ class PrometheusLogger:
**kwargs,
):
try:
verbose_logger.debug(f"in init prometheus metrics")
print(f"in init prometheus metrics")
from prometheus_client import Counter
self.litellm_llm_api_failed_requests_metric = Counter(
name="litellm_llm_api_failed_requests_metric",
documentation="Total number of failed LLM API calls via litellm",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
self.litellm_requests_metric = Counter(
name="litellm_requests_metric",
documentation="Total number of LLM calls to litellm",
labelnames=["end_user", "key", "model", "team"],
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
# Counter for spend
self.litellm_spend_metric = Counter(
"litellm_spend_metric",
"Total spend on LLM requests",
labelnames=["end_user", "key", "model", "team"],
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
# Counter for total_output_tokens
self.litellm_tokens_metric = Counter(
"litellm_total_tokens",
"Total number of input + output tokens from LLM requests",
labelnames=["end_user", "key", "model", "team"],
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
@ -61,29 +67,50 @@ class PrometheusLogger:
# unpack kwargs
model = kwargs.get("model", "")
response_cost = kwargs.get("response_cost", 0.0)
response_cost = kwargs.get("response_cost", 0.0) or 0
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params.get("metadata", {}).get(
"user_api_key_user_id", None
)
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
user_api_team = litellm_params.get("metadata", {}).get(
"user_api_key_team_id", None
)
if response_obj is not None:
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
else:
tokens_used = 0
print_verbose(
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
)
if (
user_api_key is not None
and isinstance(user_api_key, str)
and user_api_key.startswith("sk-")
):
from litellm.proxy.utils import hash_token
user_api_key = hash_token(user_api_key)
self.litellm_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
self.litellm_spend_metric.labels(
end_user_id, user_api_key, model, user_api_team
end_user_id, user_api_key, model, user_api_team, user_id
).inc(response_cost)
self.litellm_tokens_metric.labels(
end_user_id, user_api_key, model, user_api_team
end_user_id, user_api_key, model, user_api_team, user_id
).inc(tokens_used)
### FAILURE INCREMENT ###
if "exception" in kwargs:
self.litellm_llm_api_failed_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
except Exception as e:
traceback.print_exc()
verbose_logger.debug(

View file

@ -44,9 +44,18 @@ class PrometheusServicesLogger:
) # store the prometheus histogram/counter we need to call for each field in payload
for service in self.services:
histogram = self.create_histogram(service)
counter = self.create_counter(service)
self.payload_to_prometheus_map[service] = [histogram, counter]
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service, type_of_request="failed_requests"
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
)
self.payload_to_prometheus_map[service] = [
histogram,
counter_failed_request,
counter_total_requests,
]
self.prometheus_to_amount_map: dict = (
{}
@ -74,26 +83,26 @@ class PrometheusServicesLogger:
return metric
return None
def create_histogram(self, label: str):
metric_name = "litellm_{}_latency".format(label)
def create_histogram(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Histogram(
metric_name,
"Latency for {} service".format(label),
labelnames=[label],
"Latency for {} service".format(service),
labelnames=[service],
)
def create_counter(self, label: str):
metric_name = "litellm_{}_failed_requests".format(label)
def create_counter(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Counter(
metric_name,
"Total failed requests for {} service".format(label),
labelnames=[label],
"Total {} for {} service".format(type_of_request, service),
labelnames=[service],
)
def observe_histogram(
@ -129,6 +138,12 @@ class PrometheusServicesLogger:
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
def service_failure_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
@ -141,7 +156,7 @@ class PrometheusServicesLogger:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG ERROR COUNT TO PROMETHEUS
amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
@ -160,6 +175,12 @@ class PrometheusServicesLogger:
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")

View file

@ -0,0 +1,422 @@
#### What this does ####
# Class for sending Slack Alerts #
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import copy
import traceback
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm
from typing import List, Literal, Any, Union, Optional
from litellm.caching import DualCache
import asyncio
import aiohttp
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class SlackAlerting:
# Class variables or attributes
def __init__(
self,
alerting_threshold: float = 300,
alerting: Optional[List] = [],
alert_types: Optional[
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
]
] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
],
):
self.alerting_threshold = alerting_threshold
self.alerting = alerting
self.alert_types = alert_types
self.internal_usage_cache = DualCache()
self.async_http_handler = AsyncHTTPHandler()
pass
def update_values(
self,
alerting: Optional[List] = None,
alerting_threshold: Optional[float] = None,
alert_types: Optional[List] = None,
):
if alerting is not None:
self.alerting = alerting
if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold
if alert_types is not None:
self.alert_types = alert_types
async def deployment_in_cooldown(self):
pass
async def deployment_removed_from_cooldown(self):
pass
def _all_possible_alert_types(self):
# used by the UI to show all supported alert types
# Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
return [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
def _add_langfuse_trace_id_to_alert(
self,
request_info: str,
request_data: Optional[dict] = None,
kwargs: Optional[dict] = None,
):
import uuid
if request_data is not None:
trace_id = request_data.get("metadata", {}).get(
"trace_id", None
) # get langfuse trace id
if trace_id is None:
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
request_data["metadata"]["trace_id"] = trace_id
elif kwargs is not None:
_litellm_params = kwargs.get("litellm_params", {})
trace_id = _litellm_params.get("metadata", {}).get(
"trace_id", None
) # get langfuse trace id
if trace_id is None:
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
_litellm_params["metadata"]["trace_id"] = trace_id
_langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
_langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
# langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
_langfuse_url = (
f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
)
request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
return request_info
def _response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
start_time,
end_time, # start/end time
):
try:
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
model = kwargs.get("model", "")
api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
messages = kwargs.get("messages", None)
# if messages does not exist fallback to "input"
if messages is None:
messages = kwargs.get("input", None)
# only use first 100 chars for alerting
_messages = str(messages)[:100]
return time_difference_float, model, api_base, _messages
except Exception as e:
raise e
async def response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
if self.alerting is None or self.alert_types is None:
return
if "llm_too_slow" not in self.alert_types:
return
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
)
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, kwargs=kwargs
)
await self.send_alert(
message=slow_message + request_info,
level="Low",
)
async def log_failure_event(self, original_exception: Exception):
pass
async def response_taking_too_long(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request",
request_data: Optional[dict] = None,
):
if self.alerting is None or self.alert_types is None:
return
if request_data is not None:
model = request_data.get("model", "")
messages = request_data.get("messages", None)
if messages is None:
# if messages does not exist fallback to "input"
messages = request_data.get("input", None)
# try casting messages to str and get the first 100 characters, else mark as None
try:
messages = str(messages)
messages = messages[:100]
except:
messages = ""
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, request_data=request_data
)
else:
request_info = ""
if type == "hanging_request":
# Simulate a long-running operation that could take more than 5 minutes
if "llm_requests_hanging" not in self.alert_types:
return
await asyncio.sleep(
self.alerting_threshold
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
if (
request_data is not None
and request_data.get("litellm_status", "") != "success"
and request_data.get("litellm_status", "") != "fail"
):
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
_api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get(
"litellm_params", {}
),
)
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: {_api_base}"
elif request_data.get("metadata", None) is not None and isinstance(
request_data["metadata"], dict
):
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
# in that case we fallback to the api base set in the request metadata
_metadata = request_data["metadata"]
_api_base = _metadata.get("api_base", "")
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: `{_api_base}`"
# only alert hanging responses if they have not been marked as success
alerting_message = (
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
)
await self.send_alert(
message=alerting_message + request_info,
level="Medium",
)
async def budget_alerts(
self,
type: Literal[
"token_budget",
"user_budget",
"user_and_proxy_budget",
"failed_budgets",
"failed_tracking",
"projected_limit_exceeded",
],
user_max_budget: float,
user_current_spend: float,
user_info=None,
error_message="",
):
if self.alerting is None or self.alert_types is None:
# do nothing if alerting is not switched on
return
if "budget_alerts" not in self.alert_types:
return
_id: str = "default_id" # used for caching
if type == "user_and_proxy_budget":
user_info = dict(user_info)
user_id = user_info["user_id"]
_id = user_id
max_budget = user_info["max_budget"]
spend = user_info["spend"]
user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
_id = token
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
elif type == "failed_tracking":
user_id = str(user_info)
_id = user_id
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
message = "Failed Tracking Cost for" + user_info
await self.send_alert(
message=message,
level="High",
)
return
elif type == "projected_limit_exceeded" and user_info is not None:
"""
Input variables:
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
user_max_budget=soft_limit,
user_current_spend=new_spend
"""
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
await self.send_alert(
message=message,
level="High",
)
return
else:
user_info = str(user_info)
# percent of max_budget left to spend
if user_max_budget > 0:
percent_left = (user_max_budget - user_current_spend) / user_max_budget
else:
percent_left = 0
verbose_proxy_logger.debug(
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
)
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
# - Alert once within 28d period
# - Cache this information
# - Don't re-alert, if alert already sent
_cache: DualCache = self.internal_usage_cache
# check if crossed budget
if user_current_spend >= user_max_budget:
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
message = "Budget Crossed for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message,
level="High",
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
# check if 5% of max budget is left
if percent_left <= 0.05:
message = "5% budget left for" + user_info
cache_key = "alerting:{}".format(_id)
result = await _cache.async_get_cache(key=cache_key)
if result is None:
await self.send_alert(
message=message,
level="Medium",
)
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
return
# check if 15% of max budget is left
if percent_left <= 0.15:
message = "15% budget left for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message,
level="Low",
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
return
async def send_alert(self, message: str, level: Literal["Low", "Medium", "High"]):
"""
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
- Responses taking too long
- Requests are hanging
- Calls are failing
- DB Read/Writes are failing
- Proxy Close to max budget
- Key Close to max budget
Parameters:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
print(
"inside send alert for slack, message: ",
message,
"self.alerting: ",
self.alerting,
)
if self.alerting is None:
return
from datetime import datetime
import json
# Get the current timestamp
current_time = datetime.now().strftime("%H:%M:%S")
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
formatted_message = (
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
)
if _proxy_base_url is not None:
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
response = await self.async_http_handler.post(
url=slack_webhook_url,
headers=headers,
data=json.dumps(payload),
)
if response.status_code == 200:
pass
else:
print("Error sending slack alert. Error=", response.text) # noqa

View file

@ -258,8 +258,9 @@ class AnthropicChatCompletion(BaseLLM):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
api_base, headers=headers, data=json.dumps(data), stream=True
)
if response.status_code != 200:

View file

@ -43,6 +43,7 @@ class CohereChatConfig:
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
@ -62,6 +63,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
@ -82,6 +84,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():

View file

@ -41,13 +41,12 @@ class AsyncHTTPHandler:
data: Optional[Union[dict, str]] = None, # type: ignore
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
response = await self.client.post(
url,
data=data, # type: ignore
params=params,
headers=headers,
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream)
return response
def __del__(self) -> None:

View file

@ -228,7 +228,7 @@ def get_ollama_response(
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
@ -330,7 +330,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,

View file

@ -148,7 +148,7 @@ class OllamaChatConfig:
if param == "top_p":
optional_params["top_p"] = value
if param == "frequency_penalty":
optional_params["repeat_penalty"] = param
optional_params["repeat_penalty"] = value
if param == "stop":
optional_params["stop"] = value
if param == "response_format" and value["type"] == "json_object":
@ -184,6 +184,7 @@ class OllamaChatConfig:
# ollama implementation
def get_ollama_response(
api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2",
messages=None,
optional_params=None,
@ -236,6 +237,7 @@ def get_ollama_response(
if stream == True:
response = ollama_async_streaming(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -244,6 +246,7 @@ def get_ollama_response(
else:
response = ollama_acompletion(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -252,12 +255,17 @@ def get_ollama_response(
)
return response
elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
response = requests.post(**_request) # type: ignore
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
@ -307,10 +315,16 @@ def get_ollama_response(
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
def ollama_completion_stream(url, api_key, data, logging_obj):
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
with httpx.stream(**_request) as response:
try:
if response.status_code != 200:
raise OllamaError(
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
async with client.stream(**_request) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async def ollama_acompletion(
url, data, model_response, encoding, logging_obj, function_name
url,
api_key: Optional[str],
data,
model_response,
encoding,
logging_obj,
function_name,
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
resp = await session.post(**_request)
if resp.status != 200:
text = await resp.text()

View file

@ -145,6 +145,12 @@ def mistral_api_pt(messages):
elif isinstance(m["content"], str):
texts = m["content"]
new_m = {"role": m["role"], "content": texts}
if new_m["role"] == "tool" and m.get("name"):
new_m["name"] = m["name"]
if m.get("tool_calls"):
new_m["tool_calls"] = m["tool_calls"]
new_messages.append(new_m)
return new_messages
@ -218,6 +224,18 @@ def phind_codellama_pt(messages):
return prompt
known_tokenizer_config = {
"mistralai/Mistral-7B-Instruct-v0.1": {
"tokenizer": {
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"bos_token": "<s>",
"eos_token": "</s>",
},
"status": "success",
}
}
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
# Define Jinja2 environment
env = ImmutableSandboxedEnvironment()
@ -246,6 +264,9 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
else:
return {"status": "failure"}
if model in known_tokenizer_config:
tokenizer_config = known_tokenizer_config[model]
else:
tokenizer_config = _get_tokenizer_config(model)
if (
tokenizer_config["status"] == "failure"
@ -253,13 +274,13 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
):
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
bos_token = tokenizer_config["bos_token"] # type: ignore
eos_token = tokenizer_config["eos_token"] # type: ignore
chat_template = tokenizer_config["chat_template"] # type: ignore
try:
template = env.from_string(chat_template)
template = env.from_string(chat_template) # type: ignore
except Exception as e:
raise e
@ -466,10 +487,11 @@ def construct_tool_use_system_prompt(
): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb
tool_str_list = []
for tool in tools:
tool_function = get_attribute_or_key(tool, "function")
tool_str = construct_format_tool_for_claude_prompt(
tool["function"]["name"],
tool["function"].get("description", ""),
tool["function"].get("parameters", {}),
get_attribute_or_key(tool_function, "name"),
get_attribute_or_key(tool_function, "description", ""),
get_attribute_or_key(tool_function, "parameters", {}),
)
tool_str_list.append(tool_str)
tool_use_system_prompt = (
@ -593,7 +615,8 @@ def convert_to_anthropic_tool_result_xml(message: dict) -> str:
</function_results>
"""
name = message.get("name")
content = message.get("content")
content = message.get("content", "")
content = content.replace("<", "&lt;").replace(">", "&gt;").replace("&", "&amp;")
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
@ -614,13 +637,15 @@ def convert_to_anthropic_tool_result_xml(message: dict) -> str:
def convert_to_anthropic_tool_invoke_xml(tool_calls: list) -> str:
invokes = ""
for tool in tool_calls:
if tool["type"] != "function":
if get_attribute_or_key(tool, "type") != "function":
continue
tool_name = tool["function"]["name"]
tool_function = get_attribute_or_key(tool,"function")
tool_name = get_attribute_or_key(tool_function, "name")
tool_arguments = get_attribute_or_key(tool_function, "arguments")
parameters = "".join(
f"<{param}>{val}</{param}>\n"
for param, val in json.loads(tool["function"]["arguments"]).items()
for param, val in json.loads(tool_arguments).items()
)
invokes += (
"<invoke>\n"
@ -674,7 +699,7 @@ def anthropic_messages_pt_xml(messages: list):
{
"type": "text",
"text": (
convert_to_anthropic_tool_result(messages[msg_i])
convert_to_anthropic_tool_result_xml(messages[msg_i])
if messages[msg_i]["role"] == "tool"
else messages[msg_i]["content"]
),
@ -695,7 +720,7 @@ def anthropic_messages_pt_xml(messages: list):
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_text += convert_to_anthropic_tool_invoke( # type: ignore
assistant_text += convert_to_anthropic_tool_invoke_xml( # type: ignore
messages[msg_i]["tool_calls"]
)
@ -807,12 +832,12 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
anthropic_tool_invoke = [
{
"type": "tool_use",
"id": tool["id"],
"name": tool["function"]["name"],
"input": json.loads(tool["function"]["arguments"]),
"id": get_attribute_or_key(tool, "id"),
"name": get_attribute_or_key(get_attribute_or_key(tool, "function"), "name"),
"input": json.loads(get_attribute_or_key(get_attribute_or_key(tool, "function"), "arguments")),
}
for tool in tool_calls
if tool["type"] == "function"
if get_attribute_or_key(tool, "type") == "function"
]
return anthropic_tool_invoke
@ -1033,7 +1058,8 @@ def cohere_message_pt(messages: list):
tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result)
else:
prompt += message["content"]
prompt += message["content"] + "\n\n"
prompt = prompt.rstrip()
return prompt, tool_results
@ -1107,12 +1133,6 @@ def _gemini_vision_convert_messages(messages: list):
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
"""
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
try:
# given messages for gpt-4 vision, convert them for gemini
@ -1139,6 +1159,12 @@ def _gemini_vision_convert_messages(messages: list):
image = _load_image_from_url(img)
processed_images.append(image)
else:
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
# Case 2: Image filepath (e.g. temp.jpeg) given
image = Image.open(img)
processed_images.append(image)
@ -1355,3 +1381,8 @@ def prompt_factory(
return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
def get_attribute_or_key(tool_or_function, attribute, default=None):
if hasattr(tool_or_function, attribute):
return getattr(tool_or_function, attribute)
return tool_or_function.get(attribute, default)

View file

@ -22,6 +22,35 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs
class ExtendedGenerationConfig(dict):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
):
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
class VertexAIConfig:
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
@ -43,6 +72,10 @@ class VertexAIConfig:
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
Note: Please make sure to modify the default parameters as required for your use case.
"""
@ -53,6 +86,8 @@ class VertexAIConfig:
response_mime_type: Optional[str] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
def __init__(
self,
@ -63,6 +98,8 @@ class VertexAIConfig:
response_mime_type: Optional[str] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -87,6 +124,64 @@ class VertexAIConfig:
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stream":
optional_params["stream"] = value
if param == "n":
optional_params["candidate_count"] = value
if param == "stop":
if isinstance(value, str):
optional_params["stop_sequences"] = [value]
elif isinstance(value, list):
optional_params["stop_sequences"] = value
if param == "max_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
from vertexai.preview import generative_models
gtool_func_declarations = []
for tool in value:
gtool_func_declaration = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
generative_models.Tool(
function_declarations=gtool_func_declarations
)
]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
pass
return optional_params
import asyncio
@ -130,8 +225,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
image_bytes = response.content
return image_bytes
except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout)
return b"" # Return an empty bytes object or handle the error as needed
raise Exception(f"An exception occurs with this image - {str(e)}")
def _load_image_from_url(image_url: str):
@ -152,7 +246,8 @@ def _load_image_from_url(image_url: str):
)
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes)
return Image.from_bytes(data=image_bytes)
def _gemini_vision_convert_messages(messages: list):
@ -309,47 +404,20 @@ def completion(
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore
class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
):
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
if "response_mime_type" in args_spec.args:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
else:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
)
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
@ -487,12 +555,12 @@ def completion(
model_response = llm_model.generate_content(
contents=content,
generation_config=ExtendedGenerationConfig(**optional_params),
generation_config=optional_params,
safety_settings=safety_settings,
stream=True,
tools=tools,
)
optional_params["stream"] = True
return model_response
request_str += f"response = llm_model.generate_content({content})\n"
@ -509,7 +577,7 @@ def completion(
## LLM Call
response = llm_model.generate_content(
contents=content,
generation_config=ExtendedGenerationConfig(**optional_params),
generation_config=optional_params,
safety_settings=safety_settings,
tools=tools,
)
@ -564,7 +632,7 @@ def completion(
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
@ -596,7 +664,7 @@ def completion(
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
@ -748,47 +816,8 @@ async def async_completion(
Add support for acompletion calls for gemini-pro
"""
try:
from vertexai.preview.generative_models import GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
):
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
if "response_mime_type" in args_spec.args:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
else:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
)
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
@ -807,14 +836,15 @@ async def async_completion(
)
## LLM Call
# print(f"final content: {content}")
response = await llm_model._generate_content_async(
contents=content,
generation_config=ExtendedGenerationConfig(**optional_params),
generation_config=optional_params,
tools=tools,
)
if tools is not None and hasattr(
response.candidates[0].content.parts[0], "function_call"
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):
function_call = response.candidates[0].content.parts[0].function_call
args_dict = {}
@ -993,45 +1023,6 @@ async def async_streaming(
"""
Add support for async streaming calls for gemini-pro
"""
from vertexai.preview.generative_models import GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
):
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
if "response_mime_type" in args_spec.args:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
else:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
)
if mode == "vision":
stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None)
@ -1052,11 +1043,10 @@ async def async_streaming(
response = await llm_model._generate_content_streaming_async(
contents=content,
generation_config=ExtendedGenerationConfig(**optional_params),
generation_config=optional_params,
tools=tools,
)
optional_params["stream"] = True
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop(
@ -1075,7 +1065,7 @@ async def async_streaming(
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True
elif mode == "text":
optional_params.pop(
"stream", None
@ -1171,6 +1161,7 @@ def embedding(
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
print_verbose=None,
):
@ -1191,6 +1182,16 @@ def embedding(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"

View file

@ -12,7 +12,6 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
import litellm
from ._logging import verbose_logger
@ -342,6 +341,7 @@ async def acompletion(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=completion_kwargs,
extra_kwargs=kwargs,
)
@ -608,6 +608,7 @@ def completion(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -1682,13 +1683,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
@ -1704,12 +1706,13 @@ def completion(
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
)
@ -1939,9 +1942,16 @@ def completion(
or "http://localhost:11434"
)
api_key = (
api_key
or litellm.ollama_key
or os.environ.get("OLLAMA_API_KEY")
or litellm.api_key
)
## LOGGING
generator = ollama_chat.get_ollama_response(
api_base,
api_key,
model,
messages,
optional_params,
@ -2137,6 +2147,7 @@ def completion(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -2498,6 +2509,7 @@ async def aembedding(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -2549,6 +2561,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
@ -2606,6 +2619,7 @@ def embedding(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -2807,6 +2821,11 @@ def embedding(
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
response = vertex_ai.embedding(
model=model,
@ -2817,6 +2836,7 @@ def embedding(
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
@ -2933,7 +2953,10 @@ def embedding(
)
## Map to OpenAI Exception
raise exception_type(
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
model=model,
original_exception=e,
custom_llm_provider=custom_llm_provider,
extra_kwargs=kwargs,
)
@ -3027,6 +3050,7 @@ async def atext_completion(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3364,6 +3388,7 @@ async def aimage_generation(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3454,6 +3479,7 @@ def image_generation(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
@ -3563,6 +3589,7 @@ def image_generation(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=locals(),
extra_kwargs=kwargs,
)
@ -3612,6 +3639,7 @@ async def atranscription(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)

View file

@ -75,7 +75,8 @@
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
@ -86,7 +87,8 @@
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4-1106-preview": {
"max_tokens": 4096,
@ -648,6 +650,7 @@
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
},
"mistral/mistral-small-latest": {
@ -657,6 +660,7 @@
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
},
"mistral/mistral-medium": {
@ -706,6 +710,16 @@
"mode": "chat",
"supports_function_calling": true
},
"mistral/open-mixtral-8x7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
"mistral/mistral-embed": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -723,6 +737,26 @@
"mode": "chat",
"supports_function_calling": true
},
"groq/llama3-8b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/llama3-70b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000064,
"output_cost_per_token": 0.00000080,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/mixtral-8x7b-32768": {
"max_tokens": 32768,
"max_input_tokens": 32768,
@ -777,7 +811,9 @@
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "anthropic",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 264
},
"claude-3-opus-20240229": {
"max_tokens": 4096,
@ -786,7 +822,9 @@
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "anthropic",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 395
},
"claude-3-sonnet-20240229": {
"max_tokens": 4096,
@ -795,7 +833,9 @@
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "anthropic",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 159
},
"text-bison": {
"max_tokens": 1024,
@ -1010,6 +1050,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0215": {
@ -1021,6 +1062,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0409": {
@ -1032,6 +1074,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-experimental": {
@ -1043,6 +1086,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": false,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-pro-vision": {
@ -1097,7 +1141,8 @@
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096,
@ -1106,7 +1151,8 @@
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096,
@ -1115,7 +1161,8 @@
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.0000075,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"textembedding-gecko": {
"max_tokens": 3072,
@ -1268,8 +1315,23 @@
"litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-pro-latest": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"source": "https://ai.google.dev/models/gemini"
},
"gemini/gemini-pro-vision": {
"max_tokens": 2048,
"max_input_tokens": 30720,
@ -1484,6 +1546,13 @@
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "openrouter",
"mode": "chat"
},
"j2-ultra": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -1731,7 +1800,8 @@
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -1740,7 +1810,8 @@
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
@ -1749,7 +1820,8 @@
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"anthropic.claude-v1": {
"max_tokens": 8191,

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{11837:function(n,e,t){Promise.resolve().then(t.t.bind(t,99646,23)),Promise.resolve().then(t.t.bind(t,63385,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_c23dc8', '__Inter_Fallback_c23dc8'",fontStyle:"normal"},className:"__className_c23dc8"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=11837)}),_N_E=n.O()}]);
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{11837:function(n,e,t){Promise.resolve().then(t.t.bind(t,99646,23)),Promise.resolve().then(t.t.bind(t,63385,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_12bbc4', '__Inter_Fallback_12bbc4'",fontStyle:"normal"},className:"__className_12bbc4"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=11837)}),_N_E=n.O()}]);

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/11cfce8bfdf6e8f1.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/889eb79902810cea.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-59f93936973f5f5a.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-bcf69420342937de.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-442a9c01c3fd20f9.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-096338c8e1915716.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-59f93936973f5f5a.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/11cfce8bfdf6e8f1.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[8251,[\"289\",\"static/chunks/289-04be6cb9636840d2.js\",\"931\",\"static/chunks/app/page-15d0c6c10d700825.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/11cfce8bfdf6e8f1.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"fcTpSzljtxsSagYnqnMB2\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-06c4978d6b66bb10.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-dafd44dfa2da140c.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e49705773ae41779.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-096338c8e1915716.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-06c4978d6b66bb10.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/889eb79902810cea.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[67392,[\"127\",\"static/chunks/127-efd0436630e294eb.js\",\"931\",\"static/chunks/app/page-f1971f791bb7ca83.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/889eb79902810cea.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"bWtcV5WstBNX-ygMm1ejg\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_12bbc4\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""]
3:I[8251,["289","static/chunks/289-04be6cb9636840d2.js","931","static/chunks/app/page-15d0c6c10d700825.js"],""]
3:I[67392,["127","static/chunks/127-efd0436630e294eb.js","931","static/chunks/app/page-f1971f791bb7ca83.js"],""]
4:I[5613,[],""]
5:I[31778,[],""]
0:["fcTpSzljtxsSagYnqnMB2",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/11cfce8bfdf6e8f1.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
0:["bWtcV5WstBNX-ygMm1ejg",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_12bbc4","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/889eb79902810cea.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null

View file

@ -1,51 +1,61 @@
environment_variables:
LANGFUSE_PUBLIC_KEY: Q6K8MQN6L7sPYSJiFKM9eNrETOx6V/FxVPup4FqdKsZK1hyR4gyanlQ2KHLg5D5afng99uIt0JCEQ2jiKF9UxFvtnb4BbJ4qpeceH+iK8v/bdg==
LANGFUSE_SECRET_KEY: 5xQ7KMa6YMLsm+H/Pf1VmlqWq1NON5IoCxABhkUBeSck7ftsj2CmpkL2ZwrxwrktgiTUBH+3gJYBX+XBk7lqOOUpvmiLjol/E5lCqq0M1CqLWA==
SLACK_WEBHOOK_URL: RJjhS0Hhz0/s07sCIf1OTXmTGodpK9L2K9p953Z+fOX0l2SkPFT6mB9+yIrLufmlwEaku5NNEBKy//+AG01yOd+7wV1GhK65vfj3B/gTN8t5cuVnR4vFxKY5Rx4eSGLtzyAs+aIBTp4GoNXDIjroCqfCjPkItEZWCg==
general_settings:
alerting:
- slack
alerting_threshold: 300
database_connection_pool_limit: 100
database_connection_timeout: 60
disable_master_key_return: true
health_check_interval: 300
proxy_batch_write_at: 60
ui_access_mode: all
# master_key: sk-1234
litellm_settings:
allowed_fails: 3
failure_callback:
- prometheus
num_retries: 3
service_callback:
- prometheus_system
success_callback:
- langfuse
- prometheus
- langsmith
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
stream_timeout: 0.001
rpm: 10
- litellm_params:
model: azure/chatgpt-v-2
model: gpt-3.5-turbo
model_name: gpt-3.5-turbo
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
stream_timeout: 0.001
model_name: fake-openai-endpoint
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model-2
stream_timeout: 0.001
model_name: fake-openai-endpoint
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
api_version: 2023-07-01-preview
model: azure/chatgpt-v-2
stream_timeout: 0.001
model_name: azure-gpt-3.5
# - model_name: text-embedding-ada-002
# litellm_params:
# model: text-embedding-ada-002
# api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-instruct
litellm_params:
- litellm_params:
api_key: os.environ/OPENAI_API_KEY
model: text-embedding-ada-002
model_name: text-embedding-ada-002
- litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct
# api_key: my-fake-key
# api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
success_callback: ["prometheus"]
service_callback: ["prometheus_system"]
upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
model_name: gpt-instruct
router_settings:
routing_strategy: usage-based-routing-v2
enable_pre_call_checks: true
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: True
general_settings:
master_key: sk-1234
allow_user_auth: true
alerting: ["slack"]
store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
enable_jwt_auth: True
alerting: ["slack"]
litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_admin"
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
user_id_jwt_field: "sub"
org_id_jwt_field: "azp"

View file

@ -51,7 +51,8 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
class LiteLLMRoutes(enum.Enum):
openai_routes: List = [ # chat completions
openai_routes: List = [
# chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
@ -77,7 +78,22 @@ class LiteLLMRoutes(enum.Enum):
"/v1/models",
]
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
info_routes: List = [
"/key/info",
"/team/info",
"/user/info",
"/model/info",
"/v2/model/info",
"/v2/key/info",
]
sso_only_routes: List = [
"/key/generate",
"/key/update",
"/key/delete",
"/global/spend/logs",
"/global/predict/spend/logs",
]
management_routes: List = [ # key
"/key/generate",
@ -689,6 +705,21 @@ class ConfigGeneralSettings(LiteLLMBase):
None,
description="List of alerting integrations. Today, just slack - `alerting: ['slack']`",
)
alert_types: Optional[
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
]
] = Field(
None,
description="List of alerting types. By default it is all alerts",
)
alerting_threshold: Optional[int] = Field(
None,
description="sends alerts if requests hang for 5min+",
@ -719,6 +750,10 @@ class ConfigYAML(LiteLLMBase):
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
)
general_settings: Optional[ConfigGeneralSettings] = None
router_settings: Optional[dict] = Field(
None,
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
)
class Config:
protected_namespaces = ()
@ -765,6 +800,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
"""
team_spend: Optional[float] = None
team_alias: Optional[str] = None
team_tpm_limit: Optional[int] = None
team_rpm_limit: Optional[int] = None
team_max_budget: Optional[float] = None
@ -788,6 +824,10 @@ class UserAPIKeyAuth(
def check_api_key(cls, values):
if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))})
if isinstance(values.get("api_key"), str) and values.get(
"api_key"
).startswith("sk-"):
values.update({"api_key": hash_token(values.get("api_key"))})
return values

View file

@ -10,17 +10,11 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
default_team_settings:
- team_id: team-1
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1
langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1
- team_id: team-2
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
general_settings:
store_model_in_db: true
master_key: sk-1234
alerting: ["slack"]
litellm_settings:
success_callback: ["langfuse"]
_langfuse_default_tags: ["user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"]

View file

@ -1010,8 +1010,12 @@ async def user_api_key_auth(
db=custom_db_client,
)
)
if route in LiteLLMRoutes.info_routes.value and (
not _is_user_proxy_admin(user_id_information)
if not _is_user_proxy_admin(user_id_information): # if non-admin
if route in LiteLLMRoutes.openai_routes.value:
pass
elif (
route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route
if route == "/key/info":
# check if user can access this route
@ -1049,9 +1053,14 @@ async def user_api_key_auth(
status_code=status.HTTP_403_FORBIDDEN,
detail="key not allowed to access this team's info",
)
elif (
_has_user_setup_sso()
and route in LiteLLMRoutes.sso_only_routes.value
):
pass
else:
raise Exception(
f"Only master key can be used to generate, delete, update info for new keys/users."
f"Only master key can be used to generate, delete, update info for new keys/users/teams. Route={route}"
)
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
@ -1098,6 +1107,13 @@ async def user_api_key_auth(
return UserAPIKeyAuth(
api_key=api_key, user_role="proxy_admin", **valid_token_dict
)
elif (
_has_user_setup_sso()
and route in LiteLLMRoutes.sso_only_routes.value
):
return UserAPIKeyAuth(
api_key=api_key, user_role="app_owner", **valid_token_dict
)
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
@ -2201,9 +2217,9 @@ class ProxyConfig:
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.failure_callback.append(callback)
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
)
print( # noqa
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
) # noqa
elif key == "cache_params":
# this is set in the cache branch
# see usage here: https://docs.litellm.ai/docs/proxy/caching
@ -2279,6 +2295,7 @@ class ProxyConfig:
proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600),
alert_types=general_settings.get("alert_types", None),
redis_cache=redis_usage_cache,
)
### CONNECT TO DATABASE ###
@ -2295,7 +2312,7 @@ class ProxyConfig:
master_key = litellm.get_secret(master_key)
if master_key is not None and isinstance(master_key, str):
litellm_master_key_hash = master_key
litellm_master_key_hash = hash_token(master_key)
### STORE MODEL IN DB ### feature flag for `/model/new`
store_model_in_db = general_settings.get("store_model_in_db", False)
if store_model_in_db is None:
@ -2406,27 +2423,44 @@ class ProxyConfig:
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
return router, model_list, general_settings
async def _delete_deployment(self, db_models: list):
def get_model_info_with_id(self, model) -> RouterModelInfo:
"""
Common logic across add + delete router models
Parameters:
- deployment
Return model info w/ id
"""
if model.model_info is not None and isinstance(model.model_info, dict):
if "id" not in model.model_info:
model.model_info["id"] = model.model_id
_model_info = RouterModelInfo(**model.model_info)
else:
_model_info = RouterModelInfo(id=model.model_id)
return _model_info
async def _delete_deployment(self, db_models: list) -> int:
"""
(Helper function of add deployment) -> combined to reduce prisma db calls
- Create all up list of model id's (db + config)
- Compare all up list to router model id's
- Remove any that are missing
Return:
- int - returns number of deleted deployments
"""
global user_config_file_path, llm_router
combined_id_list = []
if llm_router is None:
return
return 0
## DB MODELS ##
for m in db_models:
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
combined_id_list.append(m.model_id)
else:
combined_id_list.append(m.model_id)
model_info = self.get_model_info_with_id(model=m)
if model_info.id is not None:
combined_id_list.append(model_info.id)
## CONFIG MODELS ##
config = await self.get_config(config_file_path=user_config_file_path)
model_list = config.get("model_list", None)
@ -2436,46 +2470,89 @@ class ProxyConfig:
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
model_id = litellm.Router()._generate_model_id(
## check if they have model-id's ##
model_id = model.get("model_info", {}).get("id", None)
if model_id is None:
## else - generate stable id's ##
model_id = llm_router._generate_model_id(
model_group=model["model_name"],
litellm_params=model["litellm_params"],
)
combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
router_model_ids = llm_router.get_model_ids()
# Check for model IDs in llm_router not present in combined_id_list and delete them
deleted_deployments = 0
for model_id in router_model_ids:
if model_id not in combined_id_list:
llm_router.delete_deployment(id=model_id)
is_deleted = llm_router.delete_deployment(id=model_id)
if is_deleted is not None:
deleted_deployments += 1
return deleted_deployments
async def add_deployment(
self,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
):
def _add_deployment(self, db_models: list) -> int:
"""
- Check db for new models (last 10 most recently updated)
- Check if model id's in router already
- If not, add to router
"""
global llm_router, llm_model_list, master_key, general_settings
Iterate through db models
for any not in router - add them.
Return - number of deployments added
"""
import base64
try:
if master_key is None or not isinstance(master_key, str):
raise Exception(
f"Master key is not initialized or formatted. master_key={master_key}"
)
verbose_proxy_logger.debug(f"llm_router: {llm_router}")
if llm_router is None:
new_models = (
await prisma_client.db.litellm_proxymodeltable.find_many()
) # get all models in db
return 0
added_models = 0
## ADD MODEL LOGIC
for m in db_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
continue # skip to next model
_model_info = self.get_model_info_with_id(model=m)
added = llm_router.add_deployment(
deployment=Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
)
if added is not None:
added_models += 1
return added_models
async def _update_llm_router(
self,
new_models: list,
proxy_logging_obj: ProxyLogging,
):
global llm_router, llm_model_list, master_key, general_settings
import base64
if llm_router is None and master_key is not None:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
_model_list: list = []
@ -2489,7 +2566,7 @@ class ProxyConfig:
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
value=decoded_b64, master_key=master_key # type: ignore
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
@ -2498,13 +2575,7 @@ class ProxyConfig:
)
continue # skip to next model
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
_model_info = self.get_model_info_with_id(model=m)
_model_list.append(
Deployment(
model_name=m.model_name,
@ -2512,50 +2583,19 @@ class ProxyConfig:
model_info=_model_info,
).to_json(exclude_none=True)
)
if len(_model_list) > 0:
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(model_list=_model_list)
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
else:
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
## DELETE MODEL LOGIC
await self._delete_deployment(db_models=new_models)
## ADD MODEL LOGIC
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
if isinstance(v, str):
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
verbose_proxy_logger.error(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
continue # skip to next model
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
llm_router.add_deployment(
deployment=Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
)
self._add_deployment(db_models=new_models)
if llm_router is not None:
llm_model_list = llm_router.get_model_list()
# check if user set any callbacks in Config Table
@ -2572,7 +2612,7 @@ class ProxyConfig:
for k, v in environment_variables.items():
try:
decoded_b64 = base64.b64decode(v)
value = decrypt_value(value=decoded_b64, master_key=master_key)
value = decrypt_value(value=decoded_b64, master_key=master_key) # type: ignore
os.environ[k] = value
except Exception as e:
verbose_proxy_logger.error(
@ -2584,7 +2624,44 @@ class ProxyConfig:
if "alerting" in _general_settings:
general_settings["alerting"] = _general_settings["alerting"]
proxy_logging_obj.alerting = general_settings["alerting"]
proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
"alerting"
]
if "alert_types" in _general_settings:
general_settings["alert_types"] = _general_settings["alert_types"]
proxy_logging_obj.alert_types = general_settings["alert_types"]
proxy_logging_obj.slack_alerting_instance.alert_types = general_settings[
"alert_types"
]
# router settings
if llm_router is not None:
_router_settings = config_data.get("router_settings", {})
llm_router.update_settings(**_router_settings)
async def add_deployment(
self,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
):
"""
- Check db for new models (last 10 most recently updated)
- Check if model id's in router already
- If not, add to router
"""
global llm_router, llm_model_list, master_key, general_settings
try:
if master_key is None or not isinstance(master_key, str):
raise Exception(
f"Master key is not initialized or formatted. master_key={master_key}"
)
verbose_proxy_logger.debug(f"llm_router: {llm_router}")
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
await self._update_llm_router(
new_models=new_models, proxy_logging_obj=proxy_logging_obj
)
except Exception as e:
verbose_proxy_logger.error(
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
@ -2727,10 +2804,12 @@ async def generate_key_helper_fn(
"model_max_budget": model_max_budget_json,
"budget_id": budget_id,
}
if (
general_settings.get("allow_user_auth", False) == True
or _has_user_setup_sso() == True
):
litellm.get_secret("DISABLE_KEY_NAME", False) == True
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
pass
else:
key_data["key_name"] = f"sk-...{token[-4:]}"
saved_token = copy.deepcopy(key_data)
if isinstance(saved_token["aliases"], str):
@ -3216,7 +3295,7 @@ async def startup_event():
scheduler.add_job(
proxy_config.add_deployment,
"interval",
seconds=30,
seconds=10,
args=[prisma_client, proxy_logging_obj],
)
@ -3317,6 +3396,9 @@ async def completion(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
_headers = dict(request.headers)
_headers.pop(
"authorization", None
@ -3377,7 +3459,10 @@ async def completion(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
detail={
"error": "Invalid model name passed in model="
+ data.get("model", "")
},
)
if hasattr(response, "_hidden_params"):
@ -3409,6 +3494,7 @@ async def completion(
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
verbose_proxy_logger.debug("EXCEPTION RAISED IN PROXY MAIN.PY")
verbose_proxy_logger.debug(
"\033[1;31mAn error occurred: %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
@ -3515,6 +3601,9 @@ async def chat_completion(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
@ -3608,7 +3697,10 @@ async def chat_completion(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
detail={
"error": "Invalid model name passed in model="
+ data.get("model", "")
},
)
# wait for call to end
@ -3652,6 +3744,7 @@ async def chat_completion(
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
@ -3743,6 +3836,9 @@ async def embeddings(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
@ -3832,7 +3928,10 @@ async def embeddings(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
detail={
"error": "Invalid model name passed in model="
+ data.get("model", "")
},
)
### ALERTING ###
@ -3840,6 +3939,7 @@ async def embeddings(
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
@ -3918,6 +4018,9 @@ async def image_generation(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
@ -3981,7 +4084,10 @@ async def image_generation(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
detail={
"error": "Invalid model name passed in model="
+ data.get("model", "")
},
)
### ALERTING ###
@ -3989,6 +4095,7 @@ async def image_generation(
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
@ -4071,6 +4178,9 @@ async def audio_transcriptions(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["file_name"] = file.filename
@ -4095,6 +4205,14 @@ async def audio_transcriptions(
file.filename is not None
) # make sure filename passed in (needed for type)
_original_filename = file.filename
file_extension = os.path.splitext(file.filename)[1]
# rename the file to a random hash file name -> we eventuall remove the file and don't want to remove any local files
file.filename = f"tmp-request" + str(uuid.uuid4()) + file_extension
# IMP - Asserts that we've renamed the uploaded file, since we run os.remove(file.filename), we should rename the original file
assert file.filename != _original_filename
with open(file.filename, "wb+") as f:
f.write(await file.read())
try:
@ -4141,7 +4259,10 @@ async def audio_transcriptions(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
detail={
"error": "Invalid model name passed in model="
+ data.get("model", "")
},
)
except Exception as e:
@ -4153,6 +4274,7 @@ async def audio_transcriptions(
data["litellm_status"] = "success" # used for alerting
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
@ -4243,6 +4365,9 @@ async def moderations(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
@ -4300,7 +4425,10 @@ async def moderations(
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Invalid model name passed in"},
detail={
"error": "Invalid model name passed in model="
+ data.get("model", "")
},
)
### ALERTING ###
@ -4308,6 +4436,7 @@ async def moderations(
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
@ -5452,10 +5581,12 @@ async def global_spend_per_tea():
# get the team_id for this entry
# get the spend for this entry
spend = row["total_spend"]
spend = round(spend, 2)
current_date_entries = spend_by_date[row_date]
current_date_entries[team_alias] = spend
else:
spend = row["total_spend"]
spend = round(spend, 2)
spend_by_date[row_date] = {team_alias: spend}
if team_alias in total_spend_per_team:
@ -5633,6 +5764,20 @@ async def new_user(data: NewUserRequest):
"user" # only create a user, don't create key if 'auto_create_key' set to False
)
response = await generate_key_helper_fn(**data_json)
# Admin UI Logic
# if team_id passed add this user to the team
if data_json.get("team_id", None) is not None:
await team_member_add(
data=TeamMemberAddRequest(
team_id=data_json.get("team_id", None),
member=Member(
user_id=data_json.get("user_id", None),
role="user",
user_email=data_json.get("user_email", None),
),
)
)
return NewUserResponse(
key=response.get("token", ""),
expires=response.get("expires", None),
@ -5795,6 +5940,13 @@ async def user_info(
user_id=user_api_key_dict.user_id
)
# *NEW* get all teams in user 'teams' field
if getattr(caller_user_info, "user_role", None) == "proxy_admin":
teams_2 = await prisma_client.get_data(
table_name="team",
query_type="find_all",
team_id_list=None,
)
else:
teams_2 = await prisma_client.get_data(
team_id_list=caller_user_info.teams,
table_name="team",
@ -5825,6 +5977,13 @@ async def user_info(
## REMOVE HASHED TOKEN INFO before returning ##
returned_keys = []
for key in keys:
if (
key.token == litellm_master_key_hash
and general_settings.get("disable_master_key_return", False)
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
):
continue
try:
key = key.model_dump() # noqa
except:
@ -6438,13 +6597,20 @@ async def team_member_add(
existing_team_row = await prisma_client.get_data( # type: ignore
team_id=data.team_id, table_name="team", query_type="find_unique"
)
if existing_team_row is None:
raise HTTPException(
status_code=404,
detail={
"error": f"Team not found for team_id={getattr(data, 'team_id', None)}"
},
)
new_member = data.member
existing_team_row.members_with_roles.append(new_member)
complete_team_data = LiteLLM_TeamTable(
**existing_team_row.model_dump(),
**_get_pydantic_json_dict(existing_team_row),
)
team_row = await prisma_client.update_data(
@ -7159,12 +7325,16 @@ async def model_info_v2(
"/model/metrics",
description="View number of requests & avg latency per model on config.yaml",
tags=["model management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def model_metrics(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
_selected_model_group: Optional[str] = None,
startTime: Optional[datetime] = datetime.now() - timedelta(days=30),
endTime: Optional[datetime] = datetime.now(),
):
global prisma_client
global prisma_client, llm_router
if prisma_client is None:
raise ProxyException(
message="Prisma Client is not initialized",
@ -7172,6 +7342,33 @@ async def model_metrics(
param="None",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if _selected_model_group and llm_router is not None:
_model_list = llm_router.get_model_list()
_relevant_api_bases = []
for model in _model_list:
if model["model_name"] == _selected_model_group:
_litellm_params = model["litellm_params"]
_api_base = _litellm_params.get("api_base", "")
_relevant_api_bases.append(_api_base)
_relevant_api_bases.append(_api_base + "/openai/")
sql_query = """
SELECT
CASE WHEN api_base = '' THEN model ELSE CONCAT(model, '-', api_base) END AS combined_model_api_base,
COUNT(*) AS num_requests,
AVG(EXTRACT(epoch FROM ("endTime" - "startTime"))) AS avg_latency_seconds
FROM "LiteLLM_SpendLogs"
WHERE "startTime" >= $1::timestamp AND "endTime" <= $2::timestamp
AND api_base = ANY($3)
GROUP BY CASE WHEN api_base = '' THEN model ELSE CONCAT(model, '-', api_base) END
ORDER BY num_requests DESC
LIMIT 50;
"""
db_response = await prisma_client.db.query_raw(
sql_query, startTime, endTime, _relevant_api_bases
)
else:
sql_query = """
SELECT
@ -7180,8 +7377,7 @@ async def model_metrics(
AVG(EXTRACT(epoch FROM ("endTime" - "startTime"))) AS avg_latency_seconds
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= NOW() - INTERVAL '10000 hours'
WHERE "startTime" >= $1::timestamp AND "endTime" <= $2::timestamp
GROUP BY
CASE WHEN api_base = '' THEN model ELSE CONCAT(model, '-', api_base) END
ORDER BY
@ -7189,7 +7385,7 @@ async def model_metrics(
LIMIT 50;
"""
db_response = await prisma_client.db.query_raw(query=sql_query)
db_response = await prisma_client.db.query_raw(sql_query, startTime, endTime)
response: List[dict] = []
if response is not None:
# loop through all models
@ -7751,7 +7947,7 @@ async def login(request: Request):
)
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
**{"user_role": "proxy_admin", "duration": "1hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
**{"user_role": "proxy_admin", "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore
)
else:
raise ProxyException(
@ -8003,7 +8199,7 @@ async def auth_callback(request: Request):
# User might not be already created on first generation of key
# But if it is, we want their models preferences
default_ui_key_values = {
"duration": "1hr",
"duration": "2hr",
"key_max_budget": 0.01,
"aliases": {},
"config": {},
@ -8015,6 +8211,7 @@ async def auth_callback(request: Request):
"user_id": user_id,
"user_email": user_email,
}
_user_id_from_sso = user_id
try:
user_role = None
if prisma_client is not None:
@ -8031,7 +8228,6 @@ async def auth_callback(request: Request):
}
user_role = getattr(user_info, "user_role", None)
else:
## check if user-email in db ##
user_info = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email}
@ -8039,7 +8235,7 @@ async def auth_callback(request: Request):
if user_info is not None:
user_defined_values = {
"models": getattr(user_info, "models", user_id_models),
"user_id": getattr(user_info, "user_id", user_id),
"user_id": user_id,
"user_email": getattr(user_info, "user_id", user_email),
"user_role": getattr(user_info, "user_role", None),
}
@ -8053,9 +8249,7 @@ async def auth_callback(request: Request):
litellm.default_user_params, dict
):
user_defined_values = {
"models": litellm.default_user_params.get(
"models", user_id_models
),
"models": litellm.default_user_params.get("models", user_id_models),
"user_id": litellm.default_user_params.get("user_id", user_id),
"user_email": litellm.default_user_params.get(
"user_email", user_email
@ -8072,6 +8266,10 @@ async def auth_callback(request: Request):
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
# This should always be true
# User_id on SSO == user_id in the LiteLLM_VerificationToken Table
assert user_id == _user_id_from_sso
litellm_dashboard_ui = "/ui/"
user_role = user_role or "app_owner"
if (
@ -8137,10 +8335,12 @@ async def update_config(config_info: ConfigYAML):
updated_general_settings = config_info.general_settings.dict(
exclude_none=True
)
config["general_settings"] = {
**updated_general_settings,
**config["general_settings"],
}
_existing_settings = config["general_settings"]
for k, v in updated_general_settings.items():
# overwrite existing settings with updated values
_existing_settings[k] = v
config["general_settings"] = _existing_settings
if config_info.environment_variables is not None:
config.setdefault("environment_variables", {})
@ -8188,6 +8388,16 @@ async def update_config(config_info: ConfigYAML):
"success_callback"
] = combined_success_callback
# router settings
if config_info.router_settings is not None:
config.setdefault("router_settings", {})
_updated_router_settings = config_info.router_settings
config["router_settings"] = {
**config["router_settings"],
**_updated_router_settings,
}
# Save the updated config
await proxy_config.save_config(new_config=config)
@ -8303,9 +8513,25 @@ async def get_config():
)
_slack_env_vars[_var] = _decrypted_value
_data_to_return.append({"name": "slack", "variables": _slack_env_vars})
_alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types
_all_alert_types = (
proxy_logging_obj.slack_alerting_instance._all_possible_alert_types()
)
_data_to_return.append(
{
"name": "slack",
"variables": _slack_env_vars,
"alerting_types": _alerting_types,
"all_alert_types": _all_alert_types,
}
)
return {"status": "success", "data": _data_to_return}
_router_settings = llm_router.get_settings()
return {
"status": "success",
"data": _data_to_return,
"router_settings": _router_settings,
}
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff
from litellm.proxy._types import (
UserAPIKeyAuth,
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -30,6 +31,7 @@ import smtplib, re
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
def print_verbose(print_statement):
@ -64,27 +66,70 @@ class ProxyLogging:
self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold
self.alert_types: List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
self.slack_alerting_instance = SlackAlerting(
alerting_threshold=self.alerting_threshold,
alerting=self.alerting,
alert_types=self.alert_types,
)
def update_values(
self,
alerting: Optional[List],
alerting_threshold: Optional[float],
redis_cache: Optional[RedisCache],
alert_types: Optional[
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
]
] = None,
):
self.alerting = alerting
if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold
if alert_types is not None:
self.alert_types = alert_types
self.slack_alerting_instance.update_values(
alerting=self.alerting,
alerting_threshold=self.alerting_threshold,
alert_types=self.alert_types,
)
if redis_cache is not None:
self.internal_usage_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check)
litellm.success_callback.append(self.response_taking_too_long_callback)
litellm.callbacks.append(self.service_logging_obj)
litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback
)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
@ -133,7 +178,9 @@ class ProxyLogging:
"""
print_verbose(f"Inside Proxy Logging Pre-call hook!")
### ALERTING ###
asyncio.create_task(self.response_taking_too_long(request_data=data))
asyncio.create_task(
self.slack_alerting_instance.response_taking_too_long(request_data=data)
)
try:
for callback in litellm.callbacks:
@ -182,110 +229,6 @@ class ProxyLogging:
raise e
return data
def _response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
start_time,
end_time, # start/end time
):
try:
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
api_base = litellm_params.get("api_base", "")
model = kwargs.get("model", "")
messages = kwargs.get("messages", "")
return time_difference_float, model, api_base, messages
except Exception as e:
raise e
async def response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
if self.alerting is None:
return
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
)
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
await self.alerting_handler(
message=slow_message + request_info,
level="Low",
)
async def response_taking_too_long(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request",
request_data: Optional[dict] = None,
):
if request_data is not None:
model = request_data.get("model", "")
messages = request_data.get("messages", "")
trace_id = request_data.get("metadata", {}).get(
"trace_id", None
) # get langfuse trace id
if trace_id is not None:
messages = str(messages)
messages = messages[:100]
messages = f"{messages}\nLangfuse Trace Id: {trace_id}"
else:
# try casting messages to str and get the first 100 characters, else mark as None
try:
messages = str(messages)
messages = messages[:100]
except:
messages = None
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
else:
request_info = ""
if type == "hanging_request":
# Simulate a long-running operation that could take more than 5 minutes
await asyncio.sleep(
self.alerting_threshold
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
if (
request_data is not None
and request_data.get("litellm_status", "") != "success"
):
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
_api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get(
"litellm_params", {}
),
)
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: {_api_base}"
# only alert hanging responses if they have not been marked as success
alerting_message = (
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
)
await self.alerting_handler(
message=alerting_message + request_info,
level="Medium",
)
async def budget_alerts(
self,
type: Literal[
@ -304,106 +247,13 @@ class ProxyLogging:
if self.alerting is None:
# do nothing if alerting is not switched on
return
_id: str = "default_id" # used for caching
if type == "user_and_proxy_budget":
user_info = dict(user_info)
user_id = user_info["user_id"]
_id = user_id
max_budget = user_info["max_budget"]
spend = user_info["spend"]
user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
_id = token
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
elif type == "failed_tracking":
user_id = str(user_info)
_id = user_id
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
message = "Failed Tracking Cost for" + user_info
await self.alerting_handler(
message=message,
level="High",
await self.slack_alerting_instance.budget_alerts(
type=type,
user_max_budget=user_max_budget,
user_current_spend=user_current_spend,
user_info=user_info,
error_message=error_message,
)
return
elif type == "projected_limit_exceeded" and user_info is not None:
"""
Input variables:
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
user_max_budget=soft_limit,
user_current_spend=new_spend
"""
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
await self.alerting_handler(
message=message,
level="High",
)
return
else:
user_info = str(user_info)
# percent of max_budget left to spend
if user_max_budget > 0:
percent_left = (user_max_budget - user_current_spend) / user_max_budget
else:
percent_left = 0
verbose_proxy_logger.debug(
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
)
# check if crossed budget
if user_current_spend >= user_max_budget:
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
message = "Budget Crossed for" + user_info
await self.alerting_handler(
message=message,
level="High",
)
return
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
# - Alert once within 28d period
# - Cache this information
# - Don't re-alert, if alert already sent
_cache: DualCache = self.internal_usage_cache
# check if 5% of max budget is left
if percent_left <= 0.05:
message = "5% budget left for" + user_info
cache_key = "alerting:{}".format(_id)
result = await _cache.async_get_cache(key=cache_key)
if result is None:
await self.alerting_handler(
message=message,
level="Medium",
)
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
return
# check if 15% of max budget is left
if percent_left <= 0.15:
message = "15% budget left for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.alerting_handler(
message=message,
level="Low",
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
return
async def alerting_handler(
self, message: str, level: Literal["Low", "Medium", "High"]
@ -422,44 +272,42 @@ class ProxyLogging:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
if self.alerting is None:
return
from datetime import datetime
# Get the current timestamp
current_time = datetime.now().strftime("%H:%M:%S")
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
formatted_message = (
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
)
if self.alerting is None:
return
if _proxy_base_url is not None:
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
for client in self.alerting:
if client == "slack":
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=False)
) as session:
async with session.post(
slack_webhook_url, json=payload, headers=headers
) as response:
if response.status == 200:
pass
await self.slack_alerting_instance.send_alert(
message=message, level=level
)
elif client == "sentry":
if litellm.utils.sentry_sdk_instance is not None:
litellm.utils.sentry_sdk_instance.capture_message(formatted_message)
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception, traceback_str=""):
async def failure_handler(
self, original_exception, duration: float, call_type: str, traceback_str=""
):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
### ALERTING ###
if "db_exceptions" not in self.alert_types:
return
if isinstance(original_exception, HTTPException):
if isinstance(original_exception.detail, str):
error_message = original_exception.detail
@ -478,6 +326,14 @@ class ProxyLogging:
)
)
if hasattr(self, "service_logging_obj"):
self.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.DB,
duration=duration,
error=error_message,
call_type=call_type,
)
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
@ -494,6 +350,8 @@ class ProxyLogging:
"""
### ALERTING ###
if "llm_exceptions" not in self.alert_types:
return
asyncio.create_task(
self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}", level="High"
@ -798,6 +656,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
)
start_time = time.time()
try:
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
@ -822,11 +681,17 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
traceback_str=error_traceback,
call_type="get_generic_data",
)
)
raise e
@backoff.on_exception(
@ -864,6 +729,7 @@ class PrismaClient:
] = None, # pagination, number of rows to getch when find_all==True
):
args_passed_in = locals()
start_time = time.time()
verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
)
@ -1011,9 +877,21 @@ class PrismaClient:
},
)
else:
response = await self.db.litellm_usertable.find_many( # type: ignore
order={"spend": "desc"}, take=limit, skip=offset
)
# return all users in the table, get their key aliases ordered by spend
sql_query = """
SELECT
u.*,
json_agg(v.key_alias) AS key_aliases
FROM
"LiteLLM_UserTable" u
LEFT JOIN "LiteLLM_VerificationToken" v ON u.user_id = v.user_id
GROUP BY
u.user_id
ORDER BY u.spend DESC
LIMIT $1
OFFSET $2
"""
response = await self.db.query_raw(sql_query, limit, offset)
return response
elif table_name == "spend":
verbose_proxy_logger.debug(
@ -1053,6 +931,8 @@ class PrismaClient:
response = await self.db.litellm_teamtable.find_many(
where={"team_id": {"in": team_id_list}}
)
elif query_type == "find_all" and team_id_list is None:
response = await self.db.litellm_teamtable.find_many(take=20)
return response
elif table_name == "user_notification":
if query_type == "find_unique":
@ -1088,6 +968,7 @@ class PrismaClient:
t.rpm_limit AS team_rpm_limit,
t.models AS team_models,
t.blocked AS team_blocked,
t.team_alias AS team_alias,
m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
@ -1117,9 +998,15 @@ class PrismaClient:
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
verbose_proxy_logger.debug(error_traceback)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="get_data",
traceback_str=error_traceback,
)
)
raise e
@ -1142,6 +1029,7 @@ class PrismaClient:
"""
Add a key to the database. If it already exists, do nothing.
"""
start_time = time.time()
try:
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
if table_name == "key":
@ -1259,9 +1147,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="insert_data",
traceback_str=error_traceback,
)
)
raise e
@ -1292,6 +1185,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: update_data, table_name: {table_name}"
)
start_time = time.time()
try:
db_data = self.jsonify_object(data=data)
if update_key_values is not None:
@ -1453,9 +1347,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_data",
traceback_str=error_traceback,
)
)
raise e
@ -1480,6 +1379,7 @@ class PrismaClient:
Ensure user owns that key, unless admin.
"""
start_time = time.time()
try:
if tokens is not None and isinstance(tokens, List):
hashed_tokens = []
@ -1527,9 +1427,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="delete_data",
traceback_str=error_traceback,
)
)
raise e
@ -1543,6 +1448,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
start_time = time.time()
try:
verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB"
@ -1558,9 +1464,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="connect",
traceback_str=error_traceback,
)
)
raise e
@ -1574,6 +1485,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
start_time = time.time()
try:
await self.db.disconnect()
except Exception as e:
@ -1582,9 +1494,14 @@ class PrismaClient:
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="disconnect",
traceback_str=error_traceback,
)
)
raise e
@ -1593,6 +1510,8 @@ class PrismaClient:
"""
Health check endpoint for the prisma client
"""
start_time = time.time()
try:
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
@ -1603,6 +1522,23 @@ class PrismaClient:
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
except Exception as e:
import traceback
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.proxy_logging_obj.failure_handler(
original_exception=e,
duration=_duration,
call_type="health_check",
traceback_str=error_traceback,
)
)
raise e
class DBClient:
@ -1978,6 +1914,7 @@ async def update_spend(
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2008,9 +1945,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2018,6 +1960,7 @@ async def update_spend(
### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2054,9 +1997,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2064,6 +2012,7 @@ async def update_spend(
### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2094,9 +2043,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2109,6 +2063,7 @@ async def update_spend(
)
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2144,9 +2099,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2154,6 +2114,7 @@ async def update_spend(
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
@ -2184,9 +2145,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e
@ -2201,6 +2167,7 @@ async def update_spend(
if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1):
start_time = time.time()
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ##
@ -2266,9 +2233,14 @@ async def update_spend(
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
original_exception=e,
duration=_duration,
call_type="update_spend",
traceback_str=error_traceback,
)
)
raise e

View file

@ -26,11 +26,17 @@ from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
)
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
from litellm.integrations.custom_logger import CustomLogger
class Router:
@ -60,6 +66,7 @@ class Router:
num_retries: int = 0,
timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create
default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO",
fallbacks: List = [],
@ -197,13 +204,18 @@ class Router:
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
if model_list:
if model_list is not None:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
self.healthy_deployments: List = self.model_list # type: ignore
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
else:
self.model_list: List = (
[]
) # initialize an empty list - to allow _add_deployment and delete_deployment to work
self.allowed_fails = allowed_fails or litellm.allowed_fails
self.cooldown_time = cooldown_time or 1
@ -212,6 +224,7 @@ class Router:
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries or 0
self.timeout = timeout or litellm.request_timeout
self.retry_after = retry_after
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
@ -297,8 +310,9 @@ class Router:
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
verbose_router_logger.info(
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}"
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
)
self.routing_strategy_args = routing_strategy_args
def print_deployment(self, deployment: dict):
"""
@ -350,6 +364,7 @@ class Router:
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"api_base": deployment.get("litellm_params", {}).get("api_base"),
"model_info": deployment.get("model_info", {}),
}
)
@ -377,6 +392,9 @@ class Router:
else:
model_client = potential_model_client
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.completion(
**{
**data,
@ -389,6 +407,7 @@ class Router:
verbose_router_logger.info(
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
@ -437,6 +456,7 @@ class Router:
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
@ -488,21 +508,25 @@ class Router:
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if (
rpm_semaphore is not None
and isinstance(rpm_semaphore, asyncio.Semaphore)
and self.routing_strategy == "usage-based-routing-v2"
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await _response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
self.success_calls[model_name] += 1
@ -577,6 +601,10 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.image_generation(
**{
**data,
@ -655,7 +683,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aimage_generation(
response = litellm.aimage_generation(
**{
**data,
"prompt": prompt,
@ -664,6 +692,30 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
@ -755,7 +807,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atranscription(
response = litellm.atranscription(
**{
**data,
"file": file,
@ -764,6 +816,30 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
@ -950,6 +1026,7 @@ class Router:
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
@ -977,7 +1054,8 @@ class Router:
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atext_completion(
response = litellm.atext_completion(
**{
**data,
"prompt": prompt,
@ -987,6 +1065,29 @@ class Router:
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
@ -1061,6 +1162,10 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.embedding(
**{
**data,
@ -1117,6 +1222,7 @@ class Router:
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
@ -1145,7 +1251,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aembedding(
response = litellm.aembedding(
**{
**data,
"input": input,
@ -1154,6 +1260,30 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
@ -1711,6 +1841,38 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def routing_strategy_pre_call_checks(self, deployment: dict):
"""
Mimics 'async_routing_strategy_pre_call_checks'
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = _callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = await _callback.async_pre_call_check(deployment)
def set_client(self, model: dict):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
@ -1722,17 +1884,23 @@ class Router:
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
if rpm:
semaphore = asyncio.Semaphore(rpm)
cache_key = f"{model_id}_rpm_client"
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=self.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
# print("STORES SEMAPHORE IN CACHE")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -2271,11 +2439,19 @@ class Router:
return deployment
def add_deployment(self, deployment: Deployment):
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
"""
Parameters:
- deployment: Deployment - the deployment to be added to the Router
Returns:
- The added deployment
- OR None (if deployment already exists)
"""
# check if deployment already exists
if deployment.model_info.id in self.get_model_ids():
return
return None
# add to model list
_deployment = deployment.to_json(exclude_none=True)
@ -2286,7 +2462,7 @@ class Router:
# add to model names
self.model_names.append(deployment.model_name)
return
return deployment
def delete_deployment(self, id: str) -> Optional[Deployment]:
"""
@ -2334,6 +2510,61 @@ class Router:
return self.model_list
return None
def get_settings(self):
"""
Get router settings method, returns a dictionary of the settings and their values.
For example get the set values for routing_strategy_args, routing_strategy, allowed_fails, cooldown_time, num_retries, timeout, max_retries, retry_after
"""
_all_vars = vars(self)
_settings_to_return = {}
vars_to_include = [
"routing_strategy_args",
"routing_strategy",
"allowed_fails",
"cooldown_time",
"num_retries",
"timeout",
"max_retries",
"retry_after",
]
for var in vars_to_include:
if var in _all_vars:
_settings_to_return[var] = _all_vars[var]
return _settings_to_return
def update_settings(self, **kwargs):
# only the following settings are allowed to be configured
_allowed_settings = [
"routing_strategy_args",
"routing_strategy",
"allowed_fails",
"cooldown_time",
"num_retries",
"timeout",
"max_retries",
"retry_after",
]
_int_settings = [
"timeout",
"num_retries",
"retry_after",
"allowed_fails",
"cooldown_time",
]
for var in kwargs:
if var in _allowed_settings:
if var in _int_settings:
_casted_value = int(kwargs[var])
setattr(self, var, _casted_value)
else:
setattr(self, var, kwargs[var])
else:
verbose_router_logger.debug("Setting {} is not allowed".format(var))
verbose_router_logger.debug(f"Updated Router settings: {self.get_settings()}")
def _get_client(self, deployment, kwargs, client_type=None):
"""
Returns the appropriate client based on the given deployment, kwargs, and client_type.
@ -2347,8 +2578,8 @@ class Router:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "rpm_client":
cache_key = "{}_rpm_client".format(model_id)
if client_type == "max_parallel_requests":
cache_key = "{}_max_parallel_requests_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":
@ -2588,6 +2819,7 @@ class Router:
"""
if (
self.routing_strategy != "usage-based-routing-v2"
and self.routing_strategy != "simple-shuffle"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment(
model=model,
@ -2638,7 +2870,46 @@ class Router:
messages=messages,
input=input,
)
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None:
# use weight-random pick if rpms provided
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\nrpms {rpms}")
total_rpm = sum(rpms)
weights = [rpm / total_rpm for rpm in rpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
if tpm is not None:
# use weight-random pick if rpms provided
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\ntpms {tpms}")
total_tpm = sum(tpms)
weights = [tpm / total_tpm for tpm in tpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
@ -2649,6 +2920,7 @@ class Router:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
return deployment
def get_available_deployment(

View file

@ -39,7 +39,81 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
self.router_cache = router_cache
self.model_list = model_list
async def pre_call_rpm_check(self, deployment: dict) -> dict:
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Pre-call check + update model rpm
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}"
local_result = self.router_cache.get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Pre-call check + update model rpm
- Used inside semaphore
@ -58,8 +132,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_group = deployment.get("model_name", "")
rpm_key = f"{model_group}:rpm:{current_minute}"
model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
@ -113,6 +187,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
@ -143,26 +218,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
self.router_cache.increment_cache(key=tpm_key, value=total_tokens)
### TESTING ###
if self.test_flag:
self.logged_success += 1
@ -254,21 +321,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in tpm_dict:
tpm_dict[d["model_info"]["id"]] = 0
tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}"
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
tpm_dict[tpm_key] = 0
all_deployments = tpm_dict
deployment = None
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
item = item.split(":")[0]
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
@ -291,7 +363,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (rpm_dict is not None and item in rpm_dict) and (
@ -336,13 +407,15 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
tpm_values = await self.router_cache.async_batch_get_cache(
keys=tpm_keys
) # [1, 2, None, ..]
rpm_values = await self.router_cache.async_batch_get_cache(
keys=rpm_keys
combined_tpm_rpm_keys = tpm_keys + rpm_keys
combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache(
keys=combined_tpm_rpm_keys
) # [1, 2, None, ..]
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
return self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,

View file

@ -8,4 +8,4 @@ litellm_settings:
cache_params:
type: "redis"
supported_call_types: ["embedding", "aembedding"]
host: "localhost"
host: "os.environ/REDIS_HOST"

View file

@ -4,7 +4,7 @@
import sys
import os
import io, asyncio
from datetime import datetime
from datetime import datetime, timedelta
# import logging
# logging.basicConfig(level=logging.DEBUG)
@ -13,6 +13,10 @@ from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache
import litellm
import pytest
import asyncio
from unittest.mock import patch, MagicMock
from litellm.caching import DualCache
from litellm.integrations.slack_alerting import SlackAlerting
@pytest.mark.asyncio
@ -43,7 +47,7 @@ async def test_get_api_base():
end_time = datetime.now()
time_difference_float, model, api_base, messages = (
_pl._response_taking_too_long_callback(
_pl.slack_alerting_instance._response_taking_too_long_callback(
kwargs={
"model": model,
"messages": messages,
@ -65,3 +69,27 @@ async def test_get_api_base():
message=slow_message + request_info,
level="Low",
)
print("passed test_get_api_base")
# Create a mock environment for testing
@pytest.fixture
def mock_env(monkeypatch):
monkeypatch.setenv("SLACK_WEBHOOK_URL", "https://example.com/webhook")
monkeypatch.setenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
monkeypatch.setenv("LANGFUSE_PROJECT_ID", "test-project-id")
# Test the __init__ method
def test_init():
slack_alerting = SlackAlerting(
alerting_threshold=32, alerting=["slack"], alert_types=["llm_exceptions"]
)
assert slack_alerting.alerting_threshold == 32
assert slack_alerting.alerting == ["slack"]
assert slack_alerting.alert_types == ["llm_exceptions"]
slack_no_alerting = SlackAlerting()
assert slack_no_alerting.alerting == []
print("passed testing slack alerting init")

View file

@ -90,7 +90,7 @@ def load_vertex_ai_credentials():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file
# Write the updated content to the temporary files
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
@ -145,6 +145,7 @@ def test_vertex_ai_anthropic():
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
# )
def test_vertex_ai_anthropic_streaming():
try:
# load_vertex_ai_credentials()
# litellm.set_verbose = True
@ -169,6 +170,10 @@ def test_vertex_ai_anthropic_streaming():
print(f"chunk: {chunk}")
# raise Exception("it worked!")
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_anthropic_streaming()
@ -180,6 +185,7 @@ def test_vertex_ai_anthropic_streaming():
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async():
# load_vertex_ai_credentials()
try:
model = "claude-3-sonnet@20240229"
@ -197,6 +203,10 @@ async def test_vertex_ai_anthropic_async():
vertex_credentials=vertex_credentials,
)
print(f"Model Response: {response}")
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_vertex_ai_anthropic_async())
@ -208,6 +218,7 @@ async def test_vertex_ai_anthropic_async():
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async_streaming():
# load_vertex_ai_credentials()
try:
litellm.set_verbose = True
model = "claude-3-sonnet@20240229"
@ -228,6 +239,10 @@ async def test_vertex_ai_anthropic_async_streaming():
async for chunk in response:
print(f"chunk: {chunk}")
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_vertex_ai_anthropic_async_streaming())
@ -553,12 +568,19 @@ def test_gemini_pro_function_calling():
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
completion = litellm.completion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
assert completion.choices[0].message.content is None
if hasattr(completion.choices[0].message, "tool_calls") and isinstance(
completion.choices[0].message.tool_calls, list
):
assert len(completion.choices[0].message.tool_calls) == 1
try:
load_vertex_ai_credentials()
@ -586,14 +608,22 @@ def test_gemini_pro_function_calling():
}
]
messages = [
{"role": "user", "content": "What's the weather like in Boston today?"}
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
completion = litellm.completion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
assert completion.choices[0].message.content is None
# assert completion.choices[0].message.content is None ## GEMINI PRO is very chatty.
if hasattr(completion.choices[0].message, "tool_calls") and isinstance(
completion.choices[0].message.tool_calls, list
):
assert len(completion.choices[0].message.tool_calls) == 1
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
@ -629,7 +659,12 @@ def test_gemini_pro_function_calling_streaming():
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
try:
completion = litellm.completion(
model="gemini-pro",
@ -643,6 +678,8 @@ def test_gemini_pro_function_calling_streaming():
# assert len(completion.choices[0].message.tool_calls) == 1
for chunk in completion:
print(f"chunk: {chunk}")
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
@ -675,7 +712,10 @@ async def test_gemini_pro_async_function_calling():
}
]
messages = [
{"role": "user", "content": "What's the weather like in Boston today?"}
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
completion = await litellm.acompletion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
@ -683,6 +723,8 @@ async def test_gemini_pro_async_function_calling():
print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:

View file

@ -19,7 +19,7 @@ from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -36,6 +36,7 @@ async def test_banned_keywords_check():
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

View file

@ -252,7 +252,10 @@ def test_bedrock_claude_3_tool_calling():
}
]
messages = [
{"role": "user", "content": "What's the weather like in Boston today?"}
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
@ -266,6 +269,30 @@ def test_bedrock_claude_3_tool_calling():
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(f"second response: {second_response}")
assert isinstance(second_response.choices[0].message.content, str)
except RateLimitError:
pass
except Exception as e:

View file

@ -20,7 +20,7 @@ from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
@ -106,6 +106,7 @@ async def test_block_user_check(prisma_client):
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

View file

@ -33,6 +33,51 @@ def generate_random_word(length=4):
messages = [{"role": "user", "content": "who is ishaan 5222"}]
@pytest.mark.asyncio
async def test_dual_cache_async_batch_get_cache():
"""
Unit testing for Dual Cache async_batch_get_cache()
- 2 item query
- in_memory result has a partial hit (1/2)
- hit redis for the other -> expect to return None
- expect result = [in_memory_result, None]
"""
from litellm.caching import DualCache, InMemoryCache, RedisCache
in_memory_cache = InMemoryCache()
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
in_memory_cache.set_cache(key="test_value", value="hello world")
result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
assert result[0] == "hello world"
assert result[1] == None
def test_dual_cache_batch_get_cache():
"""
Unit testing for Dual Cache batch_get_cache()
- 2 item query
- in_memory result has a partial hit (1/2)
- hit redis for the other -> expect to return None
- expect result = [in_memory_result, None]
"""
from litellm.caching import DualCache, InMemoryCache, RedisCache
in_memory_cache = InMemoryCache()
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
in_memory_cache.set_cache(key="test_value", value="hello world")
result = dual_cache.batch_get_cache(keys=["test_value", "test_value_2"])
assert result[0] == "hello world"
assert result[1] == None
# @pytest.mark.skip(reason="")
def test_caching_dynamic_args(): # test in memory cache
try:
@ -133,11 +178,17 @@ def test_caching_with_default_ttl():
pytest.fail(f"Error occurred: {e}")
def test_caching_with_cache_controls():
@pytest.mark.parametrize(
"sync_flag",
[True, False],
)
@pytest.mark.asyncio
async def test_caching_with_cache_controls(sync_flag):
try:
litellm.set_verbose = True
litellm.cache = Cache()
message = [{"role": "user", "content": f"Hey, how's it going? {uuid.uuid4()}"}]
if sync_flag:
## TTL = 0
response1 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
@ -145,11 +196,23 @@ def test_caching_with_cache_controls():
response2 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
)
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response2["id"] != response1["id"]
else:
## TTL = 0
response1 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
)
await asyncio.sleep(10)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
)
assert response2["id"] != response1["id"]
message = [{"role": "user", "content": f"Hey, how's it going? {uuid.uuid4()}"}]
## TTL = 5
if sync_flag:
response1 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
)
@ -159,6 +222,17 @@ def test_caching_with_cache_controls():
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response2["id"] == response1["id"]
else:
response1 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 25}
)
await asyncio.sleep(10)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 25}
)
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response2["id"] == response1["id"]
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
@ -390,6 +464,7 @@ async def test_embedding_caching_azure_individual_items_reordered():
@pytest.mark.asyncio
async def test_embedding_caching_base_64():
""" """
litellm.set_verbose = True
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
@ -408,6 +483,8 @@ async def test_embedding_caching_base_64():
caching=True,
encoding_format="base64",
)
await asyncio.sleep(5)
print("\n\nCALL2\n\n")
embedding_val_2 = await aembedding(
model="azure/azure-embedding-model",
input=inputs,
@ -1063,6 +1140,7 @@ async def test_cache_control_overrides():
"content": "hello who are you" + unique_num,
}
],
caching=True,
)
print(response1)
@ -1077,6 +1155,55 @@ async def test_cache_control_overrides():
"content": "hello who are you" + unique_num,
}
],
caching=True,
cache={"no-cache": True},
)
print(response2)
assert response1.id != response2.id
def test_sync_cache_control_overrides():
# we use the cache controls to ensure there is no cache hit on this test
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("Testing cache override")
litellm.set_verbose = True
import uuid
unique_num = str(uuid.uuid4())
start_time = time.time()
response1 = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello who are you" + unique_num,
}
],
caching=True,
)
print(response1)
time.sleep(2)
response2 = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello who are you" + unique_num,
}
],
caching=True,
cache={"no-cache": True},
)
@ -1094,10 +1221,6 @@ def test_custom_redis_cache_params():
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
db=0,
ssl=True,
ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key",
ssl_ca_certs="./redis_ca.pem",
)
print(litellm.cache.cache.redis_client)
@ -1105,7 +1228,7 @@ def test_custom_redis_cache_params():
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
pytest.fail(f"Error occurred:", e)
pytest.fail(f"Error occurred: {str(e)}")
def test_get_cache_key():

View file

@ -33,6 +33,22 @@ def reset_callbacks():
litellm.callbacks = []
@pytest.mark.skip(reason="Local test")
def test_response_model_none():
"""
Addresses - https://github.com/BerriAI/litellm/issues/2972
"""
x = completion(
model="mymodel",
custom_llm_provider="openai",
messages=[{"role": "user", "content": "Hello!"}],
api_base="http://0.0.0.0:8080",
api_key="my-api-key",
)
print(f"x: {x}")
assert isinstance(x, litellm.ModelResponse)
def test_completion_custom_provider_model_name():
try:
litellm.cache = None
@ -167,7 +183,12 @@ def test_completion_claude_3_function_call():
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
@ -376,7 +397,12 @@ def test_completion_claude_3_function_plus_image():
]
tool_choice = {"type": "function", "function": {"name": "get_current_weather"}}
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
response = completion(
model="claude-3-sonnet-20240229",
@ -389,6 +415,51 @@ def test_completion_claude_3_function_plus_image():
print(response)
def test_completion_azure_mistral_large_function_calling():
"""
This primarily tests if the 'Function()' pydantic object correctly handles argument param passed in as a dict vs. string
"""
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
response = completion(
model="azure/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE"),
api_key=os.getenv("AZURE_MISTRAL_API_KEY"),
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(response.choices[0].message.tool_calls[0].function.arguments, str)
def test_completion_mistral_api():
try:
litellm.set_verbose = True
@ -413,6 +484,76 @@ def test_completion_mistral_api():
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_mistral_large_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Mistral should deduce answer from tool results
second_response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(second_response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(
reason="Since we already test mistral/mistral-tiny in test_completion_mistral_api. This is only for locally verifying azure mistral works"
)
@ -2418,7 +2559,7 @@ def test_completion_deep_infra_mistral():
# Gemini tests
def test_completion_gemini():
litellm.set_verbose = True
model_name = "gemini/gemini-pro"
model_name = "gemini/gemini-1.5-pro-latest"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
try:
response = completion(model=model_name, messages=messages)

View file

@ -0,0 +1,279 @@
# What is this?
## Unit tests for ProxyConfig class
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the, system path
import pytest, litellm
from pydantic import BaseModel
from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
from typing import Literal
class DBModel(BaseModel):
model_id: str
model_name: str
model_info: dict
litellm_params: dict
@pytest.mark.asyncio
async def test_delete_deployment():
"""
- Ensure the global llm router is not being reset
- Ensure invalid model is deleted
- Check if model id != model_info["id"], the model_info["id"] is picked
"""
import base64
litellm_params = LiteLLM_Params(
model="azure/chatgpt-v-2",
api_key=os.getenv("AZURE_API_KEY"),
api_base=os.getenv("AZURE_API_BASE"),
api_version=os.getenv("AZURE_API_VERSION"),
)
encrypted_litellm_params = litellm_params.dict(exclude_none=True)
master_key = "sk-1234"
setattr(litellm.proxy.proxy_server, "master_key", master_key)
for k, v in encrypted_litellm_params.items():
if isinstance(v, str):
encrypted_value = encrypt_value(v, master_key)
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
"utf-8"
)
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
deployment_2 = Deployment(
model_name="gpt-3.5-turbo-2", litellm_params=litellm_params
)
llm_router = litellm.Router(
model_list=[
deployment.to_json(exclude_none=True),
deployment_2.to_json(exclude_none=True),
]
)
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
print(f"llm_router: {llm_router}")
pc = ProxyConfig()
db_model = DBModel(
model_id=deployment.model_info.id,
model_name="gpt-3.5-turbo",
litellm_params=encrypted_litellm_params,
model_info={"id": deployment.model_info.id},
)
db_models = [db_model]
deleted_deployments = await pc._delete_deployment(db_models=db_models)
assert deleted_deployments == 1
assert len(llm_router.model_list) == 1
"""
Scenario 2 - if model id != model_info["id"]
"""
llm_router = litellm.Router(
model_list=[
deployment.to_json(exclude_none=True),
deployment_2.to_json(exclude_none=True),
]
)
print(f"llm_router: {llm_router}")
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
pc = ProxyConfig()
db_model = DBModel(
model_id="12340523",
model_name="gpt-3.5-turbo",
litellm_params=encrypted_litellm_params,
model_info={"id": deployment.model_info.id},
)
db_models = [db_model]
deleted_deployments = await pc._delete_deployment(db_models=db_models)
assert deleted_deployments == 1
assert len(llm_router.model_list) == 1
@pytest.mark.asyncio
async def test_add_existing_deployment():
"""
- Only add new models
- don't re-add existing models
"""
import base64
litellm_params = LiteLLM_Params(
model="gpt-3.5-turbo",
api_key=os.getenv("AZURE_API_KEY"),
api_base=os.getenv("AZURE_API_BASE"),
api_version=os.getenv("AZURE_API_VERSION"),
)
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
deployment_2 = Deployment(
model_name="gpt-3.5-turbo-2", litellm_params=litellm_params
)
llm_router = litellm.Router(
model_list=[
deployment.to_json(exclude_none=True),
deployment_2.to_json(exclude_none=True),
]
)
print(f"llm_router: {llm_router}")
master_key = "sk-1234"
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
setattr(litellm.proxy.proxy_server, "master_key", master_key)
pc = ProxyConfig()
encrypted_litellm_params = litellm_params.dict(exclude_none=True)
for k, v in encrypted_litellm_params.items():
if isinstance(v, str):
encrypted_value = encrypt_value(v, master_key)
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
"utf-8"
)
db_model = DBModel(
model_id=deployment.model_info.id,
model_name="gpt-3.5-turbo",
litellm_params=encrypted_litellm_params,
model_info={"id": deployment.model_info.id},
)
db_models = [db_model]
num_added = pc._add_deployment(db_models=db_models)
assert num_added == 0
litellm_params = LiteLLM_Params(
model="azure/chatgpt-v-2",
api_key=os.getenv("AZURE_API_KEY"),
api_base=os.getenv("AZURE_API_BASE"),
api_version=os.getenv("AZURE_API_VERSION"),
)
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
deployment_2 = Deployment(model_name="gpt-3.5-turbo-2", litellm_params=litellm_params)
def _create_model_list(flag_value: Literal[0, 1], master_key: str):
"""
0 - empty list
1 - list with an element
"""
import base64
new_litellm_params = LiteLLM_Params(
model="azure/chatgpt-v-2-3",
api_key=os.getenv("AZURE_API_KEY"),
api_base=os.getenv("AZURE_API_BASE"),
api_version=os.getenv("AZURE_API_VERSION"),
)
encrypted_litellm_params = new_litellm_params.dict(exclude_none=True)
for k, v in encrypted_litellm_params.items():
if isinstance(v, str):
encrypted_value = encrypt_value(v, master_key)
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
"utf-8"
)
db_model = DBModel(
model_id="12345",
model_name="gpt-3.5-turbo",
litellm_params=encrypted_litellm_params,
model_info={"id": "12345"},
)
db_models = [db_model]
if flag_value == 0:
return []
elif flag_value == 1:
return db_models
@pytest.mark.parametrize(
"llm_router",
[
None,
litellm.Router(),
litellm.Router(
model_list=[
deployment.to_json(exclude_none=True),
deployment_2.to_json(exclude_none=True),
]
),
],
)
@pytest.mark.parametrize(
"model_list_flag_value",
[0, 1],
)
@pytest.mark.asyncio
async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
"""
Test add + delete logic in 3 scenarios
- when router is none
- when router is init but empty
- when router is init and not empty
"""
master_key = "sk-1234"
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
setattr(litellm.proxy.proxy_server, "master_key", master_key)
pc = ProxyConfig()
pl = ProxyLogging(DualCache())
async def _monkey_patch_get_config(*args, **kwargs):
print(f"ENTERS MP GET CONFIG")
if llm_router is None:
return {}
else:
print(f"llm_router.model_list: {llm_router.model_list}")
return {"model_list": llm_router.model_list}
pc.get_config = _monkey_patch_get_config
model_list = _create_model_list(
flag_value=model_list_flag_value, master_key=master_key
)
if llm_router is None:
prev_llm_router_val = None
else:
prev_llm_router_val = len(llm_router.model_list)
await pc._update_llm_router(new_models=model_list, proxy_logging_obj=pl)
llm_router = getattr(litellm.proxy.proxy_server, "llm_router")
if model_list_flag_value == 0:
if prev_llm_router_val is None:
assert prev_llm_router_val == llm_router
else:
assert prev_llm_router_val == len(llm_router.model_list)
else:
if prev_llm_router_val is None:
assert len(llm_router.model_list) == len(model_list)
else:
assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val

View file

@ -412,7 +412,7 @@ async def test_cost_tracking_with_caching():
"""
from litellm import Cache
litellm.set_verbose = False
litellm.set_verbose = True
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],

View file

@ -536,6 +536,55 @@ def test_completion_openai_api_key_exception():
# tesy_async_acompletion()
def test_router_completion_vertex_exception():
try:
import litellm
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "vertex-gemini-pro",
"litellm_params": {
"model": "vertex_ai/gemini-pro",
"api_key": "good-morning",
},
},
]
)
response = router.completion(
model="vertex-gemini-pro",
messages=[{"role": "user", "content": "hello"}],
vertex_project="bad-project",
)
pytest.fail("Request should have failed - bad api key")
except Exception as e:
print("exception: ", e)
assert "model: vertex_ai/gemini-pro" in str(e)
assert "model_group: vertex-gemini-pro" in str(e)
assert "deployment: vertex_ai/gemini-pro" in str(e)
def test_litellm_completion_vertex_exception():
try:
import litellm
litellm.set_verbose = True
response = completion(
model="vertex_ai/gemini-pro",
api_key="good-morning",
messages=[{"role": "user", "content": "hello"}],
vertex_project="bad-project",
)
pytest.fail("Request should have failed - bad api key")
except Exception as e:
print("exception: ", e)
assert "model: vertex_ai/gemini-pro" in str(e)
assert "model_group" not in str(e)
assert "deployment" not in str(e)
# # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors
# def test_model_call(model):

View file

@ -221,6 +221,9 @@ def test_parallel_function_call_stream():
# test_parallel_function_call_stream()
@pytest.mark.skip(
reason="Flaky test. Groq function calling is not reliable for ci/cd testing."
)
def test_groq_parallel_function_call():
litellm.set_verbose = True
try:
@ -266,9 +269,12 @@ def test_groq_parallel_function_call():
)
print("Response\n", response)
response_message = response.choices[0].message
if hasattr(response_message, "tool_calls"):
tool_calls = response_message.tool_calls
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.name, str
)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)

View file

@ -0,0 +1,33 @@
# What is this?
## Unit tests for the 'function_setup()' function
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the, system path
import pytest, uuid
from litellm.utils import function_setup, Rules
from datetime import datetime
def test_empty_content():
"""
Make a chat completions request with empty content -> expect this to work
"""
rules_obj = Rules()
def completion():
pass
function_setup(
original_function=completion,
rules_obj=rules_obj,
start_time=datetime.now(),
messages=[],
litellm_call_id=str(uuid.uuid4()),
)

View file

@ -0,0 +1,25 @@
# What is this?
## Unit testing for the 'get_model_info()' function
import os, sys, traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import get_model_info
def test_get_model_info_simple_model_name():
"""
tests if model name given, and model exists in model info - the object is returned
"""
model = "claude-3-opus-20240229"
litellm.get_model_info(model)
def test_get_model_info_custom_llm_with_model_name():
"""
Tests if {custom_llm_provider}/{model_name} name given, and model exists in model info, the object is returned
"""
model = "anthropic/claude-3-opus-20240229"
litellm.get_model_info(model)

View file

@ -120,6 +120,15 @@ async def test_new_user_response(prisma_client):
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.proxy_server import user_api_key_cache
await new_team(
NewTeamRequest(
team_id="ishaan-special-team",
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
_response = await new_user(
data=NewUserRequest(
models=["azure-gpt-3.5"],
@ -999,10 +1008,32 @@ def test_generate_and_update_key(prisma_client):
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
# create team "litellm-core-infra@gmail.com""
print("creating team litellm-core-infra@gmail.com")
await new_team(
NewTeamRequest(
team_id="litellm-core-infra@gmail.com",
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
await new_team(
NewTeamRequest(
team_id="ishaan-special-team",
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
request = NewUserRequest(
metadata={"team": "litellm-team3", "project": "litellm-project3"},
metadata={"project": "litellm-project3"},
team_id="litellm-core-infra@gmail.com",
)
key = await new_user(request)
print(key)
@ -1015,7 +1046,6 @@ def test_generate_and_update_key(prisma_client):
print("\n info for key=", result["info"])
assert result["info"]["max_parallel_requests"] == None
assert result["info"]["metadata"] == {
"team": "litellm-team3",
"project": "litellm-project3",
}
assert result["info"]["team_id"] == "litellm-core-infra@gmail.com"
@ -1037,7 +1067,7 @@ def test_generate_and_update_key(prisma_client):
# update the team id
response2 = await update_key_fn(
request=Request,
data=UpdateKeyRequest(key=generated_key, team_id="ishaan"),
data=UpdateKeyRequest(key=generated_key, team_id="ishaan-special-team"),
)
print("response2=", response2)
@ -1048,11 +1078,10 @@ def test_generate_and_update_key(prisma_client):
print("\n info for key=", result["info"])
assert result["info"]["max_parallel_requests"] == None
assert result["info"]["metadata"] == {
"team": "litellm-team3",
"project": "litellm-project3",
}
assert result["info"]["models"] == ["ada", "babbage", "curie", "davinci"]
assert result["info"]["team_id"] == "ishaan"
assert result["info"]["team_id"] == "ishaan-special-team"
# cleanup - delete key
delete_key_request = KeyRequest(keys=[generated_key])
@ -1554,7 +1583,7 @@ async def test_view_spend_per_user(prisma_client):
first_user = user_by_spend[0]
print("\nfirst_user=", first_user)
assert first_user.spend > 0
assert first_user["spend"] > 0
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")
@ -1587,11 +1616,12 @@ async def test_key_name_null(prisma_client):
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False})
os.environ["DISABLE_KEY_NAME"] = "True"
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
print("generated key=", key)
generated_key = key.key
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
@ -1599,6 +1629,8 @@ async def test_key_name_null(prisma_client):
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")
finally:
os.environ["DISABLE_KEY_NAME"] = "False"
@pytest.mark.asyncio()
@ -1922,3 +1954,55 @@ async def test_proxy_load_test_db(prisma_client):
raise Exception(f"it worked! key={key.key}")
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio()
async def test_master_key_hashing(prisma_client):
try:
print("prisma client=", prisma_client)
master_key = "sk-1234"
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", master_key)
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.proxy_server import user_api_key_cache
await new_team(
NewTeamRequest(
team_id="ishaans-special-team",
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
_response = await new_user(
data=NewUserRequest(
models=["azure-gpt-3.5"],
team_id="ishaans-special-team",
tpm_limit=20,
)
)
print(_response)
assert _response.models == ["azure-gpt-3.5"]
assert _response.team_id == "ishaans-special-team"
assert _response.tpm_limit == 20
bearer_token = "Bearer " + master_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# use generated key to auth in
result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
assert result.api_key == hash_token(master_key)
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")

View file

@ -18,7 +18,7 @@ import pytest
import litellm
from litellm.proxy.enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -40,6 +40,7 @@ async def test_llm_guard_valid_response():
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
@ -76,6 +77,7 @@ async def test_llm_guard_error_raising():
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()

View file

@ -16,7 +16,7 @@ sys.path.insert(
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache, RedisCache
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
@ -29,7 +29,7 @@ async def test_pre_call_hook_rpm_limits():
Test if error raised on hitting rpm limits
"""
litellm.set_verbose = True
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1)
local_cache = DualCache()
# redis_usage_cache = RedisCache()
@ -87,6 +87,7 @@ async def test_pre_call_hook_team_rpm_limits(
"team_id": _team_id,
}
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
_api_key = hash_token(_api_key)
local_cache = DualCache()
local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
internal_cache = DualCache(redis_cache=_redis_usage_cache)

View file

@ -15,7 +15,7 @@ sys.path.insert(
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import (
@ -34,6 +34,7 @@ async def test_pre_call_hook():
Test if cache updated on call being received
"""
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -248,6 +249,7 @@ async def test_success_call_hook():
Test if on success, cache correctly decremented
"""
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -289,6 +291,7 @@ async def test_failure_call_hook():
Test if on failure, cache correctly decremented
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
@ -366,6 +369,7 @@ async def test_normal_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -443,6 +447,7 @@ async def test_normal_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)
@ -524,6 +529,7 @@ async def test_streaming_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -599,6 +605,7 @@ async def test_streaming_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)
@ -677,6 +684,7 @@ async def test_bad_router_call():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
@ -750,6 +758,7 @@ async def test_bad_router_tpm_limit():
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=10, tpm_limit=10
)

View file

@ -65,23 +65,18 @@ async def test_completion_with_caching_bad_call():
- Assert failure callback gets called
"""
litellm.set_verbose = True
sl = ServiceLogging(mock_testing=True)
try:
litellm.cache = Cache(type="redis", host="hello-world")
from litellm.caching import RedisCache
litellm.service_callback = ["prometheus_system"]
sl = ServiceLogging(mock_testing=True)
litellm.cache.cache.service_logger_obj = sl
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response1 = await acompletion(
model="gpt-3.5-turbo", messages=messages, caching=True
)
response1 = await acompletion(
model="gpt-3.5-turbo", messages=messages, caching=True
)
RedisCache(host="hello-world", service_logger_obj=sl)
except Exception as e:
pass
print(f"Receives exception = {str(e)}")
await asyncio.sleep(5)
assert sl.mock_testing_async_failure_hook > 0
assert sl.mock_testing_async_success_hook == 0
assert sl.mock_testing_sync_success_hook == 0
@ -144,64 +139,3 @@ async def test_router_with_caching():
except Exception as e:
pytest.fail(f"An exception occured - {str(e)}")
@pytest.mark.asyncio
async def test_router_with_caching_bad_call():
"""
- Run completion with caching (incorrect credentials)
- Assert failure callback gets called
"""
try:
def get_azure_params(deployment_name: str):
params = {
"model": f"azure/{deployment_name}",
"api_key": os.environ["AZURE_API_KEY"],
"api_version": os.environ["AZURE_API_VERSION"],
"api_base": os.environ["AZURE_API_BASE"],
}
return params
model_list = [
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": 100,
},
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": 1000,
},
]
router = litellm.Router(
model_list=model_list,
set_verbose=True,
debug_level="DEBUG",
routing_strategy="usage-based-routing-v2",
redis_host="hello world",
redis_port=os.environ["REDIS_PORT"],
redis_password=os.environ["REDIS_PASSWORD"],
)
litellm.service_callback = ["prometheus_system"]
sl = ServiceLogging(mock_testing=True)
sl.prometheusServicesLogger.mock_testing = True
router.cache.redis_cache.service_logger_obj = sl
messages = [{"role": "user", "content": "Hey, how's it going?"}]
try:
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
except Exception as e:
pass
assert sl.mock_testing_async_failure_hook > 0
assert sl.mock_testing_async_success_hook == 0
assert sl.mock_testing_sync_success_hook == 0
except Exception as e:
pytest.fail(f"An exception occured - {str(e)}")

View file

@ -167,8 +167,9 @@ def test_chat_completion_exception_any_model(client):
openai_exception = openai_client._make_status_error_from_response(
response=response
)
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message
assert "Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
@ -195,6 +196,8 @@ def test_embedding_exception_any_model(client):
)
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message
assert "Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")

View file

@ -362,7 +362,9 @@ def test_load_router_config():
] # init with all call types
except Exception as e:
pytest.fail("Proxy: Got exception reading config", e)
pytest.fail(
f"Proxy: Got exception reading config: {str(e)}\n{traceback.format_exc()}"
)
# test_load_router_config()

View file

@ -15,6 +15,61 @@ from litellm import Router
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
@pytest.mark.asyncio
async def test_router_async_caching_with_ssl_url():
"""
Tests when a redis url is passed to the router, if caching is correctly setup
"""
try:
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
],
redis_url=os.getenv("REDIS_SSL_URL"),
)
response = await router.cache.redis_cache.ping()
print(f"response: {response}")
assert response == True
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_router_sync_caching_with_ssl_url():
"""
Tests when a redis url is passed to the router, if caching is correctly setup
"""
try:
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
],
redis_url=os.getenv("REDIS_SSL_URL"),
)
response = router.cache.redis_cache.sync_ping()
print(f"response: {response}")
assert response == True
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
async def test_acompletion_caching_on_router():
# tests acompletion + caching on router

View file

@ -81,7 +81,7 @@ def test_async_fallbacks(caplog):
# Define the expected log messages
# - error request, falling back notice, success notice
expected_logs = [
"Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None",
"Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None\n\nRouter Redis Caching=None",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",

View file

@ -512,3 +512,76 @@ async def test_wildcard_openai_routing():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
"""
Test async router get deployment (Simpl-shuffle)
"""
rpm_list = [[None, None], [6, 1440]]
tpm_list = [[None, None], [6, 1440]]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"rpm_list, tpm_list",
[(rpm, tpm) for rpm in rpm_list for tpm in tpm_list],
)
async def test_weighted_selection_router_async(rpm_list, tpm_list):
# this tests if load balancing works based on the provided rpms in the router
# it's a fast test, only tests get_available_deployment
# users can pass rpms as a litellm_param
try:
litellm.set_verbose = False
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
"rpm": rpm_list[0],
"tpm": tpm_list[0],
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
"rpm": rpm_list[1],
"tpm": tpm_list[1],
},
},
]
router = Router(
model_list=model_list,
)
selection_counts = defaultdict(int)
# call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time
for _ in range(1000):
selected_model = await router.async_get_available_deployment(
"gpt-3.5-turbo"
)
selected_model_id = selected_model["litellm_params"]["model"]
selected_model_name = selected_model_id
selection_counts[selected_model_name] += 1
print(selection_counts)
total_requests = sum(selection_counts.values())
if rpm_list[0] is not None or tpm_list[0] is not None:
# Assert that 'azure/chatgpt-v-2' has about 90% of the total requests
assert (
selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89
), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
else:
# Assert both are used
assert selection_counts["azure/chatgpt-v-2"] > 0
assert selection_counts["gpt-3.5-turbo-0613"] > 0
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")

View file

@ -0,0 +1,115 @@
# What is this?
## Unit tests for the max_parallel_requests feature on Router
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.utils import calculate_max_parallel_requests
from typing import Optional
"""
- only rpm
- only tpm
- only max_parallel_requests
- max_parallel_requests + rpm
- max_parallel_requests + tpm
- max_parallel_requests + tpm + rpm
"""
max_parallel_requests_values = [None, 10]
tpm_values = [None, 20, 300000]
rpm_values = [None, 30]
default_max_parallel_requests = [None, 40]
@pytest.mark.parametrize(
"max_parallel_requests, tpm, rpm, default_max_parallel_requests",
[
(mp, tp, rp, dmp)
for mp in max_parallel_requests_values
for tp in tpm_values
for rp in rpm_values
for dmp in default_max_parallel_requests
],
)
def test_scenario(max_parallel_requests, tpm, rpm, default_max_parallel_requests):
calculated_max_parallel_requests = calculate_max_parallel_requests(
max_parallel_requests=max_parallel_requests,
rpm=rpm,
tpm=tpm,
default_max_parallel_requests=default_max_parallel_requests,
)
if max_parallel_requests is not None:
assert max_parallel_requests == calculated_max_parallel_requests
elif rpm is not None:
assert rpm == calculated_max_parallel_requests
elif tpm is not None:
calculated_rpm = int(tpm / 1000 / 6)
if calculated_rpm == 0:
calculated_rpm = 1
print(
f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={calculated_max_parallel_requests}"
)
assert calculated_rpm == calculated_max_parallel_requests
elif default_max_parallel_requests is not None:
assert calculated_max_parallel_requests == default_max_parallel_requests
else:
assert calculated_max_parallel_requests is None
@pytest.mark.parametrize(
"max_parallel_requests, tpm, rpm, default_max_parallel_requests",
[
(mp, tp, rp, dmp)
for mp in max_parallel_requests_values
for tp in tpm_values
for rp in rpm_values
for dmp in default_max_parallel_requests
],
)
def test_setting_mpr_limits_per_model(
max_parallel_requests, tpm, rpm, default_max_parallel_requests
):
deployment = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"max_parallel_requests": max_parallel_requests,
"tpm": tpm,
"rpm": rpm,
},
"model_info": {"id": "my-unique-id"},
}
router = litellm.Router(
model_list=[deployment],
default_max_parallel_requests=default_max_parallel_requests,
)
mpr_client: Optional[asyncio.Semaphore] = router._get_client(
deployment=deployment,
kwargs={},
client_type="max_parallel_requests",
)
if max_parallel_requests is not None:
assert max_parallel_requests == mpr_client._value
elif rpm is not None:
assert rpm == mpr_client._value
elif tpm is not None:
calculated_rpm = int(tpm / 1000 / 6)
if calculated_rpm == 0:
calculated_rpm = 1
print(
f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={mpr_client._value}"
)
assert calculated_rpm == mpr_client._value
elif default_max_parallel_requests is not None:
assert mpr_client._value == default_max_parallel_requests
else:
assert mpr_client is None
# raise Exception("it worked!")

View file

@ -0,0 +1,85 @@
#### What this tests ####
# This tests utils used by llm router -> like llmrouter.get_settings()
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
load_dotenv()
def test_returned_settings():
# this tests if the router raises an exception when invalid params are set
# in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception
litellm.set_verbose = True
import openai
try:
print("testing if router raises an exception")
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { #
"model": "gpt-3.5-turbo",
"api_key": "bad-key",
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="latency-based-routing",
routing_strategy_args={"ttl": 10},
set_verbose=False,
num_retries=3,
retry_after=5,
allowed_fails=1,
cooldown_time=30,
) # type: ignore
settings = router.get_settings()
print(settings)
"""
routing_strategy: "simple-shuffle"
routing_strategy_args: {"ttl": 10} # Average the last 10 calls to compute avg latency per model
allowed_fails: 1
num_retries: 3
retry_after: 5 # seconds to wait before retrying a failed request
cooldown_time: 30 # seconds to cooldown a deployment after failure
"""
assert settings["routing_strategy"] == "latency-based-routing"
assert settings["routing_strategy_args"]["ttl"] == 10
assert settings["allowed_fails"] == 1
assert settings["num_retries"] == 3
assert settings["retry_after"] == 5
assert settings["cooldown_time"] == 30
except:
print(traceback.format_exc())
pytest.fail("An error occurred - " + traceback.format_exc())

View file

@ -0,0 +1,53 @@
# What is this?
## unit tests for 'simple-shuffle'
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import Router
"""
Test random shuffle
- async
- sync
"""
async def test_simple_shuffle():
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
num_retries=3,
) # type: ignore

View file

@ -220,6 +220,17 @@ tools_schema = [
# test_completion_cohere_stream()
def test_completion_azure_stream_special_char():
litellm.set_verbose = True
messages = [{"role": "user", "content": "hi. respond with the <xml> tag only"}]
response = completion(model="azure/chatgpt-v-2", messages=messages, stream=True)
response_str = ""
for part in response:
response_str += part.choices[0].delta.content or ""
print(f"response_str: {response_str}")
assert len(response_str) > 0
def test_completion_cohere_stream_bad_key():
try:
litellm.cache = None
@ -578,6 +589,64 @@ def test_completion_mistral_api_stream():
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_mistral_large_function_call_with_streaming():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
idx = 0
for chunk in response:
print(f"chunk in response: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1 and chunk.choices[0].finish_reason is None:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_api_stream()
@ -2252,7 +2321,12 @@ def test_completion_claude_3_function_call_with_streaming():
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
@ -2306,7 +2380,12 @@ async def test_acompletion_claude_3_function_call_with_streaming():
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
try:
# test without max tokens
response = await acompletion(

View file

@ -1,5 +1,5 @@
#### What this tests ####
# This tests the router's ability to pick deployment with lowest tpm
# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2-v2'
import sys, os, asyncio, time, random
from datetime import datetime
@ -15,11 +15,18 @@ sys.path.insert(
import pytest
from litellm import Router
import litellm
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import (
LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler,
)
from litellm.utils import get_utc_datetime
from litellm.caching import DualCache
### UNIT TESTS FOR TPM/RPM ROUTING ###
"""
- Given 2 deployments, make sure it's shuffling deployments correctly.
"""
def test_tpm_rpm_updated():
test_cache = DualCache()
@ -41,20 +48,23 @@ def test_tpm_rpm_updated():
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
current_minute = datetime.now().strftime("%H-%M")
tpm_count_api_key = f"{model_group}:tpm:{current_minute}"
rpm_count_api_key = f"{model_group}:rpm:{current_minute}"
assert (
response_obj["usage"]["total_tokens"]
== test_cache.get_cache(key=tpm_count_api_key)[deployment_id]
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}"
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}"
print(f"tpm_count_api_key={tpm_count_api_key}")
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
key=tpm_count_api_key
)
assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id]
assert 1 == test_cache.get_cache(key=rpm_count_api_key)
# test_tpm_rpm_updated()
@ -120,13 +130,6 @@ def test_get_available_deployments():
)
## CHECK WHAT'S SELECTED ##
print(
lowest_tpm_logger.get_available_deployments(
model_group=model_group,
healthy_deployments=model_list,
input=["Hello world"],
)
)
assert (
lowest_tpm_logger.get_available_deployments(
model_group=model_group,
@ -168,7 +171,7 @@ def test_router_get_available_deployments():
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
routing_strategy="usage-based-routing-v2",
set_verbose=False,
num_retries=3,
) # type: ignore
@ -187,7 +190,7 @@ def test_router_get_available_deployments():
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
router.lowesttpm_logger.log_success_event(
router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
@ -206,7 +209,7 @@ def test_router_get_available_deployments():
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
end_time = time.time()
router.lowesttpm_logger.log_success_event(
router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
@ -214,7 +217,7 @@ def test_router_get_available_deployments():
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
# print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model"))
assert (
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
)
@ -242,7 +245,7 @@ def test_router_skip_rate_limited_deployments():
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
routing_strategy="usage-based-routing-v2",
set_verbose=False,
num_retries=3,
) # type: ignore
@ -260,7 +263,7 @@ def test_router_skip_rate_limited_deployments():
start_time = time.time()
response_obj = {"usage": {"total_tokens": 1439}}
end_time = time.time()
router.lowesttpm_logger.log_success_event(
router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
@ -268,7 +271,7 @@ def test_router_skip_rate_limited_deployments():
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
# print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model"))
try:
router.get_available_deployment(
model="azure-model",
@ -297,7 +300,7 @@ def test_single_deployment_tpm_zero():
router = litellm.Router(
model_list=model_list,
routing_strategy="usage-based-routing",
routing_strategy="usage-based-routing-v2",
cache_responses=True,
)
@ -343,7 +346,7 @@ async def test_router_completion_streaming():
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
routing_strategy="usage-based-routing-v2",
set_verbose=False,
) # type: ignore
@ -360,8 +363,9 @@ async def test_router_completion_streaming():
if response is not None:
## CALL 3
await asyncio.sleep(1) # let the token update happen
current_minute = datetime.now().strftime("%H-%M")
picked_deployment = router.lowesttpm_logger.get_available_deployments(
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
picked_deployment = router.lowesttpm_logger_v2.get_available_deployments(
model_group=model,
healthy_deployments=router.healthy_deployments,
messages=messages,
@ -383,3 +387,8 @@ async def test_router_completion_streaming():
# asyncio.run(test_router_completion_streaming())
"""
- Unit test for sync 'pre_call_checks'
- Unit test for async 'async_pre_call_checks'
"""

View file

@ -173,6 +173,22 @@ def test_trimming_should_not_change_original_messages():
assert messages == messages_copy
@pytest.mark.parametrize("model", ["gpt-4-0125-preview", "claude-3-opus-20240229"])
def test_trimming_with_model_cost_max_input_tokens(model):
messages = [
{"role": "system", "content": "This is a normal system message"},
{
"role": "user",
"content": "This is a sentence" * 100000,
},
]
trimmed_messages = trim_messages(messages, model=model)
assert (
get_token_count(trimmed_messages, model=model)
< litellm.model_cost[model]["max_input_tokens"]
)
def test_get_valid_models():
old_environ = os.environ
os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ

View file

@ -101,12 +101,39 @@ class LiteLLM_Params(BaseModel):
aws_secret_access_key: Optional[str] = None
aws_region_name: Optional[str] = None
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
def __init__(
self,
model: str,
max_retries: Optional[Union[int, str]] = None,
tpm: Optional[int] = None,
rpm: Optional[int] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Optional[Union[float, str]] = None, # if str, pass in as os.environ/
stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/
),
organization: Optional[str] = None, # for openai orgs
## VERTEX AI ##
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
**params
):
args = locals()
args.pop("max_retries", None)
args.pop("self", None)
args.pop("params", None)
args.pop("__class__", None)
if max_retries is None:
max_retries = 2
elif isinstance(max_retries, str):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **params)
super().__init__(max_retries=max_retries, **args, **params)
class Config:
extra = "allow"
@ -133,12 +160,23 @@ class Deployment(BaseModel):
litellm_params: LiteLLM_Params
model_info: ModelInfo
def __init__(self, model_info: Optional[Union[ModelInfo, dict]] = None, **params):
def __init__(
self,
model_name: str,
litellm_params: LiteLLM_Params,
model_info: Optional[Union[ModelInfo, dict]] = None,
**params
):
if model_info is None:
model_info = ModelInfo()
elif isinstance(model_info, dict):
model_info = ModelInfo(**model_info)
super().__init__(model_info=model_info, **params)
super().__init__(
model_info=model_info,
model_name=model_name,
litellm_params=litellm_params,
**params
)
def to_json(self, **kwargs):
try:

View file

@ -5,11 +5,12 @@ from typing import Optional
class ServiceTypes(enum.Enum):
"""
Enum for litellm-adjacent services (redis/postgres/etc.)
Enum for litellm + litellm-adjacent services (redis/postgres/etc.)
"""
REDIS = "redis"
DB = "postgres"
LITELLM = "self"
class ServiceLoggerPayload(BaseModel):
@ -21,6 +22,7 @@ class ServiceLoggerPayload(BaseModel):
error: Optional[str] = Field(None, description="what was the error")
service: ServiceTypes = Field(description="who is this for? - postgres/redis")
duration: float = Field(description="How long did the request take?")
call_type: str = Field(description="The call of the service, being made")
def to_json(self, **kwargs):
try:

View file

@ -228,6 +228,24 @@ class Function(OpenAIObject):
arguments: str
name: Optional[str] = None
def __init__(
self,
arguments: Union[Dict, str],
name: Optional[str] = None,
**params,
):
if isinstance(arguments, Dict):
arguments = json.dumps(arguments)
else:
arguments = arguments
name = name
# Build a dictionary with the structure your BaseModel expects
data = {"arguments": arguments, "name": name, **params}
super(Function, self).__init__(**data)
class ChatCompletionDeltaToolCall(OpenAIObject):
id: Optional[str] = None
@ -2260,6 +2278,24 @@ class Logging:
level="ERROR",
kwargs=self.model_call_details,
)
elif callback == "prometheus":
global prometheusLogger
verbose_logger.debug("reaches prometheus for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
if (
k != "original_response"
): # copy.deepcopy raises errors as this could be a coroutine
kwargs[k] = v
kwargs["exception"] = str(exception)
prometheusLogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
except Exception as e:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging with integrations {str(e)}"
@ -2392,25 +2428,12 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
rules_obj = Rules()
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
def function_setup(
original_function, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
function_id = kwargs["id"] if "id" in kwargs else None
if litellm.use_client or (
"use_client" in kwargs and kwargs["use_client"] == True
):
if "lite_debugger" not in litellm.input_callback:
litellm.input_callback.append("lite_debugger")
if "lite_debugger" not in litellm.success_callback:
litellm.success_callback.append("lite_debugger")
if "lite_debugger" not in litellm.failure_callback:
litellm.failure_callback.append("lite_debugger")
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
@ -2533,7 +2556,7 @@ def client(original_function):
input="".join(
m.get("content", "")
for m in messages
if isinstance(m["content"], str)
if "content" in m and isinstance(m["content"], str)
),
model=model,
)
@ -2596,6 +2619,11 @@ def client(original_function):
)
raise e
def client(original_function):
global liteDebuggerClient, get_all_keys
rules_obj = Rules()
def check_coroutine(value) -> bool:
if inspect.iscoroutine(value):
return True
@ -2688,7 +2716,9 @@ def client(original_function):
try:
if logging_obj is None:
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
logging_obj, kwargs = function_setup(
original_function, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
# CHECK FOR 'os.environ/' in kwargs
@ -2715,23 +2745,22 @@ def client(original_function):
# [OPTIONAL] CHECK CACHE
print_verbose(
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}"
)
# if caching is false or cache["no-cache"]==True, don't run this
if (
(
(
(
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
)
and kwargs.get("cache", {}).get("no-cache", False) != True
)
)
and kwargs.get("aembedding", False) != True
and kwargs.get("atext_completion", False) != True
and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True
and kwargs.get("atranscription", False) != True
@ -2996,7 +3025,9 @@ def client(original_function):
try:
if logging_obj is None:
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
logging_obj, kwargs = function_setup(
original_function, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET
@ -3008,24 +3039,17 @@ def client(original_function):
)
# [OPTIONAL] CHECK CACHE
print_verbose(f"litellm.cache: {litellm.cache}")
print_verbose(
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
f"ASYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache'): {kwargs.get('cache', None)}"
)
# if caching is false, don't run this
final_embedding_cached_response = None
if (
(
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
and kwargs.get("cache").get("no-cache", False) != True
)
) and (
kwargs.get("cache", {}).get("no-cache", False) != True
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose("INSIDE CHECKING CACHE")
@ -3071,7 +3095,6 @@ def client(original_function):
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance(
cached_result, list
):
@ -4204,9 +4227,7 @@ def supports_vision(model: str):
return True
return False
else:
raise Exception(
f"Model not in model_prices_and_context_window.json. You passed model={model}."
)
return False
def supports_parallel_function_calling(model: str):
@ -4736,6 +4757,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop
if tools is not None:
optional_params["tools"] = tools
if seed is not None:
optional_params["seed"] = seed
elif custom_llm_provider == "maritalk":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -4907,37 +4930,11 @@ def get_optional_params(
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
if n is not None:
optional_params["candidate_count"] = n
if stop is not None:
if isinstance(stop, str):
optional_params["stop_sequences"] = [stop]
elif isinstance(stop, list):
optional_params["stop_sequences"] = stop
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
if response_format is not None and response_format["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models
gtool_func_declarations = []
for tool in tools:
gtool_func_declaration = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
optional_params = litellm.VertexAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
generative_models.Tool(function_declarations=gtool_func_declarations)
]
print_verbose(
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
)
@ -5297,7 +5294,9 @@ def get_optional_params(
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if response_format is not None:
optional_params["response_format"] = tool_choice
optional_params["response_format"] = response_format
if seed is not None:
optional_params["seed"] = seed
elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params(
@ -5418,6 +5417,49 @@ def get_optional_params(
return optional_params
def calculate_max_parallel_requests(
max_parallel_requests: Optional[int],
rpm: Optional[int],
tpm: Optional[int],
default_max_parallel_requests: Optional[int],
) -> Optional[int]:
"""
Returns the max parallel requests to send to a deployment.
Used in semaphore for async requests on router.
Parameters:
- max_parallel_requests - Optional[int] - max_parallel_requests allowed for that deployment
- rpm - Optional[int] - requests per minute allowed for that deployment
- tpm - Optional[int] - tokens per minute allowed for that deployment
- default_max_parallel_requests - Optional[int] - default_max_parallel_requests allowed for any deployment
Returns:
- int or None (if all params are None)
Order:
max_parallel_requests > rpm > tpm / 6 (azure formula) > default max_parallel_requests
Azure RPM formula:
6 rpm per 1000 TPM
https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits
"""
if max_parallel_requests is not None:
return max_parallel_requests
elif rpm is not None:
return rpm
elif tpm is not None:
calculated_rpm = int(tpm / 1000 / 6)
if calculated_rpm == 0:
calculated_rpm = 1
return calculated_rpm
elif default_max_parallel_requests is not None:
return default_max_parallel_requests
return None
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
"""
Returns the api base used for calling the model.
@ -5436,7 +5478,9 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
get_api_base(model="gemini/gemini-pro")
```
"""
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
_optional_params = LiteLLM_Params(
model=model, **optional_params
) # convert to pydantic object
# get llm provider
try:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
@ -5513,6 +5557,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"tools",
"tool_choice",
"response_format",
"seed",
]
elif custom_llm_provider == "cohere":
return [
@ -5538,6 +5583,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"n",
"tools",
"tool_choice",
"seed",
]
elif custom_llm_provider == "maritalk":
return [
@ -5637,17 +5683,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai":
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
return litellm.VertexAIConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha":
@ -5698,6 +5734,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"frequency_penalty",
"logit_bias",
"user",
"response_format",
]
elif custom_llm_provider == "perplexity":
return [
@ -5922,6 +5959,7 @@ def get_llm_provider(
or model in litellm.vertex_code_text_models
or model in litellm.vertex_language_models
or model in litellm.vertex_embedding_models
or model in litellm.vertex_vision_models
):
custom_llm_provider = "vertex_ai"
## ai21
@ -5971,6 +6009,9 @@ def get_llm_provider(
if isinstance(e, litellm.exceptions.BadRequestError):
raise e
else:
error_str = (
f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}"
)
raise litellm.exceptions.BadRequestError( # type: ignore
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
model=model,
@ -6160,6 +6201,12 @@ def get_model_info(model: str):
"litellm_provider": "huggingface",
"mode": "chat",
}
else:
"""
Check if model in model cost map
"""
if model in litellm.model_cost:
return litellm.model_cost[model]
else:
raise Exception()
except:
@ -7348,6 +7395,7 @@ def exception_type(
original_exception,
custom_llm_provider,
completion_kwargs={},
extra_kwargs={},
):
global user_logger_fn, liteDebuggerClient
exception_mapping_worked = False
@ -7842,6 +7890,26 @@ def exception_type(
response=original_exception.response,
)
elif custom_llm_provider == "vertex_ai":
if completion_kwargs is not None:
# add model, deployment and model_group to the exception message
_model = completion_kwargs.get("model")
error_str += f"\nmodel: {_model}\n"
if extra_kwargs is not None:
_vertex_project = extra_kwargs.get("vertex_project")
_vertex_location = extra_kwargs.get("vertex_location")
_metadata = extra_kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
if _model_group is not None:
error_str += f"model_group: {_model_group}\n"
if _deployment is not None:
error_str += f"deployment: {_deployment}\n"
if _vertex_project is not None:
error_str += f"vertex_project: {_vertex_project}\n"
if _vertex_location is not None:
error_str += f"vertex_location: {_vertex_location}\n"
if (
"Vertex AI API has not been used in project" in error_str
or "Unable to find your project" in error_str
@ -7853,6 +7921,15 @@ def exception_type(
llm_provider="vertex_ai",
response=original_exception.response,
)
elif "None Unknown Error." in error_str:
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException - {error_str}",
status_code=500,
model=model,
llm_provider="vertex_ai",
request=original_exception.request,
)
elif "403" in error_str:
exception_mapping_worked = True
raise BadRequestError(
@ -7878,6 +7955,8 @@ def exception_type(
elif (
"429 Quota exceeded" in error_str
or "IndexError: list index out of range" in error_str
or "429 Unable to submit request because the service is temporarily out of capacity."
in error_str
):
exception_mapping_worked = True
raise RateLimitError(
@ -8867,7 +8946,16 @@ class CustomStreamWrapper:
raise e
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
"""
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
"""
hold = False
if (
self.custom_llm_provider != "huggingface"
and self.custom_llm_provider != "sagemaker"
):
return hold, chunk
if finish_reason:
for token in self.special_tokens:
if token in chunk:
@ -8883,6 +8971,7 @@ class CustomStreamWrapper:
for token in self.special_tokens:
if len(curr_chunk) < len(token) and curr_chunk in token:
hold = True
self.holding_chunk = curr_chunk
elif len(curr_chunk) >= len(token):
if token in curr_chunk:
self.holding_chunk = curr_chunk.replace(token, "")
@ -9944,6 +10033,22 @@ class CustomStreamWrapper:
t.function.arguments = ""
_json_delta = delta.model_dump()
print_verbose(f"_json_delta: {_json_delta}")
if "role" not in _json_delta or _json_delta["role"] is None:
_json_delta["role"] = (
"assistant" # mistral's api returns role as None
)
if "tool_calls" in _json_delta and isinstance(
_json_delta["tool_calls"], list
):
for tool in _json_delta["tool_calls"]:
if (
isinstance(tool, dict)
and "function" in tool
and isinstance(tool["function"], dict)
and ("type" not in tool or tool["type"] is None)
):
# if function returned but type set to None - mistral's api returns type: None
tool["type"] = "function"
model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e:
traceback.print_exc()
@ -9964,6 +10069,7 @@ class CustomStreamWrapper:
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
)
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## RETURN ARG
if (
"content" in completion_obj
@ -10036,7 +10142,6 @@ class CustomStreamWrapper:
elif self.received_finish_reason is not None:
if self.sent_last_chunk == True:
raise StopIteration
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None:
@ -10609,16 +10714,18 @@ def trim_messages(
messages = copy.deepcopy(messages)
try:
print_verbose(f"trimming messages")
if max_tokens == None:
if max_tokens is None:
# Check if model is valid
if model in litellm.model_cost:
max_tokens_for_model = litellm.model_cost[model]["max_tokens"]
max_tokens_for_model = litellm.model_cost[model].get(
"max_input_tokens", litellm.model_cost[model]["max_tokens"]
)
max_tokens = int(max_tokens_for_model * trim_ratio)
else:
# if user did not specify max tokens
# if user did not specify max (input) tokens
# or passed an llm litellm does not know
# do nothing, just return messages
return
return messages
system_message = ""
for message in messages:

View file

@ -75,7 +75,8 @@
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
@ -86,7 +87,8 @@
"litellm_provider": "openai",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true
"supports_parallel_function_calling": true,
"supports_vision": true
},
"gpt-4-1106-preview": {
"max_tokens": 4096,
@ -648,6 +650,7 @@
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
},
"mistral/mistral-small-latest": {
@ -657,6 +660,7 @@
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"supports_function_calling": true,
"mode": "chat"
},
"mistral/mistral-medium": {
@ -706,6 +710,16 @@
"mode": "chat",
"supports_function_calling": true
},
"mistral/open-mixtral-8x7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
"mistral/mistral-embed": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -723,6 +737,26 @@
"mode": "chat",
"supports_function_calling": true
},
"groq/llama3-8b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/llama3-70b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000064,
"output_cost_per_token": 0.00000080,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/mixtral-8x7b-32768": {
"max_tokens": 32768,
"max_input_tokens": 32768,
@ -777,7 +811,9 @@
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "anthropic",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 264
},
"claude-3-opus-20240229": {
"max_tokens": 4096,
@ -786,7 +822,9 @@
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "anthropic",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 395
},
"claude-3-sonnet-20240229": {
"max_tokens": 4096,
@ -795,7 +833,9 @@
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "anthropic",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 159
},
"text-bison": {
"max_tokens": 1024,
@ -1010,6 +1050,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0215": {
@ -1021,6 +1062,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0409": {
@ -1032,6 +1074,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-experimental": {
@ -1043,6 +1086,7 @@
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": false,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-pro-vision": {
@ -1097,7 +1141,8 @@
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096,
@ -1106,7 +1151,8 @@
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096,
@ -1115,7 +1161,8 @@
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.0000075,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"textembedding-gecko": {
"max_tokens": 3072,
@ -1268,8 +1315,23 @@
"litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-pro-latest": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"source": "https://ai.google.dev/models/gemini"
},
"gemini/gemini-pro-vision": {
"max_tokens": 2048,
"max_input_tokens": 30720,
@ -1484,6 +1546,13 @@
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "openrouter",
"mode": "chat"
},
"j2-ultra": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -1731,7 +1800,8 @@
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096,
@ -1740,7 +1810,8 @@
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
@ -1749,7 +1820,8 @@
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_function_calling": true
},
"anthropic.claude-v1": {
"max_tokens": 8191,

Some files were not shown because too many files have changed in this diff Show more