Merge branch 'main' into feature/watsonx-integration

This commit is contained in:
Simon Sanchez Viloria 2024-05-06 17:27:14 +02:00
commit 9a95fa9348
144 changed files with 8872 additions and 2296 deletions

View file

@ -40,7 +40,7 @@ jobs:
pip install "aioboto3==12.3.0" pip install "aioboto3==12.3.0"
pip install langchain pip install langchain
pip install lunary==0.2.5 pip install lunary==0.2.5
pip install "langfuse==2.7.3" pip install "langfuse==2.27.1"
pip install numpydoc pip install numpydoc
pip install traceloop-sdk==0.0.69 pip install traceloop-sdk==0.0.69
pip install openai pip install openai

47
.github/pull_request_template.md vendored Normal file
View file

@ -0,0 +1,47 @@
<!-- This is just examples. You can remove all items if you want. -->
<!-- Please remove all comments. -->
## Title
<!-- e.g. "Implement user authentication feature" -->
## Relevant issues
<!-- e.g. "Fixes #000" -->
## Type
<!-- Select the type of Pull Request -->
<!-- Keep only the necessary ones -->
🆕 New Feature
🐛 Bug Fix
🧹 Refactoring
📖 Documentation
💻 Development Environment
🚄 Infrastructure
✅ Test
## Changes
<!-- List of changes -->
## Testing
<!-- Test procedure -->
## Notes
<!-- Test results -->
<!-- Points to note for the reviewer, consultation content, concerns -->
## Pre-Submission Checklist (optional but appreciated):
- [ ] I have included relevant documentation updates (stored in /docs/my-website)
## OS Tests (optional but appreciated):
- [ ] Tested on Windows
- [ ] Tested on MacOS
- [ ] Tested on Linux

1
.gitignore vendored
View file

@ -51,3 +51,4 @@ loadtest_kub.yaml
litellm/proxy/_new_secret_config.yaml litellm/proxy/_new_secret_config.yaml
litellm/proxy/_new_secret_config.yaml litellm/proxy/_new_secret_config.yaml
litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml
litellm/proxy/_super_secret_config.yaml

View file

@ -7,7 +7,7 @@ repos:
rev: 7.0.0 # The version of flake8 to use rev: 7.0.0 # The version of flake8 to use
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/ exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
additional_dependencies: [flake8-print] additional_dependencies: [flake8-print]
files: litellm/.*\.py files: litellm/.*\.py
- repo: local - repo: local

View file

@ -248,7 +248,7 @@ Step 2: Navigate into the project, and install dependencies:
``` ```
cd litellm cd litellm
poetry install poetry install -E extra_proxy -E proxy
``` ```
Step 3: Test your change: Step 3: Test your change:

View file

@ -84,7 +84,7 @@ def completion(
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stop=None, stop=None,
max_tokens: Optional[float] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,

View file

@ -1,7 +1,7 @@
# Completion Token Usage & Cost # Completion Token Usage & Cost
By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/)) By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/))
However, we also expose 5 helper functions + **[NEW]** an API to calculate token usage across providers: However, we also expose some helper functions + **[NEW]** an API to calculate token usage across providers:
- `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode) - `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode)
@ -9,17 +9,19 @@ However, we also expose 5 helper functions + **[NEW]** an API to calculate token
- `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter) - `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter)
- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#4-cost_per_token) - `create_pretrained_tokenizer` and `create_tokenizer`: LiteLLM provides default tokenizer support for OpenAI, Cohere, Anthropic, Llama2, and Llama3 models. If you are using a different model, you can create a custom tokenizer and pass it as `custom_tokenizer` to the `encode`, `decode`, and `token_counter` methods. [**Jump to code**](#4-create_pretrained_tokenizer-and-create_tokenizer)
- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#5-completion_cost) - `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#5-cost_per_token)
- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#6-get_max_tokens) - `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#6-completion_cost)
- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#7-model_cost) - `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#7-get_max_tokens)
- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#8-register_model) - `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#8-model_cost)
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#9-apilitellmai) - `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#9-register_model)
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#10-apilitellmai)
📣 This is a community maintained list. Contributions are welcome! ❤️ 📣 This is a community maintained list. Contributions are welcome! ❤️
@ -60,7 +62,24 @@ messages = [{"user": "role", "content": "Hey, how's it going"}]
print(token_counter(model="gpt-3.5-turbo", messages=messages)) print(token_counter(model="gpt-3.5-turbo", messages=messages))
``` ```
### 4. `cost_per_token` ### 4. `create_pretrained_tokenizer` and `create_tokenizer`
```python
from litellm import create_pretrained_tokenizer, create_tokenizer
# get tokenizer from huggingface repo
custom_tokenizer_1 = create_pretrained_tokenizer("Xenova/llama-3-tokenizer")
# use tokenizer from json file
with open("tokenizer.json") as f:
json_data = json.load(f)
json_str = json.dumps(json_data)
custom_tokenizer_2 = create_tokenizer(json_str)
```
### 5. `cost_per_token`
```python ```python
from litellm import cost_per_token from litellm import cost_per_token
@ -72,7 +91,7 @@ prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_toke
print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar) print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar)
``` ```
### 5. `completion_cost` ### 6. `completion_cost`
* Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings * Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings
* Output: Returns a `float` of cost for the `completion` call * Output: Returns a `float` of cost for the `completion` call
@ -99,7 +118,7 @@ cost = completion_cost(model="bedrock/anthropic.claude-v2", prompt="Hey!", compl
formatted_string = f"${float(cost):.10f}" formatted_string = f"${float(cost):.10f}"
print(formatted_string) print(formatted_string)
``` ```
### 6. `get_max_tokens` ### 7. `get_max_tokens`
Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list). Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list).
Output: Returns the maximum number of tokens allowed for the given model Output: Returns the maximum number of tokens allowed for the given model
@ -112,7 +131,7 @@ model = "gpt-3.5-turbo"
print(get_max_tokens(model)) # Output: 4097 print(get_max_tokens(model)) # Output: 4097
``` ```
### 7. `model_cost` ### 8. `model_cost`
* Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) * Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
@ -122,7 +141,7 @@ from litellm import model_cost
print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...} print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...}
``` ```
### 8. `register_model` ### 9. `register_model`
* Input: Provide EITHER a model cost dictionary or a url to a hosted json blob * Input: Provide EITHER a model cost dictionary or a url to a hosted json blob
* Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details. * Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details.
@ -157,5 +176,3 @@ export LITELLM_LOCAL_MODEL_COST_MAP="True"
``` ```
Note: this means you will need to upgrade to get updated pricing, and newer models. Note: this means you will need to upgrade to get updated pricing, and newer models.

View file

@ -23,6 +23,14 @@ response = completion(model="gpt-3.5-turbo", messages=messages)
response = completion("command-nightly", messages) response = completion("command-nightly", messages)
``` ```
## JSON Logs
If you need to store the logs as JSON, just set the `litellm.json_logs = True`.
We currently just log the raw POST request from litellm as a JSON - [**See Code**].
[Share feedback here](https://github.com/BerriAI/litellm/issues)
## Logger Function ## Logger Function
But sometimes all you care about is seeing exactly what's getting sent to your api call and what's being returned - e.g. if the api call is failing, why is that happening? what are the exact params being set? But sometimes all you care about is seeing exactly what's getting sent to your api call and what's being returned - e.g. if the api call is failing, why is that happening? what are the exact params being set?

View file

@ -320,8 +320,6 @@ from litellm import embedding
litellm.vertex_project = "hardy-device-38811" # Your Project ID litellm.vertex_project = "hardy-device-38811" # Your Project ID
litellm.vertex_location = "us-central1" # proj location litellm.vertex_location = "us-central1" # proj location
os.environ['VOYAGE_API_KEY'] = ""
response = embedding( response = embedding(
model="vertex_ai/textembedding-gecko", model="vertex_ai/textembedding-gecko",
input=["good morning from litellm"], input=["good morning from litellm"],

View file

@ -13,7 +13,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
| >=500 | InternalServerError | | >=500 | InternalServerError |
| N/A | ContextWindowExceededError| | N/A | ContextWindowExceededError|
| 400 | ContentPolicyViolationError| | 400 | ContentPolicyViolationError|
| N/A | APIConnectionError | | 500 | APIConnectionError |
Base case we return APIConnectionError Base case we return APIConnectionError
@ -74,6 +74,28 @@ except Exception as e:
``` ```
## Usage - Should you retry exception?
```
import litellm
import openai
try:
response = litellm.completion(
model="gpt-4",
messages=[
{
"role": "user",
"content": "hello, write a 20 pageg essay"
}
],
timeout=0.01, # this will raise a timeout exception
)
except openai.APITimeoutError as e:
should_retry = litellm._should_retry(e.status_code)
print(f"should_retry: {should_retry}")
```
## Details ## Details
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217) To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
@ -86,21 +108,34 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
Base case - we return the original exception. Base case - we return the original exception.
| | ContextWindowExceededError | AuthenticationError | InvalidRequestError | RateLimitError | ServiceUnavailableError | | custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------| |----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
| Anthropic | ✅ | ✅ | ✅ | ✅ | | | openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| OpenAI | ✅ | ✅ |✅ |✅ |✅| | text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅| | custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ | | anthropic | ✓ | ✓ | ✓ | ✓ | | ✓ | | | ✓ | ✓ | |
| Huggingface | ✅ | ✅ | ✅ | ✅ | | | replicate | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | | |
| Openrouter | ✅ | ✅ | ✅ | ✅ | | | bedrock | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | ✓ | |
| AI21 | ✅ | ✅ | ✅ | ✅ | | | sagemaker | | ✓ | ✓ | | | | | | | | |
| VertexAI | | |✅ | | | | vertex_ai | ✓ | | ✓ | | | | ✓ | | | | ✓ |
| Bedrock | | |✅ | | | | palm | ✓ | ✓ | | | | | ✓ | | | | |
| Sagemaker | | |✅ | | | | gemini | ✓ | ✓ | | | | | ✓ | | | | |
| TogetherAI | ✅ | ✅ | ✅ | ✅ | | | cloudflare | | | ✓ | | | ✓ | | | | | |
| AlephAlpha | ✅ | ✅ | ✅ | ✅ | ✅ | | cohere | | ✓ | ✓ | | | ✓ | | | ✓ | | |
| cohere_chat | | ✓ | ✓ | | | ✓ | | | ✓ | | |
| huggingface | ✓ | ✓ | ✓ | | | ✓ | | ✓ | ✓ | | |
| ai21 | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | | | |
| nlp_cloud | ✓ | ✓ | ✓ | | | ✓ | ✓ | ✓ | ✓ | | |
| together_ai | ✓ | ✓ | ✓ | | | ✓ | | | | | |
| aleph_alpha | | | ✓ | | | ✓ | | | | | |
| ollama | ✓ | | ✓ | | | | | | ✓ | | |
| ollama_chat | ✓ | | ✓ | | | | | | ✓ | | |
| vllm | | | | | | ✓ | ✓ | | | | |
| azure | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | | ✓ | | |
- "✓" indicates that the specified `custom_llm_provider` can raise the corresponding exception.
- Empty cells indicate the lack of association or that the provider does not raise that particular exception type as indicated by the function.
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights. > For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.

View file

@ -213,3 +213,349 @@ asyncio.run(loadtest_fn())
``` ```
## Multi-Instance TPM/RPM Load Test (Router)
Test if your defined tpm/rpm limits are respected across multiple instances of the Router object.
In our test:
- Max RPM per deployment is = 100 requests per minute
- Max Throughput / min on router = 200 requests per minute (2 deployments)
- Load we'll send through router = 600 requests per minute
:::info
If you don't want to call a real LLM API endpoint, you can setup a fake openai server. [See code](#extra---setup-fake-openai-server)
:::
### Code
Let's hit the router with 600 requests per minute.
Copy this script 👇. Save it as `test_loadtest_router.py` AND run it with `python3 test_loadtest_router.py`
```python
from litellm import Router
import litellm
litellm.suppress_debug_info = True
litellm.set_verbose = False
import logging
logging.basicConfig(level=logging.CRITICAL)
import os, random, uuid, time, asyncio
# Model list for OpenAI and Anthropic models
model_list = [
{
"model_name": "fake-openai-endpoint",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-fake-key",
"api_base": "http://0.0.0.0:8080",
"rpm": 100
},
},
{
"model_name": "fake-openai-endpoint",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-fake-key",
"api_base": "http://0.0.0.0:8081",
"rpm": 100
},
},
]
router_1 = Router(model_list=model_list, num_retries=0, enable_pre_call_checks=True, routing_strategy="usage-based-routing-v2", redis_host=os.getenv("REDIS_HOST"), redis_port=os.getenv("REDIS_PORT"), redis_password=os.getenv("REDIS_PASSWORD"))
router_2 = Router(model_list=model_list, num_retries=0, routing_strategy="usage-based-routing-v2", enable_pre_call_checks=True, redis_host=os.getenv("REDIS_HOST"), redis_port=os.getenv("REDIS_PORT"), redis_password=os.getenv("REDIS_PASSWORD"))
async def router_completion_non_streaming():
try:
client: Router = random.sample([router_1, router_2], 1)[0] # randomly pick b/w clients
# print(f"client={client}")
response = await client.acompletion(
model="fake-openai-endpoint", # [CHANGE THIS] (if you call it something else on your proxy)
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
)
return response
except Exception as e:
# print(e)
return None
async def loadtest_fn():
start = time.time()
n = 600 # Number of concurrent tasks
tasks = [router_completion_non_streaming() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
print(n, time.time() - start, len(successful_completions))
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
# Run the event loop to execute the async function
async def parent_fn():
for _ in range(10):
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
print(f"triggered new batch - {current_minute}")
await loadtest_fn()
await asyncio.sleep(10)
asyncio.run(parent_fn())
```
## Multi-Instance TPM/RPM Load Test (Proxy)
Test if your defined tpm/rpm limits are respected across multiple instances.
The quickest way to do this is by testing the [proxy](./proxy/quick_start.md). The proxy uses the [router](./routing.md) under the hood, so if you're using either of them, this test should work for you.
In our test:
- Max RPM per deployment is = 100 requests per minute
- Max Throughput / min on proxy = 200 requests per minute (2 deployments)
- Load we'll send to proxy = 600 requests per minute
So we'll send 600 requests per minute, but expect only 200 requests per minute to succeed.
:::info
If you don't want to call a real LLM API endpoint, you can setup a fake openai server. [See code](#extra---setup-fake-openai-server)
:::
### 1. Setup config
```yaml
model_list:
- litellm_params:
api_base: http://0.0.0.0:8080
api_key: my-fake-key
model: openai/my-fake-model
rpm: 100
model_name: fake-openai-endpoint
- litellm_params:
api_base: http://0.0.0.0:8081
api_key: my-fake-key
model: openai/my-fake-model-2
rpm: 100
model_name: fake-openai-endpoint
router_settings:
num_retries: 0
enable_pre_call_checks: true
redis_host: os.environ/REDIS_HOST ## 👈 IMPORTANT! Setup the proxy w/ redis
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
routing_strategy: usage-based-routing-v2
```
### 2. Start proxy 2 instances
**Instance 1**
```bash
litellm --config /path/to/config.yaml --port 4000
## RUNNING on http://0.0.0.0:4000
```
**Instance 2**
```bash
litellm --config /path/to/config.yaml --port 4001
## RUNNING on http://0.0.0.0:4001
```
### 3. Run Test
Let's hit the proxy with 600 requests per minute.
Copy this script 👇. Save it as `test_loadtest_proxy.py` AND run it with `python3 test_loadtest_proxy.py`
```python
from openai import AsyncOpenAI, AsyncAzureOpenAI
import random, uuid
import time, asyncio, litellm
# import logging
# logging.basicConfig(level=logging.DEBUG)
#### LITELLM PROXY ####
litellm_client = AsyncOpenAI(
api_key="sk-1234", # [CHANGE THIS]
base_url="http://0.0.0.0:4000"
)
litellm_client_2 = AsyncOpenAI(
api_key="sk-1234", # [CHANGE THIS]
base_url="http://0.0.0.0:4001"
)
async def proxy_completion_non_streaming():
try:
client = random.sample([litellm_client, litellm_client_2], 1)[0] # randomly pick b/w clients
# print(f"client={client}")
response = await client.chat.completions.create(
model="fake-openai-endpoint", # [CHANGE THIS] (if you call it something else on your proxy)
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
)
return response
except Exception as e:
# print(e)
return None
async def loadtest_fn():
start = time.time()
n = 600 # Number of concurrent tasks
tasks = [proxy_completion_non_streaming() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
print(n, time.time() - start, len(successful_completions))
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
# Run the event loop to execute the async function
async def parent_fn():
for _ in range(10):
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
print(f"triggered new batch - {current_minute}")
await loadtest_fn()
await asyncio.sleep(10)
asyncio.run(parent_fn())
```
### Extra - Setup Fake OpenAI Server
Let's setup a fake openai server with a RPM limit of 100.
Let's call our file `fake_openai_server.py`.
```
# import sys, os
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path
from fastapi import FastAPI, Request, status, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi import FastAPI, Request, HTTPException, UploadFile, File
import httpx, os, json
from openai import AsyncOpenAI
from typing import Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import PlainTextResponse
class ProxyException(Exception):
# NOTE: DO NOT MODIFY THIS
# This is used to map exactly to OPENAI Exceptions
def __init__(
self,
message: str,
type: str,
param: Optional[str],
code: Optional[int],
):
self.message = message
self.type = type
self.param = param
self.code = code
def to_dict(self) -> dict:
"""Converts the ProxyException instance to a dictionary."""
return {
"message": self.message,
"type": self.type,
"param": self.param,
"code": self.code,
}
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(status_code=429,
content={"detail": "Rate Limited!"})
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# for completion
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
@limiter.limit("100/minute")
async def completion(request: Request):
# raise HTTPException(status_code=429, detail="Rate Limited!")
return {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": None,
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?",
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
if __name__ == "__main__":
import socket
import uvicorn
port = 8080
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
print(f"Port {port} is available, starting server...")
break
else:
port += 1
uvicorn.run(app, host="0.0.0.0", port=port)
```
```bash
python3 fake_openai_server.py
```

View file

@ -331,49 +331,25 @@ response = litellm.completion(model="gpt-3.5-turbo", messages=messages, metadata
## Examples ## Examples
### Custom Callback to track costs for Streaming + Non-Streaming ### Custom Callback to track costs for Streaming + Non-Streaming
By default, the response cost is accessible in the logging object via `kwargs["response_cost"]` on success (sync + async)
```python ```python
# Step 1. Write your custom callback function
def track_cost_callback( def track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response, # response from completion completion_response, # response from completion
start_time, end_time # start/end time start_time, end_time # start/end time
): ):
try: try:
# init logging config response_cost = kwargs["response_cost"] # litellm calculates response cost for you
logging.basicConfig( print("regular response_cost", response_cost)
filename='cost.log',
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print("streaming response_cost", response_cost)
logging.info(f"Model {kwargs['model']} Cost: ${response_cost:.8f}")
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
response_cost = litellm.completion_cost(completion_response=completion_response)
print("regular response_cost", response_cost)
logging.info(f"Model {completion_response.model} Cost: ${response_cost:.8f}")
except: except:
pass pass
# Assign the custom callback function # Step 2. Assign the custom callback function
litellm.success_callback = [track_cost_callback] litellm.success_callback = [track_cost_callback]
# Step 3. Make litellm.completion call
response = completion( response = completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[ messages=[

View file

@ -1,4 +1,4 @@
# Greenscale Tutorial # Greenscale - Track LLM Spend and Responsible Usage
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII). [Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).

View file

@ -121,10 +121,12 @@ response = completion(
metadata={ metadata={
"generation_name": "ishaan-test-generation", # set langfuse Generation Name "generation_name": "ishaan-test-generation", # set langfuse Generation Name
"generation_id": "gen-id22", # set langfuse Generation ID "generation_id": "gen-id22", # set langfuse Generation ID
"trace_id": "trace-id22", # set langfuse Trace ID
"trace_user_id": "user-id2", # set langfuse Trace User ID "trace_user_id": "user-id2", # set langfuse Trace User ID
"session_id": "session-1", # set langfuse Session ID "session_id": "session-1", # set langfuse Session ID
"tags": ["tag1", "tag2"] # set langfuse Tags "tags": ["tag1", "tag2"] # set langfuse Tags
"trace_id": "trace-id22", # set langfuse Trace ID
### OR ###
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
}, },
) )
@ -167,6 +169,9 @@ messages = [
chat(messages) chat(messages)
``` ```
## Redacting Messages, Response Content from Langfuse Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
## Troubleshooting & Errors ## Troubleshooting & Errors
### Data not getting logged to Langfuse ? ### Data not getting logged to Langfuse ?

View file

@ -0,0 +1,97 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# OpenMeter - Usage-Based Billing
[OpenMeter](https://openmeter.io/) is an Open Source Usage-Based Billing solution for AI/Cloud applications. It integrates with Stripe for easy billing.
<Image img={require('../../img/openmeter.png')} />
:::info
We want to learn how we can make the callbacks better! Meet the LiteLLM [founders](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version) or
join our [discord](https://discord.gg/wuPM9dRgDw)
:::
## Quick Start
Use just 2 lines of code, to instantly log your responses **across all providers** with OpenMeter
Get your OpenMeter API Key from https://openmeter.cloud/meters
```python
litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls to openmeter
```
<Tabs>
<TabItem value="sdk" label="SDK">
```python
# pip install langfuse
import litellm
import os
# from https://openmeter.cloud
os.environ["OPENMETER_API_ENDPOINT"] = ""
os.environ["OPENMETER_API_KEY"] = ""
# LLM API Keys
os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse
litellm.success_callback = ["openmeter"]
# openai call
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
]
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to Config.yaml
```yaml
model_list:
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
model_name: fake-openai-endpoint
litellm_settings:
success_callback: ["openmeter"] # 👈 KEY CHANGE
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "fake-openai-endpoint",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
</TabItem>
</Tabs>
<Image img={require('../../img/openmeter_img_2.png')} />

View file

@ -40,5 +40,9 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
print(response) print(response)
``` ```
## Redacting Messages, Response Content from Sentry Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to sentry, but request metadata will still be logged.
[Let us know](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+) if you need any additional options from Sentry. [Let us know](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+) if you need any additional options from Sentry.

View file

@ -535,7 +535,8 @@ print(response)
| Model Name | Function Call | | Model Name | Function Call |
|----------------------|---------------------------------------------| |----------------------|---------------------------------------------|
| Titan Embeddings - G1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` | | Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` |
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` | | Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` | | Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |

View file

@ -477,6 +477,36 @@ print(response)
| code-gecko@latest| `completion('code-gecko@latest', messages)` | | code-gecko@latest| `completion('code-gecko@latest', messages)` |
## Embedding Models
#### Usage - Embedding
```python
import litellm
from litellm import embedding
litellm.vertex_project = "hardy-device-38811" # Your Project ID
litellm.vertex_location = "us-central1" # proj location
response = embedding(
model="vertex_ai/textembedding-gecko",
input=["good morning from litellm"],
)
print(response)
```
#### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
## Extra ## Extra
### Using `GOOGLE_APPLICATION_CREDENTIALS` ### Using `GOOGLE_APPLICATION_CREDENTIALS`
@ -520,6 +550,12 @@ def load_vertex_ai_credentials():
### Using GCP Service Account ### Using GCP Service Account
:::info
Trying to deploy LiteLLM on Google Cloud Run? Tutorial [here](https://docs.litellm.ai/docs/proxy/deploy#deploy-on-google-cloud-run)
:::
1. Figure out the Service Account bound to the Google Cloud Run service 1. Figure out the Service Account bound to the Google Cloud Run service
<Image img={require('../../img/gcp_acc_1.png')} /> <Image img={require('../../img/gcp_acc_1.png')} />

View file

@ -4,6 +4,13 @@ LiteLLM supports all models on VLLM.
🚀[Code Tutorial](https://github.com/BerriAI/litellm/blob/main/cookbook/VLLM_Model_Testing.ipynb) 🚀[Code Tutorial](https://github.com/BerriAI/litellm/blob/main/cookbook/VLLM_Model_Testing.ipynb)
:::info
To call a HOSTED VLLM Endpoint use [these docs](./openai_compatible.md)
:::
### Quick Start ### Quick Start
``` ```
pip install litellm vllm pip install litellm vllm

View file

@ -1,13 +1,13 @@
# Slack Alerting # 🚨 Alerting
Get alerts for: Get alerts for:
- hanging LLM api calls - Hanging LLM api calls
- failed LLM api calls - Failed LLM api calls
- slow LLM api calls - Slow LLM api calls
- budget Tracking per key/user: - Budget Tracking per key/user:
- When a User/Key crosses their Budget - When a User/Key crosses their Budget
- When a User/Key is 15% away from crossing their Budget - When a User/Key is 15% away from crossing their Budget
- failed db read/writes - Failed db read/writes
## Quick Start ## Quick Start

View file

@ -62,9 +62,11 @@ model_list:
litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py
drop_params: True drop_params: True
success_callback: ["langfuse"] # OPTIONAL - if you want to start sending LLM Logs to Langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your env
general_settings: general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
alerting: ["slack"] # [OPTIONAL] If you want Slack Alerts for Hanging LLM requests, Slow llm responses, Budget Alerts. Make sure to set `SLACK_WEBHOOK_URL` in your env
``` ```
:::info :::info

View file

@ -11,40 +11,37 @@ You can find the Dockerfile to build litellm proxy [here](https://github.com/Ber
<TabItem value="basic" label="Basic"> <TabItem value="basic" label="Basic">
**Step 1. Create a file called `litellm_config.yaml`** ### Step 1. CREATE config.yaml
Example `litellm_config.yaml` (the `os.environ/` prefix means litellm will read `AZURE_API_BASE` from the env) Example `litellm_config.yaml`
```yaml
model_list:
- model_name: azure-gpt-3.5
litellm_params:
model: azure/<your-azure-model-deployment>
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
```
**Step 2. Run litellm docker image** ```yaml
model_list:
- model_name: azure-gpt-3.5
litellm_params:
model: azure/<your-azure-model-deployment>
api_base: os.environ/AZURE_API_BASE # runs os.getenv("AZURE_API_BASE")
api_key: os.environ/AZURE_API_KEY # runs os.getenv("AZURE_API_KEY")
api_version: "2023-07-01-preview"
```
See the latest available ghcr docker image here:
https://github.com/berriai/litellm/pkgs/container/litellm
Your litellm config.yaml should be called `litellm_config.yaml` in the directory you run this command.
The `-v` command will mount that file
Pass `AZURE_API_KEY` and `AZURE_API_BASE` since we set them in step 1 ### Step 2. RUN Docker Image
```shell ```shell
docker run \ docker run \
-v $(pwd)/litellm_config.yaml:/app/config.yaml \ -v $(pwd)/litellm_config.yaml:/app/config.yaml \
-e AZURE_API_KEY=d6*********** \ -e AZURE_API_KEY=d6*********** \
-e AZURE_API_BASE=https://openai-***********/ \ -e AZURE_API_BASE=https://openai-***********/ \
-p 4000:4000 \ -p 4000:4000 \
ghcr.io/berriai/litellm:main-latest \ ghcr.io/berriai/litellm:main-latest \
--config /app/config.yaml --detailed_debug --config /app/config.yaml --detailed_debug
``` ```
**Step 3. Send a Test Request** Get Latest Image 👉 [here](https://github.com/berriai/litellm/pkgs/container/litellm)
### Step 3. TEST Request
Pass `model=azure-gpt-3.5` this was set on step 1 Pass `model=azure-gpt-3.5` this was set on step 1
@ -272,26 +269,63 @@ Your OpenAI proxy server is now running on `http://0.0.0.0:4000`.
#### Step 1. Create deployment.yaml #### Step 1. Create deployment.yaml
```yaml ```yaml
apiVersion: apps/v1 apiVersion: apps/v1
kind: Deployment kind: Deployment
metadata: metadata:
name: litellm-deployment name: litellm-deployment
spec: spec:
replicas: 1 replicas: 3
selector: selector:
matchLabels: matchLabels:
app: litellm app: litellm
template: template:
metadata: metadata:
labels: labels:
app: litellm app: litellm
spec: spec:
containers: containers:
- name: litellm-container - name: litellm-container
image: ghcr.io/berriai/litellm-database:main-latest image: ghcr.io/berriai/litellm:main-latest
env: imagePullPolicy: Always
- name: DATABASE_URL env:
value: postgresql://<user>:<password>@<host>:<port>/<dbname> - name: AZURE_API_KEY
value: "d6******"
- name: AZURE_API_BASE
value: "https://ope******"
- name: LITELLM_MASTER_KEY
value: "sk-1234"
- name: DATABASE_URL
value: "po**********"
args:
- "--config"
- "/app/proxy_config.yaml" # Update the path to mount the config file
volumeMounts: # Define volume mount for proxy_config.yaml
- name: config-volume
mountPath: /app
readOnly: true
livenessProbe:
httpGet:
path: /health/liveliness
port: 4000
initialDelaySeconds: 120
periodSeconds: 15
successThreshold: 1
failureThreshold: 3
timeoutSeconds: 10
readinessProbe:
httpGet:
path: /health/readiness
port: 4000
initialDelaySeconds: 120
periodSeconds: 15
successThreshold: 1
failureThreshold: 3
timeoutSeconds: 10
volumes: # Define volume to mount proxy_config.yaml
- name: config-volume
configMap:
name: litellm-config
``` ```
```bash ```bash

View file

@ -10,6 +10,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
- [Async Custom Callbacks](#custom-callback-class-async) - [Async Custom Callbacks](#custom-callback-class-async)
- [Async Custom Callback APIs](#custom-callback-apis-async) - [Async Custom Callback APIs](#custom-callback-apis-async)
- [Logging to Langfuse](#logging-proxy-inputoutput---langfuse) - [Logging to Langfuse](#logging-proxy-inputoutput---langfuse)
- [Logging to OpenMeter](#logging-proxy-inputoutput---langfuse)
- [Logging to s3 Buckets](#logging-proxy-inputoutput---s3-buckets) - [Logging to s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
- [Logging to DataDog](#logging-proxy-inputoutput---datadog) - [Logging to DataDog](#logging-proxy-inputoutput---datadog)
- [Logging to DynamoDB](#logging-proxy-inputoutput---dynamodb) - [Logging to DynamoDB](#logging-proxy-inputoutput---dynamodb)
@ -401,7 +402,7 @@ litellm_settings:
Start the LiteLLM Proxy and make a test request to verify the logs reached your callback API Start the LiteLLM Proxy and make a test request to verify the logs reached your callback API
## Logging Proxy Input/Output - Langfuse ## Logging Proxy Input/Output - Langfuse
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment
**Step 1** Install langfuse **Step 1** Install langfuse
@ -419,7 +420,13 @@ litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
``` ```
**Step 3**: Start the proxy, make a test request **Step 3**: Set required env variables for logging to langfuse
```shell
export LANGFUSE_PUBLIC_KEY="pk_kk"
export LANGFUSE_SECRET_KEY="sk_ss
```
**Step 4**: Start the proxy, make a test request
Start proxy Start proxy
```shell ```shell
@ -569,6 +576,75 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
All requests made with these keys will log data to their team-specific logging. All requests made with these keys will log data to their team-specific logging.
### Redacting Messages, Response Content from Langfuse Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["langfuse"]
turn_off_message_logging: True
```
## Logging Proxy Cost + Usage - OpenMeter
Bill customers according to their LLM API usage with [OpenMeter](../observability/openmeter.md)
**Required Env Variables**
```bash
# from https://openmeter.cloud
export OPENMETER_API_ENDPOINT="" # defaults to https://openmeter.cloud
export OPENMETER_API_KEY=""
```
### Quick Start
1. Add to Config.yaml
```yaml
model_list:
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
model_name: fake-openai-endpoint
litellm_settings:
success_callback: ["openmeter"] # 👈 KEY CHANGE
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "fake-openai-endpoint",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
<Image img={require('../../img/openmeter_img_2.png')} />
## Logging Proxy Input/Output - DataDog ## Logging Proxy Input/Output - DataDog
We will use the `--config` to set `litellm.success_callback = ["datadog"]` this will log all successfull LLM calls to DataDog We will use the `--config` to set `litellm.success_callback = ["datadog"]` this will log all successfull LLM calls to DataDog
@ -838,39 +914,72 @@ Test Request
litellm --test litellm --test
``` ```
## Logging Proxy Input/Output Traceloop (OpenTelemetry) ## Logging Proxy Input/Output in OpenTelemetry format using Traceloop's OpenLLMetry
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format [OpenLLMetry](https://github.com/traceloop/openllmetry) _(built and maintained by Traceloop)_ is a set of extensions
built on top of [OpenTelemetry](https://opentelemetry.io/) that gives you complete observability over your LLM
application. Because it uses OpenTelemetry under the
hood, [it can be connected to various observability solutions](https://www.traceloop.com/docs/openllmetry/integrations/introduction)
like:
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop * [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop)
* [Axiom](https://www.traceloop.com/docs/openllmetry/integrations/axiom)
* [Azure Application Insights](https://www.traceloop.com/docs/openllmetry/integrations/azure)
* [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
* [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace)
* [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana)
* [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb)
* [HyperDX](https://www.traceloop.com/docs/openllmetry/integrations/hyperdx)
* [Instana](https://www.traceloop.com/docs/openllmetry/integrations/instana)
* [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic)
* [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
* [Service Now Cloud Observability](https://www.traceloop.com/docs/openllmetry/integrations/service-now)
* [Sentry](https://www.traceloop.com/docs/openllmetry/integrations/sentry)
* [SigNoz](https://www.traceloop.com/docs/openllmetry/integrations/signoz)
* [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk)
**Step 1** Install traceloop-sdk and set Traceloop API key We will use the `--config` to set `litellm.success_callback = ["traceloop"]` to achieve this, steps are listed below.
**Step 1:** Install the SDK
```shell ```shell
pip install traceloop-sdk -U pip install traceloop-sdk
``` ```
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog) **Step 2:** Configure Environment Variable for trace exporting
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
You will need to configure where to export your traces. Environment variables will control this, example: For Traceloop
you should use `TRACELOOP_API_KEY`, whereas for Datadog you use `TRACELOOP_BASE_URL`. For more
visit [the Integrations Catalog](https://www.traceloop.com/docs/openllmetry/integrations/introduction).
If you are using Datadog as the observability solutions then you can set `TRACELOOP_BASE_URL` as:
```shell
TRACELOOP_BASE_URL=http://<datadog-agent-hostname>:4318
```
**Step 3**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
api_key: my-fake-key # replace api_key with actual key
litellm_settings: litellm_settings:
success_callback: ["traceloop"] success_callback: [ "traceloop" ]
``` ```
**Step 3**: Start the proxy, make a test request **Step 4**: Start the proxy, make a test request
Start proxy Start proxy
```shell ```shell
litellm --config config.yaml --debug litellm --config config.yaml --debug
``` ```
Test Request Test Request
``` ```
curl --location 'http://0.0.0.0:4000/chat/completions' \ curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \ --header 'Content-Type: application/json' \
@ -927,4 +1036,4 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
} }
] ]
}' }'
``` ```

View file

@ -3,34 +3,38 @@ import TabItem from '@theme/TabItem';
# ⚡ Best Practices for Production # ⚡ Best Practices for Production
Expected Performance in Production ## 1. Use this config.yaml
Use this config.yaml in production (with your own LLMs)
1 LiteLLM Uvicorn Worker on Kubernetes
| Description | Value |
|--------------|-------|
| Avg latency | `50ms` |
| Median latency | `51ms` |
| `/chat/completions` Requests/second | `35` |
| `/chat/completions` Requests/minute | `2100` |
| `/chat/completions` Requests/hour | `126K` |
## 1. Switch off Debug Logging
Remove `set_verbose: True` from your config.yaml
```yaml ```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
master_key: sk-1234 # enter your own master key, ensure it starts with 'sk-'
alerting: ["slack"] # Setup slack alerting - get alerts on LLM exceptions, Budget Alerts, Slow LLM Responses
proxy_batch_write_at: 60 # Batch write spend updates every 60s
litellm_settings: litellm_settings:
set_verbose: True set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
``` ```
You should only see the following level of details in logs on the proxy server Set slack webhook url in your env
```shell ```shell
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/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"
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
``` ```
:::info
Need Help or want dedicated support ? Talk to a founder [here]: (https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD] ## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
@ -40,21 +44,12 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"] CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
``` ```
## 3. Batch write spend updates every 60s
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally. ## 3. Use Redis 'port','host', 'password'. NOT 'redis_url'
In production, we recommend using a longer interval period of 60s. This reduces the number of connections used to make DB writes. If you decide to use Redis, DO NOT use 'redis_url'. We recommend usig redis port, host, and password params.
```yaml `redis_url`is 80 RPS slower
general_settings:
master_key: sk-1234
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
```
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188) This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
@ -69,103 +64,31 @@ router_settings:
redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
``` ```
## 5. Switch off resetting budgets ## Extras
### Expected Performance in Production
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database) 1 LiteLLM Uvicorn Worker on Kubernetes
```yaml
general_settings:
disable_reset_budget: true
```
## 6. Move spend logs to separate server (BETA) | Description | Value |
|--------------|-------|
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server. | Avg latency | `50ms` |
| Median latency | `51ms` |
👉 [LiteLLM Spend Logs Server](https://github.com/BerriAI/litellm/tree/main/litellm-js/spend-logs) | `/chat/completions` Requests/second | `35` |
| `/chat/completions` Requests/minute | `2100` |
| `/chat/completions` Requests/hour | `126K` |
**Spend Logs** ### Verifying Debugging logs are off
This is a log of the key, tokens, model, and latency for each call on the proxy.
[**Full Payload**](https://github.com/BerriAI/litellm/blob/8c9623a6bc4ad9da0a2dac64249a60ed8da719e8/litellm/proxy/utils.py#L1769) You should only see the following level of details in logs on the proxy server
```shell
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
**1. Start the spend logs server** # INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
```bash
docker run -p 3000:3000 \
-e DATABASE_URL="postgres://.." \
ghcr.io/berriai/litellm-spend_logs:main-latest
# RUNNING on http://0.0.0.0:3000
```
**2. Connect to proxy**
Example litellm_config.yaml
```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
master_key: sk-1234
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
```
Add `SPEND_LOGS_URL` as an environment variable when starting the proxy
```bash
docker run \
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
-e DATABASE_URL="postgresql://.." \
-e SPEND_LOGS_URL="http://host.docker.internal:3000" \ # 👈 KEY CHANGE
-p 4000:4000 \
ghcr.io/berriai/litellm:main-latest \
--config /app/config.yaml --detailed_debug
# Running on http://0.0.0.0:4000
```
**3. Test Proxy!**
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "fake-openai-endpoint",
"messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
In your LiteLLM Spend Logs Server, you should see
**Expected Response**
```
Received and stored 1 logs. Total logs in memory: 1
...
Flushed 1 log to the DB.
``` ```
### Machine Specification ### Machine Specifications to Deploy LiteLLM
A t2.micro should be sufficient to handle 1k logs / minute on this server.
This consumes at max 120MB, and <0.1 vCPU.
## Machine Specifications to Deploy LiteLLM
| Service | Spec | CPUs | Memory | Architecture | Version| | Service | Spec | CPUs | Memory | Architecture | Version|
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
@ -173,7 +96,7 @@ This consumes at max 120MB, and <0.1 vCPU.
| Redis Cache | - | - | - | - | 7.0+ Redis Engine| | Redis Cache | - | - | - | - | 7.0+ Redis Engine|
## Reference Kubernetes Deployment YAML ### Reference Kubernetes Deployment YAML
Reference Kubernetes `deployment.yaml` that was load tested by us Reference Kubernetes `deployment.yaml` that was load tested by us

View file

@ -278,6 +278,36 @@ router_settings:
routing_strategy_args: {"ttl": 10} routing_strategy_args: {"ttl": 10}
``` ```
### Set Lowest Latency Buffer
Set a buffer within which deployments are candidates for making calls to.
E.g.
if you have 5 deployments
```
https://litellm-prod-1.openai.azure.com/: 0.07s
https://litellm-prod-2.openai.azure.com/: 0.1s
https://litellm-prod-3.openai.azure.com/: 0.1s
https://litellm-prod-4.openai.azure.com/: 0.1s
https://litellm-prod-5.openai.azure.com/: 4.66s
```
to prevent initially overloading `prod-1`, with all requests - we can set a buffer of 50%, to consider deployments `prod-2, prod-3, prod-4`.
**In Router**
```python
router = Router(..., routing_strategy_args={"lowest_latency_buffer": 0.5})
```
**In Proxy**
```yaml
router_settings:
routing_strategy_args: {"lowest_latency_buffer": 0.5}
```
</TabItem> </TabItem>
<TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)"> <TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)">
@ -443,6 +473,35 @@ asyncio.run(router_acompletion())
## Basic Reliability ## Basic Reliability
### Max Parallel Requests (ASYNC)
Used in semaphore for async requests on router. Limit the max concurrent calls made to a deployment. Useful in high-traffic scenarios.
If tpm/rpm is set, and no max parallel request limit given, we use the RPM or calculated RPM (tpm/1000/6) as the max parallel request limit.
```python
from litellm import Router
model_list = [{
"model_name": "gpt-4",
"litellm_params": {
"model": "azure/gpt-4",
...
"max_parallel_requests": 10 # 👈 SET PER DEPLOYMENT
}
}]
### OR ###
router = Router(model_list=model_list, default_max_parallel_requests=20) # 👈 SET DEFAULT MAX PARALLEL REQUESTS
# deployment max parallel requests > default max parallel requests
```
[**See Code**](https://github.com/BerriAI/litellm/blob/a978f2d8813c04dad34802cb95e0a0e35a3324bc/litellm/utils.py#L5605)
### Timeouts ### Timeouts
The timeout set in router is for the entire length of the call, and is passed down to the completion() call level as well. The timeout set in router is for the entire length of the call, and is passed down to the completion() call level as well.
@ -557,6 +616,57 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}") print(f"response: {response}")
``` ```
#### Retries based on Error Type
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
Example:
- 4 retries for `ContentPolicyViolationError`
- 0 retries for `RateLimitErrors`
Example Usage
```python
from litellm.router import RetryPolicy
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
BadRequestErrorRetries=1,
TimeoutErrorRetries=2,
RateLimitErrorRetries=3,
)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "bad-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
],
retry_policy=retry_policy,
)
response = await router.acompletion(
model=model,
messages=messages,
)
```
### Fallbacks ### Fallbacks
If a call fails after num_retries, fall back to another model group. If a call fails after num_retries, fall back to another model group.

View file

@ -5,6 +5,9 @@ LiteLLM allows you to specify the following:
* API Base * API Base
* API Version * API Version
* API Type * API Type
* Project
* Location
* Token
Useful Helper functions: Useful Helper functions:
* [`check_valid_key()`](#check_valid_key) * [`check_valid_key()`](#check_valid_key)
@ -43,6 +46,24 @@ os.environ['AZURE_API_TYPE'] = "azure" # [OPTIONAL]
os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/" os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/"
``` ```
### Setting Project, Location, Token
For cloud providers:
- Azure
- Bedrock
- GCP
- Watson AI
you might need to set additional parameters. LiteLLM provides a common set of params, that we map across all providers.
| | LiteLLM param | Watson | Vertex AI | Azure | Bedrock |
|------|--------------|--------------|--------------|--------------|--------------|
| Project | project | watsonx_project | vertex_project | n/a | n/a |
| Region | region_name | watsonx_region_name | vertex_location | n/a | aws_region_name |
| Token | token | watsonx_token or token | n/a | azure_ad_token | n/a |
If you want, you can call them by their provider-specific params as well.
## litellm variables ## litellm variables
### litellm.api_key ### litellm.api_key

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 533 KiB

View file

@ -43,6 +43,12 @@ const sidebars = {
"proxy/user_keys", "proxy/user_keys",
"proxy/enterprise", "proxy/enterprise",
"proxy/virtual_keys", "proxy/virtual_keys",
"proxy/alerting",
{
type: "category",
label: "Logging",
items: ["proxy/logging", "proxy/streaming_logging"],
},
"proxy/team_based_routing", "proxy/team_based_routing",
"proxy/ui", "proxy/ui",
"proxy/cost_tracking", "proxy/cost_tracking",
@ -58,11 +64,6 @@ const sidebars = {
"proxy/pii_masking", "proxy/pii_masking",
"proxy/prompt_injection", "proxy/prompt_injection",
"proxy/caching", "proxy/caching",
{
type: "category",
label: "Logging, Alerting",
items: ["proxy/logging", "proxy/alerting", "proxy/streaming_logging"],
},
"proxy/prometheus", "proxy/prometheus",
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
@ -169,6 +170,7 @@ const sidebars = {
"observability/custom_callback", "observability/custom_callback",
"observability/langfuse_integration", "observability/langfuse_integration",
"observability/sentry", "observability/sentry",
"observability/openmeter",
"observability/promptlayer_integration", "observability/promptlayer_integration",
"observability/wandb_integration", "observability/wandb_integration",
"observability/langsmith_integration", "observability/langsmith_integration",
@ -176,7 +178,7 @@ const sidebars = {
"observability/traceloop_integration", "observability/traceloop_integration",
"observability/athina_integration", "observability/athina_integration",
"observability/lunary_integration", "observability/lunary_integration",
"observability/athina_integration", "observability/greenscale_integration",
"observability/helicone_integration", "observability/helicone_integration",
"observability/supabase_integration", "observability/supabase_integration",
`observability/telemetry`, `observability/telemetry`,

View file

@ -5,7 +5,7 @@
"packages": { "packages": {
"": { "": {
"dependencies": { "dependencies": {
"@hono/node-server": "^1.9.0", "@hono/node-server": "^1.10.1",
"hono": "^4.2.7" "hono": "^4.2.7"
}, },
"devDependencies": { "devDependencies": {
@ -382,9 +382,9 @@
} }
}, },
"node_modules/@hono/node-server": { "node_modules/@hono/node-server": {
"version": "1.9.0", "version": "1.10.1",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.9.0.tgz", "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.10.1.tgz",
"integrity": "sha512-oJjk7WXBlENeHhWiMqSyxPIZ3Kmf5ZYxqdlcSIXyN8Rn50bNJsPl99G4POBS03Jxh56FdfRJ0SEnC8mAVIiavQ==", "integrity": "sha512-5BKW25JH5PQKPDkTcIgv3yNUPtOAbnnjFFgWvIxxAY/B/ZNeYjjWoAeDmqhIiCgOAJ3Tauuw+0G+VainhuZRYQ==",
"engines": { "engines": {
"node": ">=18.14.1" "node": ">=18.14.1"
} }

View file

@ -3,7 +3,7 @@
"dev": "tsx watch src/index.ts" "dev": "tsx watch src/index.ts"
}, },
"dependencies": { "dependencies": {
"@hono/node-server": "^1.9.0", "@hono/node-server": "^1.10.1",
"hono": "^4.2.7" "hono": "^4.2.7"
}, },
"devDependencies": { "devDependencies": {

View file

@ -2,7 +2,7 @@
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger from litellm._logging import set_verbose, _turn_on_debug, verbose_logger, json_logs
from litellm.proxy._types import ( from litellm.proxy._types import (
KeyManagementSystem, KeyManagementSystem,
KeyManagementSettings, KeyManagementSettings,
@ -22,6 +22,7 @@ success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_custom_logger_compatible_callbacks: list = ["openmeter"]
_langfuse_default_tags: Optional[ _langfuse_default_tags: Optional[
List[ List[
Literal[ Literal[
@ -45,6 +46,7 @@ _async_failure_callback: List[Callable] = (
) # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
## end of callbacks ############# ## end of callbacks #############
email: Optional[str] = ( email: Optional[str] = (
@ -58,6 +60,7 @@ max_tokens = 256 # OpenAI Defaults
drop_params = False drop_params = False
modify_params = False modify_params = False
retry = True retry = True
### AUTH ###
api_key: Optional[str] = None api_key: Optional[str] = None
openai_key: Optional[str] = None openai_key: Optional[str] = None
azure_key: Optional[str] = None azure_key: Optional[str] = None
@ -76,6 +79,10 @@ cloudflare_api_key: Optional[str] = None
baseten_key: Optional[str] = None baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
common_cloud_provider_auth_params: dict = {
"params": ["project", "region_name", "token"],
"providers": ["vertex_ai", "bedrock", "watsonx", "azure"],
}
use_client: bool = False use_client: bool = False
ssl_verify: bool = True ssl_verify: bool = True
disable_streaming_logging: bool = False disable_streaming_logging: bool = False
@ -535,7 +542,11 @@ models_by_provider: dict = {
"together_ai": together_ai_models, "together_ai": together_ai_models,
"baseten": baseten_models, "baseten": baseten_models,
"openrouter": openrouter_models, "openrouter": openrouter_models,
"vertex_ai": vertex_chat_models + vertex_text_models, "vertex_ai": vertex_chat_models
+ vertex_text_models
+ vertex_anthropic_models
+ vertex_vision_models
+ vertex_language_models,
"ai21": ai21_models, "ai21": ai21_models,
"bedrock": bedrock_models, "bedrock": bedrock_models,
"petals": petals_models, "petals": petals_models,
@ -594,7 +605,6 @@ all_embedding_models = (
####### IMAGE GENERATION MODELS ################### ####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = ["dall-e-2", "dall-e-3"] openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout from .timeout import timeout
from .utils import ( from .utils import (
client, client,
@ -602,6 +612,8 @@ from .utils import (
get_optional_params, get_optional_params,
modify_integration, modify_integration,
token_counter, token_counter,
create_pretrained_tokenizer,
create_tokenizer,
cost_per_token, cost_per_token,
completion_cost, completion_cost,
supports_function_calling, supports_function_calling,
@ -625,6 +637,7 @@ from .utils import (
get_secret, get_secret,
get_supported_openai_params, get_supported_openai_params,
get_api_base, get_api_base,
get_first_chars_messages,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
@ -654,6 +667,7 @@ from .llms.bedrock import (
AmazonLlamaConfig, AmazonLlamaConfig,
AmazonStabilityConfig, AmazonStabilityConfig,
AmazonMistralConfig, AmazonMistralConfig,
AmazonBedrockGlobalConfig,
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
@ -680,3 +694,4 @@ from .exceptions import (
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server
from .router import Router from .router import Router
from .assistants.main import *

View file

@ -1,7 +1,7 @@
import logging import logging
set_verbose = False set_verbose = False
json_logs = False
# Create a handler for the logger (you may need to adapt this based on your needs) # Create a handler for the logger (you may need to adapt this based on your needs)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG) handler.setLevel(logging.DEBUG)

495
litellm/assistants/main.py Normal file
View file

@ -0,0 +1,495 @@
# What is this?
## Main file for assistants API logic
from typing import Iterable
import os
import litellm
from openai import OpenAI
from litellm import client
from litellm.utils import supports_httpx_timeout
from ..llms.openai import OpenAIAssistantsAPI
from ..types.llms.openai import *
from ..types.router import *
####### ENVIRONMENT VARIABLES ###################
openai_assistants_api = OpenAIAssistantsAPI()
### ASSISTANTS ###
def get_assistants(
custom_llm_provider: Literal["openai"],
client: Optional[OpenAI] = None,
**kwargs,
) -> SyncCursorPage[Assistant]:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[SyncCursorPage[Assistant]] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.get_assistants(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
### THREADS ###
def create_thread(
custom_llm_provider: Literal["openai"],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
metadata: Optional[dict] = None,
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
client: Optional[OpenAI] = None,
**kwargs,
) -> Thread:
"""
- get the llm provider
- if openai - route it there
- pass through relevant params
```
from litellm import create_thread
create_thread(
custom_llm_provider="openai",
### OPTIONAL ###
messages = {
"role": "user",
"content": "Hello, what is AI?"
},
{
"role": "user",
"content": "How does AI work? Explain it in simple terms."
}]
)
```
"""
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.create_thread(
messages=messages,
metadata=metadata,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
def get_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[OpenAI] = None,
**kwargs,
) -> Thread:
"""Get the thread object, given a thread_id"""
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.get_thread(
thread_id=thread_id,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
### MESSAGES ###
def add_message(
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[OpenAI] = None,
**kwargs,
) -> OpenAIMessage:
### COMMON OBJECTS ###
message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata
)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[OpenAIMessage] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.add_message(
thread_id=thread_id,
message_data=message_data,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
def get_messages(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[OpenAI] = None,
**kwargs,
) -> SyncCursorPage[OpenAIMessage]:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[SyncCursorPage[OpenAIMessage]] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.get_messages(
thread_id=thread_id,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
### RUNS ###
def run_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[OpenAI] = None,
**kwargs,
) -> Run:
"""Run a given thread + assistant."""
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Run] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.run_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response

View file

@ -177,11 +177,18 @@ class RedisCache(BaseCache):
try: try:
# asyncio.get_running_loop().create_task(self.ping()) # asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping()) result = asyncio.get_running_loop().create_task(self.ping())
except Exception: except Exception as e:
pass verbose_logger.error(
"Error connecting to Async Redis client", extra={"error": str(e)}
)
### SYNC HEALTH PING ### ### SYNC HEALTH PING ###
self.redis_client.ping() try:
self.redis_client.ping()
except Exception as e:
verbose_logger.error(
"Error connecting to Sync Redis client", extra={"error": str(e)}
)
def init_async_client(self): def init_async_client(self):
from ._redis import get_redis_async_client from ._redis import get_redis_async_client

View file

@ -12,9 +12,12 @@ import litellm
class LangFuseLogger: class LangFuseLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self, langfuse_public_key=None, langfuse_secret=None): def __init__(
self, langfuse_public_key=None, langfuse_secret=None, flush_interval=1
):
try: try:
from langfuse import Langfuse from langfuse import Langfuse
import langfuse
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m" f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
@ -25,14 +28,20 @@ class LangFuseLogger:
self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
self.langfuse_release = os.getenv("LANGFUSE_RELEASE") self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG") self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
self.Langfuse = Langfuse(
public_key=self.public_key, parameters = {
secret_key=self.secret_key, "public_key": self.public_key,
host=self.langfuse_host, "secret_key": self.secret_key,
release=self.langfuse_release, "host": self.langfuse_host,
debug=self.langfuse_debug, "release": self.langfuse_release,
flush_interval=1, # flush interval in seconds "debug": self.langfuse_debug,
) "flush_interval": flush_interval, # flush interval in seconds
}
if Version(langfuse.version.__version__) >= Version("2.6.0"):
parameters["sdk_integration"] = "litellm"
self.Langfuse = Langfuse(**parameters)
# set the current langfuse project id in the environ # set the current langfuse project id in the environ
# this is used by Alerting to link to the correct project # this is used by Alerting to link to the correct project
@ -77,7 +86,7 @@ class LangFuseLogger:
print_verbose, print_verbose,
level="DEFAULT", level="DEFAULT",
status_message=None, status_message=None,
): ) -> dict:
# Method definition # Method definition
try: try:
@ -138,8 +147,10 @@ class LangFuseLogger:
input = prompt input = prompt
output = response_obj["data"] output = response_obj["data"]
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}") print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
trace_id = None
generation_id = None
if self._is_langfuse_v2(): if self._is_langfuse_v2():
self._log_langfuse_v2( trace_id, generation_id = self._log_langfuse_v2(
user_id, user_id,
metadata, metadata,
litellm_params, litellm_params,
@ -169,10 +180,12 @@ class LangFuseLogger:
f"Langfuse Layer Logging - final response object: {response_obj}" f"Langfuse Layer Logging - final response object: {response_obj}"
) )
verbose_logger.info(f"Langfuse Layer Logging - logging success") verbose_logger.info(f"Langfuse Layer Logging - logging success")
return {"trace_id": trace_id, "generation_id": generation_id}
except: except:
traceback.print_exc() traceback.print_exc()
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
pass return {"trace_id": None, "generation_id": None}
async def _async_log_event( async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
@ -244,7 +257,7 @@ class LangFuseLogger:
response_obj, response_obj,
level, level,
print_verbose, print_verbose,
): ) -> tuple:
import langfuse import langfuse
try: try:
@ -263,22 +276,28 @@ class LangFuseLogger:
tags = metadata_tags tags = metadata_tags
trace_name = metadata.get("trace_name", None) trace_name = metadata.get("trace_name", None)
if trace_name is None: trace_id = metadata.get("trace_id", None)
existing_trace_id = metadata.get("existing_trace_id", None)
if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name # just log `litellm-{call_type}` as the trace name
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}" trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
trace_params = { if existing_trace_id is not None:
"name": trace_name, trace_params = {"id": existing_trace_id}
"input": input, else: # don't overwrite an existing trace
"user_id": metadata.get("trace_user_id", user_id), trace_params = {
"id": metadata.get("trace_id", None), "name": trace_name,
"session_id": metadata.get("session_id", None), "input": input,
} "user_id": metadata.get("trace_user_id", user_id),
"id": trace_id,
"session_id": metadata.get("session_id", None),
}
if level == "ERROR": if level == "ERROR":
trace_params["status_message"] = output trace_params["status_message"] = output
else: else:
trace_params["output"] = output trace_params["output"] = output
cost = kwargs.get("response_cost", None) cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}") print_verbose(f"trace: {cost}")
@ -336,7 +355,8 @@ class LangFuseLogger:
kwargs["cache_hit"] = False kwargs["cache_hit"] = False
tags.append(f"cache_hit:{kwargs['cache_hit']}") tags.append(f"cache_hit:{kwargs['cache_hit']}")
clean_metadata["cache_hit"] = kwargs["cache_hit"] clean_metadata["cache_hit"] = kwargs["cache_hit"]
trace_params.update({"tags": tags}) if existing_trace_id is None:
trace_params.update({"tags": tags})
proxy_server_request = litellm_params.get("proxy_server_request", None) proxy_server_request = litellm_params.get("proxy_server_request", None)
if proxy_server_request: if proxy_server_request:
@ -356,8 +376,6 @@ class LangFuseLogger:
"headers": clean_headers, "headers": clean_headers,
} }
print_verbose(f"trace_params: {trace_params}")
trace = self.Langfuse.trace(**trace_params) trace = self.Langfuse.trace(**trace_params)
generation_id = None generation_id = None
@ -407,8 +425,9 @@ class LangFuseLogger:
"completion_start_time", None "completion_start_time", None
) )
print_verbose(f"generation_params: {generation_params}") generation_client = trace.generation(**generation_params)
trace.generation(**generation_params) return generation_client.trace_id, generation_id
except Exception as e: except Exception as e:
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None

View file

@ -73,10 +73,6 @@ class LangsmithLogger:
elif type(value) != dict and is_serializable(value=value): elif type(value) != dict and is_serializable(value=value):
new_kwargs[key] = value new_kwargs[key] = value
print(f"type of response: {type(response_obj)}")
for k, v in new_kwargs.items():
print(f"key={k}, type of arg: {type(v)}, value={v}")
if isinstance(response_obj, BaseModel): if isinstance(response_obj, BaseModel):
try: try:
response_obj = response_obj.model_dump() response_obj = response_obj.model_dump()

View file

@ -0,0 +1,131 @@
# What is this?
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
import dotenv, os, json
import requests
import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class OpenMeterLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
OPENMETER_API_ENDPOINT,
OPENMETER_API_KEY,
in the environment
"""
missing_keys = []
if os.getenv("OPENMETER_API_KEY", None) is None:
missing_keys.append("OPENMETER_API_KEY")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj):
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
subject = kwargs.get("user", None), # end-user passed in via 'user' param
if not subject:
raise Exception("OpenMeter: user is required")
return {
"specversion": "1.0",
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
"id": call_id,
"time": dt,
"subject": subject,
"source": "litellm-proxy",
"data": {"model": model, "cost": cost, **usage},
}
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("OPENMETER_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/cloudevents+json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("OPENMETER_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/cloudevents+json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e

View file

@ -12,6 +12,7 @@ from litellm.caching import DualCache
import asyncio import asyncio
import aiohttp import aiohttp
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import datetime
class SlackAlerting: class SlackAlerting:
@ -47,7 +48,6 @@ class SlackAlerting:
self.internal_usage_cache = DualCache() self.internal_usage_cache = DualCache()
self.async_http_handler = AsyncHTTPHandler() self.async_http_handler = AsyncHTTPHandler()
self.alert_to_webhook_url = alert_to_webhook_url self.alert_to_webhook_url = alert_to_webhook_url
pass pass
def update_values( def update_values(
@ -93,39 +93,14 @@ class SlackAlerting:
request_info: str, request_info: str,
request_data: Optional[dict] = None, request_data: Optional[dict] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request",
start_time: Optional[datetime.datetime] = None,
end_time: Optional[datetime.datetime] = None,
): ):
import uuid # do nothing for now
pass
# For now: do nothing as we're debugging why this is not working as expected
return request_info return request_info
# if request_data is not None:
# trace_id = request_data.get("metadata", {}).get(
# "trace_id", None
# ) # get langfuse trace id
# if trace_id is None:
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
# request_data["metadata"]["trace_id"] = trace_id
# elif kwargs is not None:
# _litellm_params = kwargs.get("litellm_params", {})
# trace_id = _litellm_params.get("metadata", {}).get(
# "trace_id", None
# ) # get langfuse trace id
# if trace_id is None:
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
# _litellm_params["metadata"]["trace_id"] = trace_id
# _langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
# _langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
# # langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
# _langfuse_url = (
# f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
# )
# request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
# return request_info
def _response_taking_too_long_callback( def _response_taking_too_long_callback(
self, self,
kwargs, # kwargs to completion kwargs, # kwargs to completion
@ -167,6 +142,14 @@ class SlackAlerting:
_deployment_latencies = metadata["_latency_per_deployment"] _deployment_latencies = metadata["_latency_per_deployment"]
if len(_deployment_latencies) == 0: if len(_deployment_latencies) == 0:
return None return None
try:
# try sorting deployments by latency
_deployment_latencies = sorted(
_deployment_latencies.items(), key=lambda x: x[1]
)
_deployment_latencies = dict(_deployment_latencies)
except:
pass
for api_base, latency in _deployment_latencies.items(): for api_base, latency in _deployment_latencies.items():
_message_to_send += f"\n{api_base}: {round(latency,2)}s" _message_to_send += f"\n{api_base}: {round(latency,2)}s"
_message_to_send = "```" + _message_to_send + "```" _message_to_send = "```" + _message_to_send + "```"
@ -192,10 +175,6 @@ class SlackAlerting:
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold: if time_difference_float > self.alerting_threshold:
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, kwargs=kwargs
)
# add deployment latencies to alert # add deployment latencies to alert
if ( if (
kwargs is not None kwargs is not None
@ -222,8 +201,8 @@ class SlackAlerting:
async def response_taking_too_long( async def response_taking_too_long(
self, self,
start_time: Optional[float] = None, start_time: Optional[datetime.datetime] = None,
end_time: Optional[float] = None, end_time: Optional[datetime.datetime] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request", type: Literal["hanging_request", "slow_response"] = "hanging_request",
request_data: Optional[dict] = None, request_data: Optional[dict] = None,
): ):
@ -243,10 +222,6 @@ class SlackAlerting:
except: except:
messages = "" messages = ""
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, request_data=request_data
)
else: else:
request_info = "" request_info = ""
@ -288,6 +263,15 @@ class SlackAlerting:
f"`Requests are hanging - {self.alerting_threshold}s+ request time`" f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
) )
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info,
request_data=request_data,
type="hanging_request",
start_time=start_time,
end_time=end_time,
)
# add deployment latencies to alert # add deployment latencies to alert
_deployment_latency_map = self._get_deployment_latencies_to_alert( _deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=request_data.get("metadata", {}) metadata=request_data.get("metadata", {})

View file

@ -84,6 +84,51 @@ class AnthropicConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
if (
value == "\n"
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
return optional_params
# makes headers for API call # makes headers for API call
def validate_environment(api_key, user_headers): def validate_environment(api_key, user_headers):
@ -142,7 +187,7 @@ class AnthropicChatCompletion(BaseLLM):
elif len(completion_response["content"]) == 0: elif len(completion_response["content"]) == 0:
raise AnthropicError( raise AnthropicError(
message="No content in response", message="No content in response",
status_code=response.status_code, status_code=500,
) )
else: else:
text_content = "" text_content = ""

View file

@ -96,6 +96,15 @@ class AzureOpenAIConfig(OpenAIConfig):
top_p, top_p,
) )
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "token":
optional_params["azure_ad_token"] = value
return optional_params
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = { # azure_client_params = {
@ -142,7 +151,7 @@ class AzureChatCompletion(BaseLLM):
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
print_verbose: Callable, print_verbose: Callable,
timeout, timeout: Union[float, httpx.Timeout],
logging_obj, logging_obj,
optional_params, optional_params,
litellm_params, litellm_params,

View file

@ -4,7 +4,13 @@ from enum import Enum
import time, uuid import time, uuid
from typing import Callable, Optional, Any, Union, List from typing import Callable, Optional, Any, Union, List
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse from litellm.utils import (
ModelResponse,
get_secret,
Usage,
ImageResponse,
map_finish_reason,
)
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -29,6 +35,24 @@ class BedrockError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AmazonBedrockGlobalConfig:
def __init__(self):
pass
def get_mapped_special_auth_params(self) -> dict:
"""
Mapping of common auth params across bedrock/vertex/azure/watsonx
"""
return {"region_name": "aws_region_name"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
class AmazonTitanConfig: class AmazonTitanConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -139,8 +163,10 @@ class AmazonAnthropicClaude3Config:
"stop", "stop",
"temperature", "temperature",
"top_p", "top_p",
"extra_headers"
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens": if param == "max_tokens":
@ -506,6 +532,15 @@ class AmazonStabilityConfig:
} }
def add_custom_header(headers):
"""Closure to capture the headers and add them."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value)
return callback
def init_bedrock_client( def init_bedrock_client(
region_name=None, region_name=None,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
@ -515,12 +550,12 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None, aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None, aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None, aws_role_name: Optional[str] = None,
timeout: Optional[int] = None, extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None) standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in ## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check # Define the list of parameters to check
params_to_check = [ params_to_check = [
@ -574,7 +609,14 @@ def init_bedrock_client(
import boto3 import boto3
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config(
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config()
### CHECK STS ### ### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None: if aws_role_name is not None and aws_session_name is not None:
@ -629,6 +671,8 @@ def init_bedrock_client(
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
) )
if extra_headers:
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
return client return client
@ -692,6 +736,7 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
timeout=None, timeout=None,
extra_headers: Optional[dict] = None,
): ):
exception_mapping_worked = False exception_mapping_worked = False
_is_function_call = False _is_function_call = False
@ -721,6 +766,7 @@ def completion(
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name, aws_profile_name=aws_profile_name,
extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )
@ -934,7 +980,7 @@ def completion(
original_response=json.dumps(response_body), original_response=json.dumps(response_body),
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response_body}")
## RESPONSE OBJECT ## RESPONSE OBJECT
outputText = "default" outputText = "default"
if provider == "ai21": if provider == "ai21":
@ -1025,7 +1071,9 @@ def completion(
logging_obj=logging_obj, logging_obj=logging_obj,
) )
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = map_finish_reason(
response_body["stop_reason"]
)
_usage = litellm.Usage( _usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"], prompt_tokens=response_body["usage"]["input_tokens"],
completion_tokens=response_body["usage"]["output_tokens"], completion_tokens=response_body["usage"]["output_tokens"],
@ -1047,6 +1095,7 @@ def completion(
outputText = response_body.get("results")[0].get("outputText") outputText = response_body.get("results")[0].get("outputText")
response_metadata = response.get("ResponseMetadata", {}) response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400: if response_metadata.get("HTTPStatusCode", 500) >= 400:
raise BedrockError( raise BedrockError(
message=outputText, message=outputText,
@ -1082,11 +1131,13 @@ def completion(
prompt_tokens = response_metadata.get( prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt)) "x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
) )
_text_response = model_response["choices"][0]["message"].get("content", "")
completion_tokens = response_metadata.get( completion_tokens = response_metadata.get(
"x-amzn-bedrock-output-token-count", "x-amzn-bedrock-output-token-count",
len( len(
encoding.encode( encoding.encode(
model_response["choices"][0]["message"].get("content", "") _text_response,
disallowed_special=(),
) )
), ),
) )

View file

@ -1,3 +1,4 @@
from itertools import chain
import requests, types, time import requests, types, time
import json, uuid import json, uuid
import traceback import traceback
@ -212,18 +213,20 @@ def get_ollama_response(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"arguments": response_json["response"], "name": ""}, "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
@ -254,8 +257,34 @@ def ollama_completion_stream(url, data, logging_obj):
custom_llm_provider="ollama", custom_llm_provider="ollama",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
for transformed_chunk in streamwrapper: # If format is JSON, this was a function call
yield transformed_chunk # Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
response_content = "".join(
chunk.choices[0].delta.content
for chunk in chain([first_chunk], streamwrapper)
if chunk.choices[0].delta.content
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: except Exception as e:
raise e raise e
@ -277,8 +306,36 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
custom_llm_provider="ollama", custom_llm_provider="ollama",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
async for transformed_chunk in streamwrapper:
yield transformed_chunk # If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper)
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content]
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e raise e
@ -310,20 +367,19 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": { "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"arguments": response_json["response"],
"name": "",
},
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json[ model_response["choices"][0]["message"]["content"] = response_json[
"response" "response"

View file

@ -1,3 +1,4 @@
from itertools import chain
import requests, types, time import requests, types, time
import json, uuid import json, uuid
import traceback import traceback
@ -285,20 +286,19 @@ def get_ollama_response(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["message"]["content"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": { "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"arguments": response_json["message"]["content"],
"name": "",
},
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"] = response_json["message"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
@ -337,8 +337,35 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
custom_llm_provider="ollama_chat", custom_llm_provider="ollama_chat",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
for transformed_chunk in streamwrapper:
yield transformed_chunk # If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
response_content = "".join(
chunk.choices[0].delta.content
for chunk in chain([first_chunk], streamwrapper)
if chunk.choices[0].delta.content
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: except Exception as e:
raise e raise e
@ -368,8 +395,36 @@ async def ollama_async_streaming(
custom_llm_provider="ollama_chat", custom_llm_provider="ollama_chat",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
async for transformed_chunk in streamwrapper:
yield transformed_chunk # If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper)
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content]
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -415,20 +470,19 @@ async def ollama_acompletion(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["message"]["content"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": { "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"arguments": response_json["message"]["content"],
"name": function_name or "",
},
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"] = response_json["message"]

View file

@ -1,4 +1,13 @@
from typing import Optional, Union, Any, BinaryIO from typing import (
Optional,
Union,
Any,
BinaryIO,
Literal,
Iterable,
)
from typing_extensions import override
from pydantic import BaseModel
import types, time, json, traceback import types, time, json, traceback
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
@ -17,6 +26,7 @@ import aiohttp, requests
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
class OpenAIError(Exception): class OpenAIError(Exception):
@ -246,7 +256,7 @@ class OpenAIChatCompletion(BaseLLM):
def completion( def completion(
self, self,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: Union[float, httpx.Timeout],
model: Optional[str] = None, model: Optional[str] = None,
messages: Optional[list] = None, messages: Optional[list] = None,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
@ -271,9 +281,12 @@ class OpenAIChatCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float): if not isinstance(timeout, float) and not isinstance(
timeout, httpx.Timeout
):
raise OpenAIError( raise OpenAIError(
status_code=422, message=f"Timeout needs to be a float" status_code=422,
message=f"Timeout needs to be a float or httpx.Timeout",
) )
if custom_llm_provider != "openai": if custom_llm_provider != "openai":
@ -425,7 +438,7 @@ class OpenAIChatCompletion(BaseLLM):
self, self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -447,6 +460,7 @@ class OpenAIChatCompletion(BaseLLM):
) )
else: else:
openai_aclient = client openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
@ -479,7 +493,7 @@ class OpenAIChatCompletion(BaseLLM):
def streaming( def streaming(
self, self,
logging_obj, logging_obj,
timeout: float, timeout: Union[float, httpx.Timeout],
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -523,7 +537,7 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj,
timeout: float, timeout: Union[float, httpx.Timeout],
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -1232,3 +1246,223 @@ class OpenAITextCompletion(BaseLLM):
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ###
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> SyncCursorPage[Assistant]:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.list()
return response
### MESSAGES ###
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAIMessage:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create(
thread_id, **message_data
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> SyncCursorPage[OpenAIMessage]:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
"""
Here's an example:
```
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> Thread:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
def delete_thread(self):
pass
### RUNS ###
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> Run:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.runs.create_and_poll(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response

View file

@ -3,9 +3,25 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import Optional, Any from typing import (
from typing import List Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
import litellm import litellm
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
from litellm.types.llms.anthropic import *
import uuid
def default_pt(messages): def default_pt(messages):
@ -16,6 +32,41 @@ def prompt_injection_detection_default_pt():
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not.""" return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
def map_system_message_pt(messages: list) -> list:
"""
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
Enabled via `completion(...,supports_system_message=False)`
If next message is a user message or assistant message -> merge system prompt into it
if next message is system -> append a user message instead of the system message
"""
new_messages = []
for i, m in enumerate(messages):
if m["role"] == "system":
if i < len(messages) - 1: # Not the last message
next_m = messages[i + 1]
next_role = next_m["role"]
if (
next_role == "user" or next_role == "assistant"
): # Next message is a user or assistant message
# Merge system prompt into the next message
next_m["content"] = m["content"] + " " + next_m["content"]
elif next_role == "system": # Next message is a system message
# Append a user message instead of the system message
new_message = {"role": "user", "content": m["content"]}
new_messages.append(new_message)
else: # Last message
new_message = {"role": "user", "content": m["content"]}
new_messages.append(new_message)
else: # Not a system message
new_messages.append(m)
return new_messages
# alpaca prompt template - for models like mythomax, etc. # alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages): def alpaca_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
@ -430,8 +481,10 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
prompt = default_pt(messages) prompt = default_pt(messages)
return prompt return prompt
### IBM Granite ### IBM Granite
def ibm_granite_pt(messages: list): def ibm_granite_pt(messages: list):
""" """
IBM's Granite chat models uses the template: IBM's Granite chat models uses the template:
@ -440,15 +493,15 @@ def ibm_granite_pt(messages: list):
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
""" """
return custom_prompt( return custom_prompt(
messages=messages, messages=messages,
role_dict={ role_dict={
'system': { "system": {
'pre_message': '<|system|>\n', "pre_message": "<|system|>\n",
'post_message': '\n', "post_message": "\n",
}, },
'user': { "user": {
'pre_message': '<|user|>\n', "pre_message": "<|user|>\n",
'post_message': '\n', "post_message": "\n",
}, },
'assistant': { 'assistant': {
'pre_message': '<|assistant|>\n', 'pre_message': '<|assistant|>\n',
@ -458,6 +511,7 @@ def ibm_granite_pt(messages: list):
final_prompt_value='<|assistant|>\n', final_prompt_value='<|assistant|>\n',
) )
### ANTHROPIC ### ### ANTHROPIC ###
@ -749,15 +803,9 @@ def anthropic_messages_pt_xml(messages: list):
assistant_content = [] assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ## ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
# Handle assistant messages as string, none, or list of text-content dictionaries. assistant_text = (
if isinstance(messages[msg_i].get("content"), list): messages[msg_i].get("content") or ""
assistant_text = '' ) # either string or none
for content in messages[msg_i]["content"]:
if content.get("type") == "text":
assistant_text += content["text"]
else:
assistant_text = messages[msg_i].get("content") or ""
if messages[msg_i].get( if messages[msg_i].get(
"tool_calls", [] "tool_calls", []
): # support assistant tool invoke convertion ): # support assistant tool invoke convertion
@ -803,6 +851,13 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
"name": "get_current_weather", "name": "get_current_weather",
"content": "function result goes here", "content": "function result goes here",
}, },
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
""" """
""" """
@ -819,18 +874,42 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
] ]
} }
""" """
tool_call_id = message.get("tool_call_id") if message["role"] == "tool":
content = message.get("content") tool_call_id = message.get("tool_call_id")
content = message.get("content")
# We can't determine from openai message format whether it's a successful or # We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template # error call result so default to the successful result template
anthropic_tool_result = { anthropic_tool_result = {
"type": "tool_result", "type": "tool_result",
"tool_use_id": tool_call_id, "tool_use_id": tool_call_id,
"content": content, "content": content,
} }
return anthropic_tool_result
elif message["role"] == "function":
content = message.get("content")
anthropic_tool_result = {
"type": "tool_result",
"tool_use_id": str(uuid.uuid4()),
"content": content,
}
return anthropic_tool_result
return {}
return anthropic_tool_result
def convert_function_to_anthropic_tool_invoke(function_call):
try:
anthropic_tool_invoke = [
{
"type": "tool_use",
"id": str(uuid.uuid4()),
"name": get_attribute_or_key(function_call, "name"),
"input": json.loads(get_attribute_or_key(function_call, "arguments")),
}
]
return anthropic_tool_invoke
except Exception as e:
raise e
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list: def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
@ -893,7 +972,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
def anthropic_messages_pt(messages: list): def anthropic_messages_pt(messages: list):
""" """
format messages for anthropic format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant) 1. Anthropic supports roles like "user" and "assistant" (system prompt sent separately)
2. The first message always needs to be of role "user" 2. The first message always needs to be of role "user"
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm) 3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise) 4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
@ -901,12 +980,14 @@ def anthropic_messages_pt(messages: list):
6. Ensure we only accept role, content. (message.name is not supported) 6. Ensure we only accept role, content. (message.name is not supported)
""" """
# add role=tool support to allow function call result/error submission # add role=tool support to allow function call result/error submission
user_message_types = {"user", "tool"} user_message_types = {"user", "tool", "function"}
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them. # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
new_messages = [] new_messages = []
msg_i = 0 msg_i = 0
tool_use_param = False
while msg_i < len(messages): while msg_i < len(messages):
user_content = [] user_content = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ## ## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
if isinstance(messages[msg_i]["content"], list): if isinstance(messages[msg_i]["content"], list):
@ -922,7 +1003,10 @@ def anthropic_messages_pt(messages: list):
) )
elif m.get("type", "") == "text": elif m.get("type", "") == "text":
user_content.append({"type": "text", "text": m["text"]}) user_content.append({"type": "text", "text": m["text"]})
elif messages[msg_i]["role"] == "tool": elif (
messages[msg_i]["role"] == "tool"
or messages[msg_i]["role"] == "function"
):
# OpenAI's tool message content will always be a string # OpenAI's tool message content will always be a string
user_content.append(convert_to_anthropic_tool_result(messages[msg_i])) user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
else: else:
@ -951,11 +1035,24 @@ def anthropic_messages_pt(messages: list):
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"]) convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
) )
if messages[msg_i].get("function_call"):
assistant_content.extend(
convert_function_to_anthropic_tool_invoke(
messages[msg_i]["function_call"]
)
)
msg_i += 1 msg_i += 1
if assistant_content: if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content}) new_messages.append({"role": "assistant", "content": assistant_content})
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
if not new_messages or new_messages[0]["role"] != "user": if not new_messages or new_messages[0]["role"] != "user":
if litellm.modify_params: if litellm.modify_params:
new_messages.insert( new_messages.insert(
@ -967,11 +1064,14 @@ def anthropic_messages_pt(messages: list):
) )
if new_messages[-1]["role"] == "assistant": if new_messages[-1]["role"] == "assistant":
for content in new_messages[-1]["content"]: if isinstance(new_messages[-1]["content"], str):
if isinstance(content, dict) and content["type"] == "text": new_messages[-1]["content"] = new_messages[-1]["content"].rstrip()
content["text"] = content[ elif isinstance(new_messages[-1]["content"], list):
"text" for content in new_messages[-1]["content"]:
].rstrip() # no trailing whitespace for final assistant message if isinstance(content, dict) and content["type"] == "text":
content["text"] = content[
"text"
].rstrip() # no trailing whitespace for final assistant message
return new_messages return new_messages
@ -1050,6 +1150,30 @@ def get_system_prompt(messages):
return system_prompt, messages return system_prompt, messages
def convert_to_documents(
observations: Any,
) -> List[MutableMapping]:
"""Converts observations into a 'document' dict"""
documents: List[MutableMapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
return documents
def convert_openai_message_to_cohere_tool_result(message): def convert_openai_message_to_cohere_tool_result(message):
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
@ -1091,7 +1215,7 @@ def convert_openai_message_to_cohere_tool_result(message):
"parameters": {"location": "San Francisco, CA"}, "parameters": {"location": "San Francisco, CA"},
"generation_id": tool_call_id, "generation_id": tool_call_id,
}, },
"outputs": [content], "outputs": convert_to_documents(content),
} }
return cohere_tool_result return cohere_tool_result
@ -1104,7 +1228,7 @@ def cohere_message_pt(messages: list):
if message["role"] == "tool": if message["role"] == "tool":
tool_result = convert_openai_message_to_cohere_tool_result(message) tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result) tool_results.append(tool_result)
else: elif message.get("content"):
prompt += message["content"] + "\n\n" prompt += message["content"] + "\n\n"
prompt = prompt.rstrip() prompt = prompt.rstrip()
return prompt, tool_results return prompt, tool_results

View file

@ -184,6 +184,20 @@ class VertexAIConfig:
pass pass
return optional_params return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
import asyncio import asyncio
@ -529,7 +543,7 @@ def completion(
"instances": instances, "instances": instances,
"vertex_location": vertex_location, "vertex_location": vertex_location,
"vertex_project": vertex_project, "vertex_project": vertex_project,
"safety_settings":safety_settings, "safety_settings": safety_settings,
**optional_params, **optional_params,
} }
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
@ -1025,6 +1039,7 @@ async def async_streaming(
instances=None, instances=None,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
safety_settings=None,
**optional_params, **optional_params,
): ):
""" """
@ -1051,6 +1066,7 @@ async def async_streaming(
response = await llm_model._generate_content_streaming_async( response = await llm_model._generate_content_streaming_async(
contents=content, contents=content,
generation_config=optional_params, generation_config=optional_params,
safety_settings=safety_settings,
tools=tools, tools=tools,
) )

View file

@ -14,7 +14,7 @@ from .prompt_templates import factory as ptf
class WatsonXAIError(Exception): class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: str = None): def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com" url = url or "https://https://us-south.ml.cloud.ibm.com"
@ -74,7 +74,6 @@ class IBMWatsonXAIConfig:
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
truncate_input_tokens: Optional[int] = None truncate_input_tokens: Optional[int] = None
include_stop_sequences: Optional[bool] = False include_stop_sequences: Optional[bool] = False
return_options: Optional[dict] = None
return_options: Optional[Dict[str, bool]] = None return_options: Optional[Dict[str, bool]] = None
random_seed: Optional[int] = None # e.g 42 random_seed: Optional[int] = None # e.g 42
moderations: Optional[dict] = None moderations: Optional[dict] = None
@ -133,6 +132,24 @@ class IBMWatsonXAIConfig:
"stream", # equivalent to stream "stream", # equivalent to stream
] ]
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {
"project": "watsonx_project",
"region_name": "watsonx_region_name",
"token": "watsonx_token",
}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
@ -162,6 +179,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
) )
return prompt return prompt
class WatsonXAIEndpoint(str, Enum): class WatsonXAIEndpoint(str, Enum):
TEXT_GENERATION = "/ml/v1/text/generation" TEXT_GENERATION = "/ml/v1/text/generation"
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream" TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
@ -172,6 +190,7 @@ class WatsonXAIEndpoint(str, Enum):
EMBEDDINGS = "/ml/v1/text/embeddings" EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts" PROMPTS = "/ml/v1/prompts"
class IBMWatsonXAI(BaseLLM): class IBMWatsonXAI(BaseLLM):
""" """
Class to interface with IBM watsonx.ai API for text generation and embeddings. Class to interface with IBM watsonx.ai API for text generation and embeddings.
@ -190,7 +209,7 @@ class IBMWatsonXAI(BaseLLM):
prompt: str, prompt: str,
stream: bool, stream: bool,
optional_params: dict, optional_params: dict,
print_verbose: Callable = None, print_verbose: Optional[Callable] = None,
) -> dict: ) -> dict:
""" """
Get the request parameters for text generation. Get the request parameters for text generation.
@ -224,9 +243,9 @@ class IBMWatsonXAI(BaseLLM):
) )
deployment_id = "/".join(model_id.split("/")[1:]) deployment_id = "/".join(model_id.split("/")[1:])
endpoint = ( endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
if stream if stream
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
) )
endpoint = endpoint.format(deployment_id=deployment_id) endpoint = endpoint.format(deployment_id=deployment_id)
else: else:
@ -242,23 +261,37 @@ class IBMWatsonXAI(BaseLLM):
method="POST", url=url, headers=headers, json=payload, params=request_params method="POST", url=url, headers=headers, json=payload, params=request_params
) )
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: def _get_api_params(
self, params: dict, print_verbose: Optional[Callable] = None
) -> dict:
""" """
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
""" """
# Load auth variables from params # Load auth variables from params
url = params.pop("url", None) url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None) api_key = params.pop("apikey", None)
token = params.pop("token", None) token = params.pop("token", None)
project_id = params.pop("project_id", None) # watsonx.ai project_id project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None)) region_name = params.pop("region_name", params.pop("region", None))
wx_credentials = params.pop("wx_credentials", None) if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", IBMWatsonXAI.api_version) api_version = params.pop("api_version", IBMWatsonXAI.api_version)
# Load auth variables from environment variables # Load auth variables from environment variables
if url is None: if url is None:
url = ( url = (
get_secret("WATSONX_URL") get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret("WATSONX_URL")
or get_secret("WX_URL") or get_secret("WX_URL")
or get_secret("WML_URL") or get_secret("WML_URL")
) )
@ -296,7 +329,12 @@ class IBMWatsonXAI(BaseLLM):
api_key = wx_credentials.get( api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key) "apikey", wx_credentials.get("api_key", api_key)
) )
token = wx_credentials.get("token", token) token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present # verify that all required credentials are present
if url is None: if url is None:
@ -345,7 +383,7 @@ class IBMWatsonXAI(BaseLLM):
acompletion: bool = None, acompletion: bool = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
logger_fn=None, logger_fn=None,
timeout: float = None, timeout: Optional[float] = None,
): ):
""" """
Send a text generation request to the IBM Watsonx.ai API. Send a text generation request to the IBM Watsonx.ai API.
@ -381,10 +419,14 @@ class IBMWatsonXAI(BaseLLM):
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
model_response.usage = Usage( setattr(
prompt_tokens=prompt_tokens, model_response,
completion_tokens=completion_tokens, "usage",
total_tokens=prompt_tokens + completion_tokens, Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
return model_response return model_response

View file

@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
@ -33,9 +34,12 @@ from litellm.utils import (
async_mock_completion_streaming_obj, async_mock_completion_streaming_obj,
convert_to_model_response_object, convert_to_model_response_object,
token_counter, token_counter,
create_pretrained_tokenizer,
create_tokenizer,
Usage, Usage,
get_optional_params_embeddings, get_optional_params_embeddings,
get_optional_params_image_gen, get_optional_params_image_gen,
supports_httpx_timeout,
) )
from .llms import ( from .llms import (
anthropic_text, anthropic_text,
@ -75,6 +79,7 @@ from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
map_system_message_pt,
) )
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -363,7 +368,7 @@ def mock_completion(
model: str, model: str,
messages: List, messages: List,
stream: Optional[bool] = False, stream: Optional[bool] = False,
mock_response: str = "This is a mock request", mock_response: Union[str, Exception] = "This is a mock request",
logging=None, logging=None,
**kwargs, **kwargs,
): ):
@ -390,6 +395,20 @@ def mock_completion(
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion. - If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
""" """
try: try:
## LOGGING
if logging is not None:
logging.pre_call(
input=messages,
api_key="mock-key",
)
if isinstance(mock_response, Exception):
raise litellm.APIError(
status_code=500, # type: ignore
message=str(mock_response),
llm_provider="openai", # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
model_response = ModelResponse(stream=stream) model_response = ModelResponse(stream=stream)
if stream is True: if stream is True:
# don't try to access stream object, # don't try to access stream object,
@ -436,7 +455,7 @@ def completion(
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [], messages: List = [],
timeout: Optional[Union[float, int]] = None, timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
@ -539,6 +558,7 @@ def completion(
eos_token = kwargs.get("eos_token", None) eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None) preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None) hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
### TEXT COMPLETION CALLS ### ### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False) text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False) atext_completion = kwargs.get("atext_completion", False)
@ -604,6 +624,7 @@ def completion(
"model_list", "model_list",
"num_retries", "num_retries",
"context_window_fallback_dict", "context_window_fallback_dict",
"retry_policy",
"roles", "roles",
"final_prompt_value", "final_prompt_value",
"bos_token", "bos_token",
@ -629,16 +650,27 @@ def completion(
"no-log", "no-log",
"base_model", "base_model",
"stream_timeout", "stream_timeout",
"supports_system_message",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
if timeout is None:
timeout = ( ### TIMEOUT LOGIC ###
kwargs.get("request_timeout", None) or 600 timeout = timeout or kwargs.get("request_timeout", 600) or 600
) # set timeout for 10 minutes by default # set timeout for 10 minutes by default
timeout = float(timeout)
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try: try:
if base_url is not None: if base_url is not None:
api_base = base_url api_base = base_url
@ -733,6 +765,13 @@ def completion(
custom_prompt_dict[model]["bos_token"] = bos_token custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token: if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token custom_prompt_dict[model]["eos_token"] = eos_token
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
and supports_system_message == False
):
messages = map_system_message_pt(messages=messages)
model_api_key = get_api_key( model_api_key = get_api_key(
llm_provider=custom_llm_provider, dynamic_api_key=api_key llm_provider=custom_llm_provider, dynamic_api_key=api_key
) # get the api key from the environment if required for the model ) # get the api key from the environment if required for the model
@ -859,7 +898,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
) )
@ -1000,7 +1039,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization, organization=organization,
@ -1085,7 +1124,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
@ -1459,7 +1498,7 @@ def completion(
acompletion=acompletion, acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -1552,7 +1591,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout, # type: ignore
) )
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
@ -1832,6 +1871,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )
@ -1875,7 +1915,7 @@ def completion(
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
@ -2261,7 +2301,7 @@ def batch_completion(
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stop=None, stop=None,
max_tokens: Optional[float] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
@ -2655,6 +2695,7 @@ def embedding(
"model_list", "model_list",
"num_retries", "num_retries",
"context_window_fallback_dict", "context_window_fallback_dict",
"retry_policy",
"roles", "roles",
"final_prompt_value", "final_prompt_value",
"bos_token", "bos_token",
@ -3525,6 +3566,7 @@ def image_generation(
"model_list", "model_list",
"num_retries", "num_retries",
"context_window_fallback_dict", "context_window_fallback_dict",
"retry_policy",
"roles", "roles",
"final_prompt_value", "final_prompt_value",
"bos_token", "bos_token",

View file

@ -338,6 +338,18 @@
"output_cost_per_second": 0.0001, "output_cost_per_second": 0.0001,
"litellm_provider": "azure" "litellm_provider": "azure"
}, },
"azure/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"azure/gpt-4-0125-preview": { "azure/gpt-4-0125-preview": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -813,6 +825,7 @@
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 264 "tool_use_system_prompt_tokens": 264
}, },
"claude-3-opus-20240229": { "claude-3-opus-20240229": {
@ -824,6 +837,7 @@
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 395 "tool_use_system_prompt_tokens": 395
}, },
"claude-3-sonnet-20240229": { "claude-3-sonnet-20240229": {
@ -835,6 +849,7 @@
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159 "tool_use_system_prompt_tokens": 159
}, },
"text-bison": { "text-bison": {
@ -1142,7 +1157,8 @@
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"vertex_ai/claude-3-haiku@20240307": { "vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1152,7 +1168,8 @@
"output_cost_per_token": 0.00000125, "output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"vertex_ai/claude-3-opus@20240229": { "vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1162,7 +1179,8 @@
"output_cost_per_token": 0.0000075, "output_cost_per_token": 0.0000075,
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"textembedding-gecko": { "textembedding-gecko": {
"max_tokens": 3072, "max_tokens": 3072,
@ -1418,6 +1436,123 @@
"litellm_provider": "replicate", "litellm_provider": "replicate",
"mode": "chat" "mode": "chat"
}, },
"replicate/meta/llama-2-13b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000005,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-2-13b-chat": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000005,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-2-70b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-2-70b-chat": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-2-7b-chat": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-70b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-70b-instruct": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000065,
"output_cost_per_token": 0.00000275,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-8b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/meta/llama-3-8b-instruct": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/mistralai/mistral-7b-v0.1": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/mistralai/mistral-7b-instruct-v0.2": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
"mode": "chat"
},
"replicate/mistralai/mixtral-8x7b-instruct-v0.1": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000003,
"output_cost_per_token": 0.000001,
"litellm_provider": "replicate",
"mode": "chat"
},
"openrouter/openai/gpt-3.5-turbo": { "openrouter/openai/gpt-3.5-turbo": {
"max_tokens": 4095, "max_tokens": 4095,
"input_cost_per_token": 0.0000015, "input_cost_per_token": 0.0000015,
@ -1455,6 +1590,18 @@
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat" "mode": "chat"
}, },
"openrouter/anthropic/claude-3-opus": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 395
},
"openrouter/google/palm-2-chat-bison": { "openrouter/google/palm-2-chat-bison": {
"max_tokens": 8000, "max_tokens": 8000,
"input_cost_per_token": 0.0000005, "input_cost_per_token": 0.0000005,
@ -1685,6 +1832,15 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "embedding" "mode": "embedding"
}, },
"amazon.titan-embed-text-v2:0": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"output_vector_size": 1024,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "embedding"
},
"mistral.mistral-7b-instruct-v0:2": { "mistral.mistral-7b-instruct-v0:2": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 32000, "max_input_tokens": 32000,
@ -1801,7 +1957,8 @@
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"anthropic.claude-3-haiku-20240307-v1:0": { "anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1811,7 +1968,8 @@
"output_cost_per_token": 0.00000125, "output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"anthropic.claude-3-opus-20240229-v1:0": { "anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1821,7 +1979,8 @@
"output_cost_per_token": 0.000075, "output_cost_per_token": 0.000075,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"anthropic.claude-v1": { "anthropic.claude-v1": {
"max_tokens": 8191, "max_tokens": 8191,
@ -2379,6 +2538,24 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"meta.llama3-8b-instruct-v1:0": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000006,
"litellm_provider": "bedrock",
"mode": "chat"
},
"meta.llama3-70b-instruct-v1:0": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000265,
"output_cost_per_token": 0.0000035,
"litellm_provider": "bedrock",
"mode": "chat"
},
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": { "512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77, "max_tokens": 77,
"max_input_tokens": 77, "max_input_tokens": 77,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/5e699db73bf6f8c2.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}(); !function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/00c2ddbcd01819c0.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-ccae12a25017afa5.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-dafd44dfa2da140c.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e49705773ae41779.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-ccae12a25017afa5.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/5e699db73bf6f8c2.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[27125,[\"447\",\"static/chunks/447-9f8d32190ff7d16d.js\",\"931\",\"static/chunks/app/page-781ca5f151d78d1d.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/5e699db73bf6f8c2.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"PtTtxXIYvdjQsvRgdITlk\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html> <!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-202e312607f242a1.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-dafd44dfa2da140c.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e49705773ae41779.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-202e312607f242a1.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/00c2ddbcd01819c0.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[58854,[\"936\",\"static/chunks/2f6dbc85-17d29013b8ff3da5.js\",\"142\",\"static/chunks/142-11990a208bf93746.js\",\"931\",\"static/chunks/app/page-d9bdfedbff191985.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/00c2ddbcd01819c0.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"e55gTzpa2g2-9SwXgA9Uo\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""] 2:I[77831,[],""]
3:I[27125,["447","static/chunks/447-9f8d32190ff7d16d.js","931","static/chunks/app/page-781ca5f151d78d1d.js"],""] 3:I[58854,["936","static/chunks/2f6dbc85-17d29013b8ff3da5.js","142","static/chunks/142-11990a208bf93746.js","931","static/chunks/app/page-d9bdfedbff191985.js"],""]
4:I[5613,[],""] 4:I[5613,[],""]
5:I[31778,[],""] 5:I[31778,[],""]
0:["PtTtxXIYvdjQsvRgdITlk",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/5e699db73bf6f8c2.css","precedence":"next","crossOrigin":""}]],"$L6"]]]] 0:["e55gTzpa2g2-9SwXgA9Uo",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/00c2ddbcd01819c0.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]] 6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null 1:null

View file

@ -1,23 +1,22 @@
model_list: model_list:
- model_name: text-embedding-3-small
litellm_params:
model: text-embedding-3-small
- model_name: whisper
litellm_params:
model: azure/azure-whisper
api_version: 2024-02-15-preview
api_base: os.environ/AZURE_EUROPE_API_BASE
api_key: os.environ/AZURE_EUROPE_API_KEY
model_info:
mode: audio_transcription
- litellm_params: - litellm_params:
model: gpt-4 api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
model_name: gpt-4 api_key: my-fake-key
- model_name: azure-mistral model: openai/my-fake-model
litellm_params: model_name: fake-openai-endpoint
model: azure/mistral-large-latest router_settings:
api_base: https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com num_retries: 0
api_key: os.environ/AZURE_MISTRAL_API_KEY enable_pre_call_checks: true
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
# litellm_settings: router_settings:
# cache: True routing_strategy: "latency-based-routing"
litellm_settings:
success_callback: ["openmeter"]
general_settings:
alerting: ["slack"]
alert_types: ["llm_exceptions"]

View file

@ -422,6 +422,9 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str created_by: str
updated_by: str updated_by: str
class Config:
protected_namespaces = ()
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None max_budget: Optional[float] = None
@ -485,6 +488,9 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase): class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None model_aliases: Optional[dict] = None
class Config:
protected_namespaces = ()
class GlobalEndUsersSpend(LiteLLMBase): class GlobalEndUsersSpend(LiteLLMBase):
api_key: Optional[str] = None api_key: Optional[str] = None
@ -534,6 +540,9 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
class Config:
protected_namespaces = ()
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
dict_fields = [ dict_fields = [
@ -570,6 +579,9 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
class Config:
protected_namespaces = ()
class NewOrganizationRequest(LiteLLM_BudgetTable): class NewOrganizationRequest(LiteLLM_BudgetTable):
organization_id: Optional[str] = None organization_id: Optional[str] = None
@ -900,5 +912,19 @@ class LiteLLM_SpendLogs(LiteLLMBase):
request_tags: Optional[Json] = None request_tags: Optional[Json] = None
class LiteLLM_ErrorLogs(LiteLLMBase):
request_id: Optional[str] = str(uuid.uuid4())
api_base: Optional[str] = ""
model_group: Optional[str] = ""
litellm_model_name: Optional[str] = ""
model_id: Optional[str] = ""
request_kwargs: Optional[dict] = {}
exception_type: Optional[str] = ""
status_code: Optional[str] = ""
exception_string: Optional[str] = ""
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase): class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase):
response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None

View file

@ -95,7 +95,15 @@ def common_checks(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
) )
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if litellm.max_budget > 0 and global_proxy_spend is not None: if (
litellm.max_budget > 0
and global_proxy_spend is not None
# only run global budget checks for OpenAI routes
# Reason - the Admin UI should continue working if the proxy crosses it's global budget
and route in LiteLLMRoutes.openai_routes.value
and route != "/v1/models"
and route != "/models"
):
if global_proxy_spend > litellm.max_budget: if global_proxy_spend > litellm.max_budget:
raise Exception( raise Exception(
f"ExceededBudget: LiteLLM Proxy has exceeded its budget. Current spend: {global_proxy_spend}; Max Budget: {litellm.max_budget}" f"ExceededBudget: LiteLLM Proxy has exceeded its budget. Current spend: {global_proxy_spend}; Max Budget: {litellm.max_budget}"

File diff suppressed because it is too large Load diff

View file

@ -183,6 +183,21 @@ model LiteLLM_SpendLogs {
end_user String? end_user String?
} }
// View spend, model, api_key per request
model LiteLLM_ErrorLogs {
request_id String @id @default(uuid())
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
api_base String @default("")
model_group String @default("") // public model_name / model_group
litellm_model_name String @default("") // model passed to litellm
model_id String @default("") // ID of model in ProxyModelTable
request_kwargs Json @default("{}")
exception_type String @default("")
exception_string String @default("")
status_code String @default("")
}
// Beta - allow team members to request access to a model // Beta - allow team members to request access to a model
model LiteLLM_UserNotifications { model LiteLLM_UserNotifications {
request_id String @id request_id String @id

View file

@ -387,15 +387,21 @@ class ProxyLogging:
""" """
### ALERTING ### ### ALERTING ###
if "llm_exceptions" not in self.alert_types: if "llm_exceptions" in self.alert_types and not isinstance(
return original_exception, HTTPException
asyncio.create_task( ):
self.alerting_handler( """
message=f"LLM API call failed: {str(original_exception)}", Just alert on LLM API exceptions. Do not alert on user errors
level="High",
alert_type="llm_exceptions", Related issue - https://github.com/BerriAI/litellm/issues/3395
"""
asyncio.create_task(
self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}",
level="High",
alert_type="llm_exceptions",
)
) )
)
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
@ -679,8 +685,8 @@ class PrismaClient:
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries max_tries=1, # maximum number of retries
max_time=10, # maximum total time to retry for max_time=2, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def get_generic_data( async def get_generic_data(
@ -718,7 +724,8 @@ class PrismaClient:
import traceback import traceback
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}" error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
print_verbose(error_msg) verbose_proxy_logger.error(error_msg)
error_msg = error_msg + "\nException Type: {}".format(type(e))
error_traceback = error_msg + "\n" + traceback.format_exc() error_traceback = error_msg + "\n" + traceback.format_exc()
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time
@ -1777,7 +1784,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
usage = response_obj["usage"] usage = response_obj["usage"]
if type(usage) == litellm.Usage: if type(usage) == litellm.Usage:
usage = dict(usage) usage = dict(usage)
id = response_obj.get("id", str(uuid.uuid4())) id = response_obj.get("id", kwargs.get("litellm_call_id"))
api_key = metadata.get("user_api_key", "") api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key # hash the api_key
@ -2049,6 +2056,11 @@ async def update_spend(
raise e raise e
### UPDATE KEY TABLE ### ### UPDATE KEY TABLE ###
verbose_proxy_logger.debug(
"KEY Spend transactions: {}".format(
len(prisma_client.key_list_transactons.keys())
)
)
if len(prisma_client.key_list_transactons.keys()) > 0: if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
start_time = time.time() start_time = time.time()

View file

@ -42,6 +42,7 @@ from litellm.types.router import (
RouterErrors, RouterErrors,
updateDeployment, updateDeployment,
updateLiteLLMParams, updateLiteLLMParams,
RetryPolicy,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -50,7 +51,6 @@ class Router:
model_names: List = [] model_names: List = []
cache_responses: Optional[bool] = False cache_responses: Optional[bool] = False
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0
tenacity = None tenacity = None
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None
@ -70,9 +70,11 @@ class Router:
] = None, # if you want to cache across model groups ] = None, # if you want to cache across model groups
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## RELIABILITY ## ## RELIABILITY ##
num_retries: int = 0, num_retries: Optional[int] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create default_litellm_params: Optional[
dict
] = None, # default params for Router.chat.completion.create
default_max_parallel_requests: Optional[int] = None, default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False, set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO", debug_level: Literal["DEBUG", "INFO"] = "INFO",
@ -81,6 +83,12 @@ class Router:
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False, enable_pre_call_checks: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[
RetryPolicy
] = None, # set custom retries for different exceptions
model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
] = {}, # set custom retry policies based on model group
allowed_fails: Optional[ allowed_fails: Optional[
int int
] = None, # Number of times a deployment can failbefore being added to cooldown ] = None, # Number of times a deployment can failbefore being added to cooldown
@ -158,6 +166,7 @@ class Router:
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
``` ```
""" """
if semaphore: if semaphore:
self.semaphore = semaphore self.semaphore = semaphore
self.set_verbose = set_verbose self.set_verbose = set_verbose
@ -229,7 +238,14 @@ class Router:
self.failed_calls = ( self.failed_calls = (
InMemoryCache() InMemoryCache()
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries or 0
if num_retries is not None:
self.num_retries = num_retries
elif litellm.num_retries is not None:
self.num_retries = litellm.num_retries
else:
self.num_retries = openai.DEFAULT_MAX_RETRIES
self.timeout = timeout or litellm.request_timeout self.timeout = timeout or litellm.request_timeout
self.retry_after = retry_after self.retry_after = retry_after
@ -255,6 +271,7 @@ class Router:
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
# make Router.chat.completions.create compatible for openai.chat.completions.create # make Router.chat.completions.create compatible for openai.chat.completions.create
default_litellm_params = default_litellm_params or {}
self.chat = litellm.Chat(params=default_litellm_params, router_obj=self) self.chat = litellm.Chat(params=default_litellm_params, router_obj=self)
# default litellm args # default litellm args
@ -280,6 +297,25 @@ class Router:
} }
""" """
### ROUTING SETUP ### ### ROUTING SETUP ###
self.routing_strategy_init(
routing_strategy=routing_strategy,
routing_strategy_args=routing_strategy_args,
)
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
print( # noqa
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
) # noqa
self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = retry_policy
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy": if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler( self.leastbusy_logger = LeastBusyLoggingHandler(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache, model_list=self.model_list
@ -311,15 +347,6 @@ class Router:
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
verbose_router_logger.info(
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
)
self.routing_strategy_args = routing_strategy_args
def print_deployment(self, deployment: dict): def print_deployment(self, deployment: dict):
""" """
@ -359,7 +386,9 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs): def _completion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
model_name = None model_name = None
try: try:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
@ -422,12 +451,15 @@ class Router:
) )
raise e raise e
async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): async def acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["original_function"] = self._acompletion kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout) timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
@ -437,7 +469,9 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): async def _acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
""" """
- Get an available deployment - Get an available deployment
- call it with a semaphore over the call - call it with a semaphore over the call
@ -469,6 +503,7 @@ class Router:
) )
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
model_name = data["model"] model_name = data["model"]
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if ( if (
@ -1415,10 +1450,12 @@ class Router:
context_window_fallbacks = kwargs.pop( context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}"
)
num_retries = kwargs.pop("num_retries") num_retries = kwargs.pop("num_retries")
verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
)
try: try:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs) response = await original_function(*args, **kwargs)
@ -1435,38 +1472,24 @@ class Router:
): ):
raise original_exception raise original_exception
### RETRY ### RETRY
#### check if it should retry + back-off if required
if "No models available" in str(e):
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
if hasattr(original_exception, "response") and hasattr(
original_exception.response, "headers"
):
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,
response_headers=original_exception.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=num_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
else:
raise original_exception
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
if (
self.retry_policy is not None
or self.model_group_retry_policy is not None
):
# get num_retries from retry policy
_retry_policy_retries = self.get_num_retries_from_retry_policy(
exception=original_exception, model_group=kwargs.get("model")
)
if _retry_policy_retries is not None:
num_retries = _retry_policy_retries
## LOGGING ## LOGGING
if num_retries > 0: if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
@ -1488,34 +1511,16 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
if "No models available" in str(e): _timeout = self._router_should_retry(
timeout = litellm._calculate_retry_after( e=original_exception,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, num_retries=num_retries,
min_timeout=self.retry_after, )
) await asyncio.sleep(_timeout)
await asyncio.sleep(timeout) try:
elif ( original_exception.message += f"\nNumber Retries = {current_attempt}"
hasattr(e, "status_code") except:
and hasattr(e, "response") pass
and litellm._should_retry(status_code=e.status_code)
):
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
else:
raise e
raise original_exception raise original_exception
def function_with_fallbacks(self, *args, **kwargs): def function_with_fallbacks(self, *args, **kwargs):
@ -1606,6 +1611,27 @@ class Router:
raise e raise e
raise original_exception raise original_exception
def _router_should_retry(
self, e: Exception, remaining_retries: int, num_retries: int
) -> Union[int, float]:
"""
Calculate back-off, then retry
"""
if hasattr(e, "response") and hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
return timeout
def function_with_retries(self, *args, **kwargs): def function_with_retries(self, *args, **kwargs):
""" """
Try calling the model 3 times. Shuffle between available deployments. Try calling the model 3 times. Shuffle between available deployments.
@ -1619,15 +1645,13 @@ class Router:
context_window_fallbacks = kwargs.pop( context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
try: try:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs) response = original_function(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
original_exception = e original_exception = e
verbose_router_logger.debug(
f"num retries in function with retries: {num_retries}"
)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if ( if (
isinstance(original_exception, litellm.ContextWindowExceededError) isinstance(original_exception, litellm.ContextWindowExceededError)
@ -1641,6 +1665,12 @@ class Router:
if num_retries > 0: if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY ### RETRY
_timeout = self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
time.sleep(_timeout)
for current_attempt in range(num_retries): for current_attempt in range(num_retries):
verbose_router_logger.debug( verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
@ -1654,34 +1684,12 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
if "No models available" in str(e): _timeout = self._router_should_retry(
timeout = litellm._calculate_retry_after( e=e,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
max_retries=num_retries, num_retries=num_retries,
min_timeout=self.retry_after, )
) time.sleep(_timeout)
time.sleep(timeout)
elif (
hasattr(e, "status_code")
and hasattr(e, "response")
and litellm._should_retry(status_code=e.status_code)
):
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
time.sleep(timeout)
else:
raise e
raise original_exception raise original_exception
### HELPER FUNCTIONS ### HELPER FUNCTIONS
@ -1715,10 +1723,11 @@ class Router:
) # i.e. azure ) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get("metadata", None) metadata = kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {}) _model_info = kwargs.get("litellm_params", {}).get("model_info", {})
if isinstance(_model_info, dict): if isinstance(_model_info, dict):
deployment_id = _model_info.get("id", None) deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments( self._set_cooldown_deployments(
deployment_id exception_status=exception_status, deployment=deployment_id
) # setting deployment_id in cooldown deployments ) # setting deployment_id in cooldown deployments
if custom_llm_provider: if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}" model_name = f"{custom_llm_provider}/{model_name}"
@ -1778,9 +1787,15 @@ class Router:
key=rpm_key, value=request_count, local_only=True key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl ) # don't change existing ttl
def _set_cooldown_deployments(self, deployment: Optional[str] = None): def _set_cooldown_deployments(
self, exception_status: Union[str, int], deployment: Optional[str] = None
):
""" """
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or
the exception is not one that should be immediately retried (e.g. 401)
""" """
if deployment is None: if deployment is None:
return return
@ -1797,7 +1812,20 @@ class Router:
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
) )
cooldown_time = self.cooldown_time or 1 cooldown_time = self.cooldown_time or 1
if updated_fails > self.allowed_fails:
if isinstance(exception_status, str):
try:
exception_status = int(exception_status)
except Exception as e:
verbose_router_logger.debug(
"Unable to cast exception status to int {}. Defaulting to status=500.".format(
exception_status
)
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > self.allowed_fails or _should_retry == False:
# get the current cooldown list for that minute # get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key) cached_value = self.cache.get_cache(key=cooldown_key)
@ -1941,8 +1969,10 @@ class Router:
or "ft:gpt-3.5-turbo" in model_name or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models or model_name in litellm.open_ai_embedding_models
): ):
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name): if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai" custom_llm_provider = "openai"
# remove azure prefx from model_name # remove azure prefx from model_name
model_name = model_name.replace("azure/", "") model_name = model_name.replace("azure/", "")
@ -1972,13 +2002,15 @@ class Router:
if not, add it - https://github.com/BerriAI/litellm/issues/2279 if not, add it - https://github.com/BerriAI/litellm/issues/2279
""" """
if ( if (
custom_llm_provider == "openai" is_azure_ai_studio_model == True
and api_base is not None and api_base is not None
and not api_base.endswith("/v1/") and not api_base.endswith("/v1/")
): ):
# check if it ends with a trailing slash # check if it ends with a trailing slash
if api_base.endswith("/"): if api_base.endswith("/"):
api_base += "v1/" api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else: else:
api_base += "/v1/" api_base += "/v1/"
@ -2004,7 +2036,9 @@ class Router:
stream_timeout = litellm.get_secret(stream_timeout_env_name) stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2) max_retries = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "") max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name) max_retries = litellm.get_secret(max_retries_env_name)
@ -2553,6 +2587,16 @@ class Router:
return model return model
return None return None
def get_model_info(self, id: str) -> Optional[dict]:
"""
For a given model id, return the model info
"""
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
if id == model["model_info"]["id"]:
return model
return None
def get_model_ids(self): def get_model_ids(self):
ids = [] ids = []
for model in self.model_list: for model in self.model_list:
@ -2592,6 +2636,11 @@ class Router:
for var in vars_to_include: for var in vars_to_include:
if var in _all_vars: if var in _all_vars:
_settings_to_return[var] = _all_vars[var] _settings_to_return[var] = _all_vars[var]
if (
var == "routing_strategy_args"
and self.routing_strategy == "latency-based-routing"
):
_settings_to_return[var] = self.lowestlatency_logger.routing_args.json()
return _settings_to_return return _settings_to_return
def update_settings(self, **kwargs): def update_settings(self, **kwargs):
@ -2617,12 +2666,24 @@ class Router:
"cooldown_time", "cooldown_time",
] ]
_existing_router_settings = self.get_settings()
for var in kwargs: for var in kwargs:
if var in _allowed_settings: if var in _allowed_settings:
if var in _int_settings: if var in _int_settings:
_casted_value = int(kwargs[var]) _casted_value = int(kwargs[var])
setattr(self, var, _casted_value) setattr(self, var, _casted_value)
else: else:
# only run routing strategy init if it has changed
if (
var == "routing_strategy"
and _existing_router_settings["routing_strategy"] != kwargs[var]
):
self.routing_strategy_init(
routing_strategy=kwargs[var],
routing_strategy_args=kwargs.get(
"routing_strategy_args", {}
),
)
setattr(self, var, kwargs[var]) setattr(self, var, kwargs[var])
else: else:
verbose_router_logger.debug("Setting {} is not allowed".format(var)) verbose_router_logger.debug("Setting {} is not allowed".format(var))
@ -2759,7 +2820,10 @@ class Router:
self.cache.get_cache(key=model_id, local_only=True) or 0 self.cache.get_cache(key=model_id, local_only=True) or 0
) )
### get usage based cache ### ### get usage based cache ###
if isinstance(model_group_cache, dict): if (
isinstance(model_group_cache, dict)
and self.routing_strategy != "usage-based-routing-v2"
):
model_group_cache[model_id] = model_group_cache.get(model_id, 0) model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max( current_request = max(
@ -2787,7 +2851,7 @@ class Router:
if _rate_limit_error == True: # allow generic fallback logic to take place if _rate_limit_error == True: # allow generic fallback logic to take place
raise ValueError( raise ValueError(
f"No deployments available for selected model, passed model={model}" f"{RouterErrors.no_deployments_available.value}, passed model={model}"
) )
elif _context_window_error == True: elif _context_window_error == True:
raise litellm.ContextWindowExceededError( raise litellm.ContextWindowExceededError(
@ -2852,15 +2916,10 @@ class Router:
m for m in self.model_list if m["litellm_params"]["model"] == model m for m in self.model_list if m["litellm_params"]["model"] == model
] ]
verbose_router_logger.debug( litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
f"initial list of deployments: {healthy_deployments}"
)
verbose_router_logger.debug(
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
)
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
raise ValueError(f"No healthy deployment available, passed model={model}") raise ValueError(f"No healthy deployment available, passed model={model}. ")
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[ model = litellm.model_alias_map[
model model
@ -2925,6 +2984,11 @@ class Router:
model=model, healthy_deployments=healthy_deployments, messages=messages model=model, healthy_deployments=healthy_deployments, messages=messages
) )
if len(healthy_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
)
if ( if (
self.routing_strategy == "usage-based-routing-v2" self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None and self.lowesttpm_logger_v2 is not None
@ -2980,7 +3044,7 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
raise ValueError( raise ValueError(
f"No deployments available for selected model, passed model={model}" f"{RouterErrors.no_deployments_available.value}, passed model={model}"
) )
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -3110,7 +3174,7 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
raise ValueError( raise ValueError(
f"No deployments available for selected model, passed model={model}" f"{RouterErrors.no_deployments_available.value}, passed model={model}"
) )
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -3181,6 +3245,53 @@ class Router:
except Exception as e: except Exception as e:
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
def get_num_retries_from_retry_policy(
self, exception: Exception, model_group: Optional[str] = None
):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
retry_policy = self.model_group_retry_policy.get(model_group, None)
if retry_policy is None:
return None
if (
isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
):
return retry_policy.BadRequestErrorRetries
if (
isinstance(exception, litellm.AuthenticationError)
and retry_policy.AuthenticationErrorRetries is not None
):
return retry_policy.AuthenticationErrorRetries
if (
isinstance(exception, litellm.Timeout)
and retry_policy.TimeoutErrorRetries is not None
):
return retry_policy.TimeoutErrorRetries
if (
isinstance(exception, litellm.RateLimitError)
and retry_policy.RateLimitErrorRetries is not None
):
return retry_policy.RateLimitErrorRetries
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and retry_policy.ContentPolicyViolationErrorRetries is not None
):
return retry_policy.ContentPolicyViolationErrorRetries
def flush_cache(self): def flush_cache(self):
litellm.cache = None litellm.cache = None
self.cache.flush_cache() self.cache.flush_cache()
@ -3191,4 +3302,5 @@ class Router:
litellm.__async_success_callback = [] litellm.__async_success_callback = []
litellm.failure_callback = [] litellm.failure_callback = []
litellm._async_failure_callback = [] litellm._async_failure_callback = []
self.retry_policy = None
self.flush_cache() self.flush_cache()

View file

@ -4,6 +4,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -29,6 +30,8 @@ class LiteLLMBase(BaseModel):
class RoutingArgs(LiteLLMBase): class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 * 60 # 1 hour ttl: int = 1 * 60 * 60 # 1 hour
lowest_latency_buffer: float = 0
max_latency_list_size: int = 10
class LowestLatencyLoggingHandler(CustomLogger): class LowestLatencyLoggingHandler(CustomLogger):
@ -101,7 +104,18 @@ class LowestLatencyLoggingHandler(CustomLogger):
request_count_dict[id] = {} request_count_dict[id] = {}
## Latency ## Latency
request_count_dict[id].setdefault("latency", []).append(final_value) if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(final_value)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
if precise_minute not in request_count_dict[id]: if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {} request_count_dict[id][precise_minute] = {}
@ -168,8 +182,17 @@ class LowestLatencyLoggingHandler(CustomLogger):
if id not in request_count_dict: if id not in request_count_dict:
request_count_dict[id] = {} request_count_dict[id] = {}
## Latency ## Latency - give 1000s penalty for failing
request_count_dict[id].setdefault("latency", []).append(1000.0) if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(1000.0)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
self.router_cache.set_cache( self.router_cache.set_cache(
key=latency_key, key=latency_key,
value=request_count_dict, value=request_count_dict,
@ -240,7 +263,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
request_count_dict[id] = {} request_count_dict[id] = {}
## Latency ## Latency
request_count_dict[id].setdefault("latency", []).append(final_value) if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(final_value)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
if precise_minute not in request_count_dict[id]: if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {} request_count_dict[id][precise_minute] = {}
@ -312,6 +343,14 @@ class LowestLatencyLoggingHandler(CustomLogger):
except: except:
input_tokens = 0 input_tokens = 0
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
for item, item_map in all_deployments.items(): for item, item_map in all_deployments.items():
## get the item from model list ## get the item from model list
_deployment = None _deployment = None
@ -360,17 +399,33 @@ class LowestLatencyLoggingHandler(CustomLogger):
# End of Debugging Logic # End of Debugging Logic
# -------------- # # -------------- #
if item_latency == 0: if (
deployment = _deployment
break
elif (
item_tpm + input_tokens > _deployment_tpm item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list ): # if user passed in tpm / rpm in the model_list
continue continue
elif item_latency < lowest_latency: else:
lowest_latency = item_latency potential_deployments.append((_deployment, item_latency))
deployment = _deployment
if len(potential_deployments) == 0:
return None
# Sort potential deployments by latency
sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
# Find lowest latency deployment
lowest_latency = sorted_deployments[0][1]
# Find deployments within buffer of lowest latency
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
valid_deployments = [
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
]
# Pick a random deployment from valid deployments
random_valid_deployment = random.choice(valid_deployments)
deployment = random_valid_deployment[0]
if request_kwargs is not None and "metadata" in request_kwargs: if request_kwargs is not None and "metadata" in request_kwargs:
request_kwargs["metadata"][ request_kwargs["metadata"][
"_latency_per_deployment" "_latency_per_deployment"

View file

@ -206,7 +206,7 @@ class LowestTPMLoggingHandler(CustomLogger):
if item_tpm + input_tokens > _deployment_tpm: if item_tpm + input_tokens > _deployment_tpm:
continue continue
elif (rpm_dict is not None and item in rpm_dict) and ( elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 > _deployment_rpm rpm_dict[item] + 1 >= _deployment_rpm
): ):
continue continue
elif item_tpm < lowest_tpm: elif item_tpm < lowest_tpm:

View file

@ -79,10 +79,12 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
model=deployment.get("litellm_params", {}).get("model"), model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response( response=httpx.Response(
status_code=429, status_code=429,
content="{} rpm limit={}. current usage={}".format( content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format(
RouterErrors.user_defined_ratelimit_error.value, RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm, deployment_rpm,
local_result, local_result,
model_id,
deployment.get("model_name", ""),
), ),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
@ -333,7 +335,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
tpm_dict[tpm_key] = 0 tpm_dict[tpm_key] = 0
all_deployments = tpm_dict all_deployments = tpm_dict
deployment = None potential_deployments = [] # if multiple deployments have the same low value
for item, item_tpm in all_deployments.items(): for item, item_tpm in all_deployments.items():
## get the item from model list ## get the item from model list
_deployment = None _deployment = None
@ -343,6 +345,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment = m _deployment = m
if _deployment is None: if _deployment is None:
continue # skip to next one continue # skip to next one
elif item_tpm is None:
continue # skip if unhealthy deployment
_deployment_tpm = None _deployment_tpm = None
if _deployment_tpm is None: if _deployment_tpm is None:
@ -366,14 +370,20 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if item_tpm + input_tokens > _deployment_tpm: if item_tpm + input_tokens > _deployment_tpm:
continue continue
elif (rpm_dict is not None and item in rpm_dict) and ( elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 > _deployment_rpm rpm_dict[item] + 1 >= _deployment_rpm
): ):
continue continue
elif item_tpm == lowest_tpm:
potential_deployments.append(_deployment)
elif item_tpm < lowest_tpm: elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm lowest_tpm = item_tpm
deployment = _deployment potential_deployments = [_deployment]
print_verbose("returning picked lowest tpm/rpm deployment.") print_verbose("returning picked lowest tpm/rpm deployment.")
return deployment
if len(potential_deployments) > 0:
return random.choice(potential_deployments)
else:
return None
async def async_get_available_deployments( async def async_get_available_deployments(
self, self,
@ -394,6 +404,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
tpm_keys = [] tpm_keys = []
rpm_keys = [] rpm_keys = []
for m in healthy_deployments: for m in healthy_deployments:
@ -416,7 +427,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)] tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :] rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
return self._common_checks_available_deployment( deployment = self._common_checks_available_deployment(
model_group=model_group, model_group=model_group,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys, tpm_keys=tpm_keys,
@ -427,6 +438,61 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
input=input, input=input,
) )
try:
assert deployment is not None
return deployment
except Exception as e:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
if isinstance(_deployment, dict):
id = _deployment.get("model_info", {}).get("id")
### GET DEPLOYMENT TPM LIMIT ###
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index]
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index]
deployment_dict[id] = {
"current_tpm": current_tpm,
"tpm_limit": _deployment_tpm,
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise ValueError(
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
)
def get_available_deployments( def get_available_deployments(
self, self,
model_group: str, model_group: str,
@ -464,7 +530,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
keys=rpm_keys keys=rpm_keys
) # [1, 2, None, ..] ) # [1, 2, None, ..]
return self._common_checks_available_deployment( deployment = self._common_checks_available_deployment(
model_group=model_group, model_group=model_group,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys, tpm_keys=tpm_keys,
@ -474,3 +540,58 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
messages=messages, messages=messages,
input=input, input=input,
) )
try:
assert deployment is not None
return deployment
except Exception as e:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
if isinstance(_deployment, dict):
id = _deployment.get("model_info", {}).get("id")
### GET DEPLOYMENT TPM LIMIT ###
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index]
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index]
deployment_dict[id] = {
"current_tpm": current_tpm,
"tpm_limit": _deployment_tpm,
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise ValueError(
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
)

View file

@ -19,6 +19,7 @@ def setup_and_teardown():
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the project directory to the system path ) # Adds the project directory to the system path
import litellm import litellm
from litellm import Router
importlib.reload(litellm) importlib.reload(litellm)
import asyncio import asyncio

View file

@ -0,0 +1,88 @@
int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/client.py", line 778, in generation
"usage": _convert_usage_input(usage) if usage is not None else None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 77, in _convert_usage_input
"totalCost": extract_by_priority(usage, ["totalCost", "total_cost"]),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 32, in extract_by_priority
return int(usage[key])
^^^^^^^^^^^^^^^
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/client.py", line 778, in generation
"usage": _convert_usage_input(usage) if usage is not None else None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 77, in _convert_usage_input
"totalCost": extract_by_priority(usage, ["totalCost", "total_cost"]),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 32, in extract_by_priority
return int(usage[key])
^^^^^^^^^^^^^^^
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/client.py", line 778, in generation
"usage": _convert_usage_input(usage) if usage is not None else None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 77, in _convert_usage_input
"totalCost": extract_by_priority(usage, ["totalCost", "total_cost"]),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 32, in extract_by_priority
return int(usage[key])
^^^^^^^^^^^^^^^
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/client.py", line 778, in generation
"usage": _convert_usage_input(usage) if usage is not None else None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 77, in _convert_usage_input
"totalCost": extract_by_priority(usage, ["totalCost", "total_cost"]),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 32, in extract_by_priority
return int(usage[key])
^^^^^^^^^^^^^^^
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
Traceback (most recent call last):
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/client.py", line 778, in generation
"usage": _convert_usage_input(usage) if usage is not None else None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 77, in _convert_usage_input
"totalCost": extract_by_priority(usage, ["totalCost", "total_cost"]),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/langfuse/utils.py", line 32, in extract_by_priority
return int(usage[key])
^^^^^^^^^^^^^^^
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
consumer is running...
Getting observations... None, None, None, None, litellm-test-98e1cc75-bef8-4280-a2b9-e08633b81acd, None, GENERATION
consumer is running...
Getting observations... None, None, None, None, litellm-test-532d2bc8-f8d6-42fd-8f78-416bae79925d, None, GENERATION
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined
joining 1 consumer threads
consumer thread 0 joined

View file

@ -5,74 +5,99 @@ plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1
asyncio: mode=Mode.STRICT asyncio: mode=Mode.STRICT
collected 1 item collected 1 item
test_custom_logger.py Chunks have a created at hidden param test_completion.py F [100%]
Chunks sorted
token_counter messages received: [{'role': 'user', 'content': 'write a one sentence poem about: 73348'}]
Token Counter - using OpenAI token counter, for model=gpt-3.5-turbo
LiteLLM: Utils - Counting tokens for OpenAI model=gpt-3.5-turbo
Logging Details LiteLLM-Success Call: None
success callbacks: []
Token Counter - using OpenAI token counter, for model=gpt-3.5-turbo
LiteLLM: Utils - Counting tokens for OpenAI model=gpt-3.5-turbo
Logging Details LiteLLM-Success Call streaming complete
Looking up model=gpt-3.5-turbo in model_cost_map
Success: model=gpt-3.5-turbo in model_cost_map
prompt_tokens=17; completion_tokens=0
Returned custom cost for model=gpt-3.5-turbo - prompt_tokens_cost_usd_dollar: 2.55e-05, completion_tokens_cost_usd_dollar: 0.0
final cost: 2.55e-05; prompt_tokens_cost_usd_dollar: 2.55e-05; completion_tokens_cost_usd_dollar: 0.0
. [100%]
=================================== FAILURES ===================================
______________________ test_completion_anthropic_hanging _______________________
def test_completion_anthropic_hanging():
litellm.set_verbose = True
litellm.modify_params = True
messages = [
{
"role": "user",
"content": "What's the capital of fictional country Ubabababababaaba? Use your tools.",
},
{
"role": "assistant",
"function_call": {
"name": "get_capital",
"arguments": '{"country": "Ubabababababaaba"}',
},
},
{"role": "function", "name": "get_capital", "content": "Kokoko"},
]
converted_messages = anthropic_messages_pt(messages)
print(f"converted_messages: {converted_messages}")
## ENSURE USER / ASSISTANT ALTERNATING
for i, msg in enumerate(converted_messages):
if i < len(converted_messages) - 1:
> assert msg["role"] != converted_messages[i + 1]["role"]
E AssertionError: assert 'user' != 'user'
test_completion.py:2406: AssertionError
---------------------------- Captured stdout setup -----------------------------
<module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
pytest fixture - resetting callbacks
----------------------------- Captured stdout call -----------------------------
message: {'role': 'user', 'content': "What's the capital of fictional country Ubabababababaaba? Use your tools."}
message: {'role': 'function', 'name': 'get_capital', 'content': 'Kokoko'}
converted_messages: [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the capital of fictional country Ubabababababaaba? Use your tools."}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': '10e9f4d4-bdc9-4514-8b7a-c10bc555d67c', 'content': 'Kokoko'}]}]
=============================== warnings summary =============================== =============================== warnings summary ===============================
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 18 warnings ../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 23 warnings
/opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../proxy/_types.py:218 ../proxy/_types.py:219
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:218: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:219: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:305 ../proxy/_types.py:306
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:305: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:306: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
extra = Extra.allow # Allow extra fields extra = Extra.allow # Allow extra fields
../proxy/_types.py:308 ../proxy/_types.py:309
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:308: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:309: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:337 ../proxy/_types.py:338
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:337: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:338: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:384 ../proxy/_types.py:385
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:384: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:385: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:450 ../proxy/_types.py:454
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:450: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:454: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:462 ../proxy/_types.py:466
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:462: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:466: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:502 ../proxy/_types.py:509
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:502: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:509: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:536 ../proxy/_types.py:546
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:536: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:546: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:823 ../proxy/_types.py:840
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:823: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:840: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:850 ../proxy/_types.py:867
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:850: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:867: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:869 ../proxy/_types.py:886
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:869: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:886: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
@root_validator(pre=True) @root_validator(pre=True)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:121 ../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:121
@ -126,30 +151,7 @@ final cost: 2.55e-05; prompt_tokens_cost_usd_dollar: 2.55e-05; completion_tokens
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg) declare_namespace(pkg)
test_custom_logger.py::test_redis_cache_completion_stream
/opt/homebrew/lib/python3.11/site-packages/_pytest/unraisableexception.py:78: PytestUnraisableExceptionWarning: Exception ignored in: <function StreamWriter.__del__ at 0x1019c28e0>
Traceback (most recent call last):
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/streams.py", line 395, in __del__
self.close()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/streams.py", line 343, in close
return self._transport.close()
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/sslproto.py", line 112, in close
self._ssl_protocol._start_shutdown()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/sslproto.py", line 620, in _start_shutdown
self._shutdown_timeout_handle = self._loop.call_later(
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 727, in call_later
timer = self.call_at(self.time() + delay, callback, *args,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 740, in call_at
self._check_closed()
File "/opt/homebrew/Cellar/python@3.11/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 519, in _check_closed
raise RuntimeError('Event loop is closed')
RuntimeError: Event loop is closed
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 56 warnings in 2.43s ======================== =========================== short test summary info ============================
FAILED test_completion.py::test_completion_anthropic_hanging - AssertionError...
======================== 1 failed, 60 warnings in 0.15s ========================

View file

@ -119,7 +119,9 @@ def test_multiple_deployments_parallel():
# test_multiple_deployments_parallel() # test_multiple_deployments_parallel()
def test_cooldown_same_model_name(): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cooldown_same_model_name(sync_mode):
# users could have the same model with different api_base # users could have the same model with different api_base
# example # example
# azure/chatgpt, api_base: 1234 # azure/chatgpt, api_base: 1234
@ -161,22 +163,40 @@ def test_cooldown_same_model_name():
num_retries=3, num_retries=3,
) # type: ignore ) # type: ignore
response = router.completion( if sync_mode:
model="gpt-3.5-turbo", response = router.completion(
messages=[{"role": "user", "content": "hello this request will pass"}], model="gpt-3.5-turbo",
) messages=[{"role": "user", "content": "hello this request will pass"}],
print(router.model_list) )
model_ids = [] print(router.model_list)
for model in router.model_list: model_ids = []
model_ids.append(model["model_info"]["id"]) for model in router.model_list:
print("\n litellm model ids ", model_ids) model_ids.append(model["model_info"]["id"])
print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960'] # example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert ( assert (
model_ids[0] != model_ids[1] model_ids[0] != model_ids[1]
) # ensure both models have a uuid added, and they have different names ) # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response) print("\ngot response\n", response)
else:
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hello this request will pass"}],
)
print(router.model_list)
model_ids = []
for model in router.model_list:
model_ids.append(model["model_info"]["id"])
print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert (
model_ids[0] != model_ids[1]
) # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response)
except Exception as e: except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {e}") pytest.fail(f"Got unexpected exception on router! - {e}")

View file

@ -161,40 +161,54 @@ async def make_async_calls():
return total_time return total_time
# def test_langfuse_logging_async_text_completion(): @pytest.mark.asyncio
# try: @pytest.mark.parametrize("stream", [False, True])
# pre_langfuse_setup() async def test_langfuse_logging_without_request_response(stream):
# litellm.set_verbose = False try:
# litellm.success_callback = ["langfuse"] import uuid
# async def _test_langfuse(): _unique_trace_name = f"litellm-test-{str(uuid.uuid4())}"
# response = await litellm.atext_completion( litellm.set_verbose = True
# model="gpt-3.5-turbo-instruct", litellm.turn_off_message_logging = True
# prompt="this is a test", litellm.success_callback = ["langfuse"]
# max_tokens=5, response = await litellm.acompletion(
# temperature=0.7, model="gpt-3.5-turbo",
# timeout=5, mock_response="It's simple to use and easy to get started",
# user="test_user", messages=[{"role": "user", "content": "Hi 👋 - i'm claude"}],
# stream=True max_tokens=10,
# ) temperature=0.2,
# async for chunk in response: stream=stream,
# print() metadata={"trace_id": _unique_trace_name},
# print(chunk) )
# await asyncio.sleep(1) print(response)
# return response if stream:
async for chunk in response:
print(chunk)
# response = asyncio.run(_test_langfuse()) await asyncio.sleep(3)
# print(f"response: {response}")
# # # check langfuse.log to see if there was a failed response import langfuse
# search_logs("langfuse.log")
# except litellm.Timeout as e:
# pass
# except Exception as e:
# pytest.fail(f"An exception occurred - {e}")
langfuse_client = langfuse.Langfuse(
public_key=os.environ["LANGFUSE_PUBLIC_KEY"],
secret_key=os.environ["LANGFUSE_SECRET_KEY"],
)
# test_langfuse_logging_async_text_completion() # get trace with _unique_trace_name
trace = langfuse_client.get_generations(trace_id=_unique_trace_name)
print("trace_from_langfuse", trace)
_trace_data = trace.data
assert _trace_data[0].input == {"messages": "redacted-by-litellm"}
assert _trace_data[0].output == {
"role": "assistant",
"content": "redacted-by-litellm",
}
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
@pytest.mark.skip(reason="beta test - checking langfuse output") @pytest.mark.skip(reason="beta test - checking langfuse output")
@ -334,6 +348,228 @@ def test_langfuse_logging_function_calling():
# test_langfuse_logging_function_calling() # test_langfuse_logging_function_calling()
def test_langfuse_existing_trace_id():
"""
When existing trace id is passed, don't set trace params -> prevents overwriting the trace
Pass 1 logging object with a trace
Pass 2nd logging object with the trace id
Assert no changes to the trace
"""
# Test - if the logs were sent to the correct team on langfuse
import litellm, datetime
from litellm.integrations.langfuse import LangFuseLogger
langfuse_Logger = LangFuseLogger(
langfuse_public_key=os.getenv("LANGFUSE_PROJECT2_PUBLIC"),
langfuse_secret=os.getenv("LANGFUSE_PROJECT2_SECRET"),
)
litellm.success_callback = ["langfuse"]
# langfuse_args = {'kwargs': { 'start_time': 'end_time': datetime.datetime(2024, 5, 1, 7, 31, 29, 903685), 'user_id': None, 'print_verbose': <function print_verbose at 0x109d1f420>, 'level': 'DEFAULT', 'status_message': None}
response_obj = litellm.ModelResponse(
id="chatcmpl-9K5HUAbVRqFrMZKXL0WoC295xhguY",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="I'm sorry, I am an AI assistant and do not have real-time information. I recommend checking a reliable weather website or app for the most up-to-date weather information in Boston.",
role="assistant",
),
)
],
created=1714573888,
model="gpt-3.5-turbo-0125",
object="chat.completion",
system_fingerprint="fp_3b956da36b",
usage=litellm.Usage(completion_tokens=37, prompt_tokens=14, total_tokens=51),
)
### NEW TRACE ###
message = [{"role": "user", "content": "what's the weather in boston"}]
langfuse_args = {
"response_obj": response_obj,
"kwargs": {
"model": "gpt-3.5-turbo",
"litellm_params": {
"acompletion": False,
"api_key": None,
"force_timeout": 600,
"logger_fn": None,
"verbose": False,
"custom_llm_provider": "openai",
"api_base": "https://api.openai.com/v1/",
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
"model_alias_map": {},
"completion_call_id": None,
"metadata": None,
"model_info": None,
"proxy_server_request": None,
"preset_cache_key": None,
"no-log": False,
"stream_response": {},
},
"messages": message,
"optional_params": {"temperature": 0.1, "extra_body": {}},
"start_time": "2024-05-01 07:31:27.986164",
"stream": False,
"user": None,
"call_type": "completion",
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
"completion_start_time": "2024-05-01 07:31:29.903685",
"temperature": 0.1,
"extra_body": {},
"input": [{"role": "user", "content": "what's the weather in boston"}],
"api_key": "my-api-key",
"additional_args": {
"complete_input_dict": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "what's the weather in boston"}
],
"temperature": 0.1,
"extra_body": {},
}
},
"log_event_type": "successful_api_call",
"end_time": "2024-05-01 07:31:29.903685",
"cache_hit": None,
"response_cost": 6.25e-05,
},
"start_time": datetime.datetime(2024, 5, 1, 7, 31, 27, 986164),
"end_time": datetime.datetime(2024, 5, 1, 7, 31, 29, 903685),
"user_id": None,
"print_verbose": litellm.print_verbose,
"level": "DEFAULT",
"status_message": None,
}
langfuse_response_object = langfuse_Logger.log_event(**langfuse_args)
import langfuse
langfuse_client = langfuse.Langfuse(
public_key=os.getenv("LANGFUSE_PROJECT2_PUBLIC"),
secret_key=os.getenv("LANGFUSE_PROJECT2_SECRET"),
)
trace_id = langfuse_response_object["trace_id"]
langfuse_client.flush()
time.sleep(2)
print(langfuse_client.get_trace(id=trace_id))
initial_langfuse_trace = langfuse_client.get_trace(id=trace_id)
### EXISTING TRACE ###
new_metadata = {"existing_trace_id": trace_id}
new_messages = [{"role": "user", "content": "What do you know?"}]
new_response_obj = litellm.ModelResponse(
id="chatcmpl-9K5HUAbVRqFrMZKXL0WoC295xhguY",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="What do I know?",
role="assistant",
),
)
],
created=1714573888,
model="gpt-3.5-turbo-0125",
object="chat.completion",
system_fingerprint="fp_3b956da36b",
usage=litellm.Usage(completion_tokens=37, prompt_tokens=14, total_tokens=51),
)
langfuse_args = {
"response_obj": new_response_obj,
"kwargs": {
"model": "gpt-3.5-turbo",
"litellm_params": {
"acompletion": False,
"api_key": None,
"force_timeout": 600,
"logger_fn": None,
"verbose": False,
"custom_llm_provider": "openai",
"api_base": "https://api.openai.com/v1/",
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
"model_alias_map": {},
"completion_call_id": None,
"metadata": new_metadata,
"model_info": None,
"proxy_server_request": None,
"preset_cache_key": None,
"no-log": False,
"stream_response": {},
},
"messages": new_messages,
"optional_params": {"temperature": 0.1, "extra_body": {}},
"start_time": "2024-05-01 07:31:27.986164",
"stream": False,
"user": None,
"call_type": "completion",
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
"completion_start_time": "2024-05-01 07:31:29.903685",
"temperature": 0.1,
"extra_body": {},
"input": [{"role": "user", "content": "what's the weather in boston"}],
"api_key": "my-api-key",
"additional_args": {
"complete_input_dict": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "what's the weather in boston"}
],
"temperature": 0.1,
"extra_body": {},
}
},
"log_event_type": "successful_api_call",
"end_time": "2024-05-01 07:31:29.903685",
"cache_hit": None,
"response_cost": 6.25e-05,
},
"start_time": datetime.datetime(2024, 5, 1, 7, 31, 27, 986164),
"end_time": datetime.datetime(2024, 5, 1, 7, 31, 29, 903685),
"user_id": None,
"print_verbose": litellm.print_verbose,
"level": "DEFAULT",
"status_message": None,
}
langfuse_response_object = langfuse_Logger.log_event(**langfuse_args)
new_trace_id = langfuse_response_object["trace_id"]
assert new_trace_id == trace_id
langfuse_client.flush()
time.sleep(2)
print(langfuse_client.get_trace(id=trace_id))
new_langfuse_trace = langfuse_client.get_trace(id=trace_id)
initial_langfuse_trace_dict = dict(initial_langfuse_trace)
initial_langfuse_trace_dict.pop("updatedAt")
initial_langfuse_trace_dict.pop("timestamp")
new_langfuse_trace_dict = dict(new_langfuse_trace)
new_langfuse_trace_dict.pop("updatedAt")
new_langfuse_trace_dict.pop("timestamp")
assert initial_langfuse_trace_dict == new_langfuse_trace_dict
def test_langfuse_logging_tool_calling(): def test_langfuse_logging_tool_calling():
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -15,10 +15,24 @@ import litellm
import pytest import pytest
import asyncio import asyncio
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from litellm.utils import get_api_base
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting
@pytest.mark.parametrize(
"model, optional_params, expected_api_base",
[
("openai/my-fake-model", {"api_base": "my-fake-api-base"}, "my-fake-api-base"),
("gpt-3.5-turbo", {}, "https://api.openai.com"),
],
)
def test_get_api_base_unit_test(model, optional_params, expected_api_base):
api_base = get_api_base(model=model, optional_params=optional_params)
assert api_base == expected_api_base
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_base(): async def test_get_api_base():
_pl = ProxyLogging(user_api_key_cache=DualCache()) _pl = ProxyLogging(user_api_key_cache=DualCache())
@ -94,3 +108,80 @@ def test_init():
assert slack_no_alerting.alerting == [] assert slack_no_alerting.alerting == []
print("passed testing slack alerting init") print("passed testing slack alerting init")
from unittest.mock import patch, AsyncMock
from datetime import datetime, timedelta
@pytest.fixture
def slack_alerting():
return SlackAlerting(alerting_threshold=1)
# Test for hanging LLM responses
@pytest.mark.asyncio
async def test_response_taking_too_long_hanging(slack_alerting):
request_data = {
"model": "test_model",
"messages": "test_messages",
"litellm_status": "running",
}
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
await slack_alerting.response_taking_too_long(
type="hanging_request", request_data=request_data
)
mock_send_alert.assert_awaited_once()
# Test for slow LLM responses
@pytest.mark.asyncio
async def test_response_taking_too_long_callback(slack_alerting):
start_time = datetime.now()
end_time = start_time + timedelta(seconds=301)
kwargs = {"model": "test_model", "messages": "test_messages", "litellm_params": {}}
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
await slack_alerting.response_taking_too_long_callback(
kwargs, None, start_time, end_time
)
mock_send_alert.assert_awaited_once()
# Test for budget crossed
@pytest.mark.asyncio
async def test_budget_alerts_crossed(slack_alerting):
user_max_budget = 100
user_current_spend = 101
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
await slack_alerting.budget_alerts(
"user_budget", user_max_budget, user_current_spend
)
mock_send_alert.assert_awaited_once()
# Test for budget crossed again (should not fire alert 2nd time)
@pytest.mark.asyncio
async def test_budget_alerts_crossed_again(slack_alerting):
user_max_budget = 100
user_current_spend = 101
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
await slack_alerting.budget_alerts(
"user_budget", user_max_budget, user_current_spend
)
mock_send_alert.assert_awaited_once()
mock_send_alert.reset_mock()
await slack_alerting.budget_alerts(
"user_budget", user_max_budget, user_current_spend
)
mock_send_alert.assert_not_awaited()
# Test for send_alert - should be called once
@pytest.mark.asyncio
async def test_send_alert(slack_alerting):
with patch.object(
slack_alerting.async_http_handler, "post", new=AsyncMock()
) as mock_post:
mock_post.return_value.status_code = 200
await slack_alerting.send_alert("Test message", "Low", "budget_alerts")
mock_post.assert_awaited_once()

View file

@ -394,6 +394,8 @@ async def test_async_vertexai_response():
pass pass
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except litellm.APIError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
@ -546,42 +548,6 @@ def test_gemini_pro_vision_base64():
def test_gemini_pro_function_calling(): def test_gemini_pro_function_calling():
load_vertex_ai_credentials()
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
completion = litellm.completion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
if hasattr(completion.choices[0].message, "tool_calls") and isinstance(
completion.choices[0].message.tool_calls, list
):
assert len(completion.choices[0].message.tool_calls) == 1
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [

View file

@ -0,0 +1,102 @@
# What is this?
## Unit Tests for OpenAI Assistants API
import sys, os, json
import traceback
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm
from litellm import create_thread, get_thread
from litellm.llms.openai import (
OpenAIAssistantsAPI,
MessageData,
Thread,
OpenAIMessage as Message,
)
"""
V0 Scope:
- Add Message -> `/v1/threads/{thread_id}/messages`
- Run Thread -> `/v1/threads/{thread_id}/run`
"""
def test_create_thread_litellm() -> Thread:
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = create_thread(
custom_llm_provider="openai",
messages=[message], # type: ignore
)
assert isinstance(
new_thread, Thread
), f"type of thread={type(new_thread)}. Expected Thread-type"
return new_thread
def test_get_thread_litellm():
new_thread = test_create_thread_litellm()
received_thread = get_thread(
custom_llm_provider="openai",
thread_id=new_thread.id,
)
assert isinstance(
received_thread, Thread
), f"type of thread={type(received_thread)}. Expected Thread-type"
return new_thread
def test_add_message_litellm():
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = test_create_thread_litellm()
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
added_message = litellm.add_message(
thread_id=new_thread.id, custom_llm_provider="openai", **message
)
print(f"added message: {added_message}")
assert isinstance(added_message, Message)
def test_run_thread_litellm():
"""
- Get Assistants
- Create thread
- Create run w/ Assistants + Thread
"""
assistants = litellm.get_assistants(custom_llm_provider="openai")
## get the first assistant ###
assistant_id = assistants.data[0].id
new_thread = test_create_thread_litellm()
thread_id = new_thread.id
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
added_message = litellm.add_message(
thread_id=new_thread.id, custom_llm_provider="openai", **message
)
run = litellm.run_thread(
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=new_thread.id, custom_llm_provider="openai"
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")

View file

@ -207,7 +207,7 @@ def test_completion_bedrock_claude_sts_client_auth():
# test_completion_bedrock_claude_sts_client_auth() # test_completion_bedrock_claude_sts_client_auth()
def test_bedrock_claude_3(): def test_bedrock_extra_headers():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response: ModelResponse = completion( response: ModelResponse = completion(
@ -215,6 +215,7 @@ def test_bedrock_claude_3():
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.78, temperature=0.78,
extra_headers={"x-key": "x_key_value"}
) )
# Add any assertions here to check the response # Add any assertions here to check the response
assert len(response.choices) > 0 assert len(response.choices) > 0
@ -225,6 +226,48 @@ def test_bedrock_claude_3():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
try:
litellm.set_verbose = True
data = {
"max_tokens": 2000,
"stream": False,
"temperature": 0.3,
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hi"},
{
"role": "user",
"content": [
{"text": "describe this image", "type": "text"},
{
"image_url": {
"detail": "high",
"url": "",
},
"type": "image_url",
},
],
},
],
}
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
# messages=messages,
# max_tokens=10,
# temperature=0.78,
**data,
)
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3_tool_calling(): def test_bedrock_claude_3_tool_calling():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -12,6 +12,7 @@ import pytest
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries=3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
@ -57,7 +58,7 @@ def test_completion_custom_provider_model_name():
messages=messages, messages=messages,
logger_fn=logger_fn, logger_fn=logger_fn,
) )
# Add any assertions here to, check the response # Add any assertions here to,check the response
print(response) print(response)
print(response["choices"][0]["finish_reason"]) print(response["choices"][0]["finish_reason"])
except litellm.Timeout as e: except litellm.Timeout as e:
@ -230,49 +231,144 @@ def test_completion_claude_3_function_call():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_with_text_content_dictionaries():
@pytest.mark.asyncio
async def test_anthropic_no_content_error():
"""
https://github.com/BerriAI/litellm/discussions/3440#discussioncomment-9323402
"""
try:
litellm.drop_params = True
response = await litellm.acompletion(
model="anthropic/claude-3-opus-20240229",
api_key=os.getenv("ANTHROPIC_API_KEY"),
messages=[
{
"role": "system",
"content": "You will be given a list of fruits. Use the submitFruit function to submit a fruit. Don't say anything after.",
},
{"role": "user", "content": "I like apples"},
{
"content": "<thinking>The most relevant tool for this request is the submitFruit function.</thinking>",
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": '{"name": "Apple"}',
"name": "submitFruit",
},
"id": "toolu_012ZTYKWD4VqrXGXyE7kEnAK",
"type": "function",
}
],
},
{
"role": "tool",
"content": '{"success":true}',
"tool_call_id": "toolu_012ZTYKWD4VqrXGXyE7kEnAK",
},
],
max_tokens=2000,
temperature=1,
tools=[
{
"type": "function",
"function": {
"name": "submitFruit",
"description": "Submits a fruit",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the fruit",
}
},
"required": ["name"],
},
},
}
],
frequency_penalty=0.8,
)
pass
except litellm.APIError as e:
assert e.status_code == 500
except Exception as e:
pytest.fail(f"An unexpected error occurred - {str(e)}")
def test_completion_cohere_command_r_plus_function_call():
litellm.set_verbose = True litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": [ "content": "What's the weather like in Boston today in Fahrenheit?",
{
"type": "text",
"text": "Hello"
}
]
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "Hello! How can I assist you today?"
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello again!"
}
]
} }
] ]
try: try:
# test without max tokens # test without max tokens
response = completion( response = completion(
model="anthropic/claude-3-opus-20240229", model="command-r-plus",
messages=messages, messages=messages,
tools=tools,
tool_choice="auto",
) )
# Add any assertions, here to check response args # Add any assertions, here to check response args
print(response) print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Cohere should deduce answer from tool results
second_response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(second_response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_parse_xml_params(): def test_parse_xml_params():
from litellm.llms.prompt_templates.factory import parse_xml_params from litellm.llms.prompt_templates.factory import parse_xml_params
@ -1412,6 +1508,198 @@ def test_completion_ollama_hosted():
# test_completion_ollama_hosted() # test_completion_ollama_hosted()
@pytest.mark.skip(reason="Local test")
@pytest.mark.parametrize(
("model"),
[
"ollama/llama2",
"ollama_chat/llama2",
],
)
def test_completion_ollama_function_call(model):
messages = [
{"role": "user", "content": "What's the weather like in San Francisco?"}
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
try:
litellm.set_verbose = True
response = litellm.completion(model=model, messages=messages, tools=tools)
print(response)
assert response.choices[0].message.tool_calls
assert (
response.choices[0].message.tool_calls[0].function.name
== "get_current_weather"
)
assert response.choices[0].finish_reason == "tool_calls"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="Local test")
@pytest.mark.parametrize(
("model"),
[
"ollama/llama2",
"ollama_chat/llama2",
],
)
def test_completion_ollama_function_call_stream(model):
messages = [
{"role": "user", "content": "What's the weather like in San Francisco?"}
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
try:
litellm.set_verbose = True
response = litellm.completion(
model=model, messages=messages, tools=tools, stream=True
)
print(response)
first_chunk = next(response)
assert first_chunk.choices[0].delta.tool_calls
assert (
first_chunk.choices[0].delta.tool_calls[0].function.name
== "get_current_weather"
)
assert first_chunk.choices[0].finish_reason == "tool_calls"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
("model"),
[
"ollama/llama2",
"ollama_chat/llama2",
],
)
@pytest.mark.asyncio
async def test_acompletion_ollama_function_call(model):
messages = [
{"role": "user", "content": "What's the weather like in San Francisco?"}
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
try:
litellm.set_verbose = True
response = await litellm.acompletion(
model=model, messages=messages, tools=tools
)
print(response)
assert response.choices[0].message.tool_calls
assert (
response.choices[0].message.tool_calls[0].function.name
== "get_current_weather"
)
assert response.choices[0].finish_reason == "tool_calls"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
("model"),
[
"ollama/llama2",
"ollama_chat/llama2",
],
)
@pytest.mark.asyncio
async def test_acompletion_ollama_function_call_stream(model):
messages = [
{"role": "user", "content": "What's the weather like in San Francisco?"}
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
try:
litellm.set_verbose = True
response = await litellm.acompletion(
model=model, messages=messages, tools=tools, stream=True
)
print(response)
first_chunk = await anext(response)
assert first_chunk.choices[0].delta.tool_calls
assert (
first_chunk.choices[0].delta.tool_calls[0].function.name
== "get_current_weather"
)
assert first_chunk.choices[0].finish_reason == "tool_calls"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_openrouter1(): def test_completion_openrouter1():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
@ -2327,6 +2615,56 @@ def test_completion_with_fallbacks():
# test_completion_with_fallbacks() # test_completion_with_fallbacks()
# @pytest.mark.parametrize(
# "function_call",
# [
# [{"role": "function", "name": "get_capital", "content": "Kokoko"}],
# [
# {"role": "function", "name": "get_capital", "content": "Kokoko"},
# {"role": "function", "name": "get_capital", "content": "Kokoko"},
# ],
# ],
# )
# @pytest.mark.parametrize(
# "tool_call",
# [
# [{"role": "tool", "tool_call_id": "1234", "content": "Kokoko"}],
# [
# {"role": "tool", "tool_call_id": "12344", "content": "Kokoko"},
# {"role": "tool", "tool_call_id": "1214", "content": "Kokoko"},
# ],
# ],
# )
def test_completion_anthropic_hanging():
litellm.set_verbose = True
litellm.modify_params = True
messages = [
{
"role": "user",
"content": "What's the capital of fictional country Ubabababababaaba? Use your tools.",
},
{
"role": "assistant",
"function_call": {
"name": "get_capital",
"arguments": '{"country": "Ubabababababaaba"}',
},
},
{"role": "function", "name": "get_capital", "content": "Kokoko"},
]
converted_messages = anthropic_messages_pt(messages)
print(f"converted_messages: {converted_messages}")
## ENSURE USER / ASSISTANT ALTERNATING
for i, msg in enumerate(converted_messages):
if i < len(converted_messages) - 1:
assert msg["role"] != converted_messages[i + 1]["role"]
def test_completion_anyscale_api(): def test_completion_anyscale_api():
try: try:
# litellm.set_verbose=True # litellm.set_verbose=True
@ -2696,6 +3034,7 @@ def test_completion_palm_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_watsonx(): def test_completion_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2" model_name = "watsonx/ibm/granite-13b-chat-v2"
@ -2713,10 +3052,57 @@ def test_completion_watsonx():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"provider, model, project, region_name, token",
[
("azure", "chatgpt-v-2", None, None, "test-token"),
("vertex_ai", "anthropic-claude-3", "adroit-crow-1", "us-east1", None),
("watsonx", "ibm/granite", "96946574", "dallas", "1234"),
("bedrock", "anthropic.claude-3", None, "us-east-1", None),
],
)
def test_unified_auth_params(provider, model, project, region_name, token):
"""
Check if params = ["project", "region_name", "token"]
are correctly translated for = ["azure", "vertex_ai", "watsonx", "aws"]
tests get_optional_params
"""
data = {
"project": project,
"region_name": region_name,
"token": token,
"custom_llm_provider": provider,
"model": model,
}
translated_optional_params = litellm.utils.get_optional_params(**data)
if provider == "azure":
special_auth_params = (
litellm.AzureOpenAIConfig().get_mapped_special_auth_params()
)
elif provider == "bedrock":
special_auth_params = (
litellm.AmazonBedrockGlobalConfig().get_mapped_special_auth_params()
)
elif provider == "vertex_ai":
special_auth_params = litellm.VertexAIConfig().get_mapped_special_auth_params()
elif provider == "watsonx":
special_auth_params = (
litellm.IBMWatsonXAIConfig().get_mapped_special_auth_params()
)
for param, value in special_auth_params.items():
assert param in data
assert value in translated_optional_params
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_watsonx(): async def test_acompletion_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID") model_name = "watsonx/ibm/granite-13b-chat-v2"
print("testing watsonx") print("testing watsonx")
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
@ -2724,7 +3110,6 @@ async def test_acompletion_watsonx():
messages=messages, messages=messages,
temperature=0.2, temperature=0.2,
max_tokens=80, max_tokens=80,
space_id=os.getenv("WATSONX_SPACE_ID_TEST"),
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)

View file

@ -328,3 +328,56 @@ def test_dalle_3_azure_cost_tracking():
completion_response=response, call_type="image_generation" completion_response=response, call_type="image_generation"
) )
assert cost > 0 assert cost > 0
def test_replicate_llama3_cost_tracking():
litellm.set_verbose = True
model = "replicate/meta/meta-llama-3-8b-instruct"
litellm.register_model(
{
"replicate/meta/meta-llama-3-8b-instruct": {
"input_cost_per_token": 0.00000005,
"output_cost_per_token": 0.00000025,
"litellm_provider": "replicate",
}
}
)
response = litellm.ModelResponse(
id="chatcmpl-cad7282f-7f68-41e7-a5ab-9eb33ae301dc",
choices=[
litellm.utils.Choices(
finish_reason="stop",
index=0,
message=litellm.utils.Message(
content="I'm doing well, thanks for asking! I'm here to help you with any questions or tasks you may have. How can I assist you today?",
role="assistant",
),
)
],
created=1714401369,
model="replicate/meta/meta-llama-3-8b-instruct",
object="chat.completion",
system_fingerprint=None,
usage=litellm.utils.Usage(
prompt_tokens=48, completion_tokens=31, total_tokens=79
),
)
cost = litellm.completion_cost(
completion_response=response,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
print(f"cost: {cost}")
cost = round(cost, 5)
expected_cost = round(
litellm.model_cost["replicate/meta/meta-llama-3-8b-instruct"][
"input_cost_per_token"
]
* 48
+ litellm.model_cost["replicate/meta/meta-llama-3-8b-instruct"][
"output_cost_per_token"
]
* 31,
5,
)
assert cost == expected_cost

View file

@ -26,6 +26,9 @@ class DBModel(BaseModel):
model_info: dict model_info: dict
litellm_params: dict litellm_params: dict
class Config:
protected_namespaces = ()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_deployment(): async def test_delete_deployment():

View file

@ -529,6 +529,7 @@ def test_chat_bedrock_stream():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_bedrock_stream(): async def test_async_chat_bedrock_stream():
try: try:
litellm.set_verbose = True
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = await litellm.acompletion( response = await litellm.acompletion(

View file

@ -483,6 +483,8 @@ def test_mistral_embeddings():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="local test")
def test_watsonx_embeddings(): def test_watsonx_embeddings():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -41,6 +41,30 @@ exception_models = [
] ]
@pytest.mark.asyncio
async def test_content_policy_exception_azure():
try:
# this is ony a test - we needed some way to invoke the exception :(
litellm.set_verbose = True
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "where do I buy lethal drugs from"}],
)
except litellm.ContentPolicyViolationError as e:
print("caught a content policy violation error! Passed")
print("exception", e)
# assert that the first 100 chars of the message is returned in the exception
assert (
"Messages: [{'role': 'user', 'content': 'where do I buy lethal drugs from'}]"
in str(e)
)
assert "Model: azure/chatgpt-v-2" in str(e)
pass
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# Test 1: Context Window Errors # Test 1: Context Window Errors
@pytest.mark.skip(reason="AWS Suspended Account") @pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.parametrize("model", exception_models) @pytest.mark.parametrize("model", exception_models)
@ -561,7 +585,7 @@ def test_router_completion_vertex_exception():
pytest.fail("Request should have failed - bad api key") pytest.fail("Request should have failed - bad api key")
except Exception as e: except Exception as e:
print("exception: ", e) print("exception: ", e)
assert "model: vertex_ai/gemini-pro" in str(e) assert "Model: gemini-pro" in str(e)
assert "model_group: vertex-gemini-pro" in str(e) assert "model_group: vertex-gemini-pro" in str(e)
assert "deployment: vertex_ai/gemini-pro" in str(e) assert "deployment: vertex_ai/gemini-pro" in str(e)
@ -580,9 +604,8 @@ def test_litellm_completion_vertex_exception():
pytest.fail("Request should have failed - bad api key") pytest.fail("Request should have failed - bad api key")
except Exception as e: except Exception as e:
print("exception: ", e) print("exception: ", e)
assert "model: vertex_ai/gemini-pro" in str(e) assert "Model: gemini-pro" in str(e)
assert "model_group" not in str(e) assert "vertex_project: bad-project" in str(e)
assert "deployment" not in str(e)
# # test_invalid_request_error(model="command-nightly") # # test_invalid_request_error(model="command-nightly")

View file

@ -40,3 +40,32 @@ def test_vertex_projects():
# test_vertex_projects() # test_vertex_projects()
def test_bedrock_embed_v2_regular():
model, custom_llm_provider, _, _ = get_llm_provider(
model="bedrock/amazon.titan-embed-text-v2:0"
)
optional_params = get_optional_params_embeddings(
model=model,
dimensions=512,
custom_llm_provider=custom_llm_provider,
)
print(f"received optional_params: {optional_params}")
assert optional_params == {"dimensions": 512}
def test_bedrock_embed_v2_with_drop_params():
litellm.drop_params = True
model, custom_llm_provider, _, _ = get_llm_provider(
model="bedrock/amazon.titan-embed-text-v2:0"
)
optional_params = get_optional_params_embeddings(
model=model,
dimensions=512,
user="test-litellm-user-5",
encoding_format="base64",
custom_llm_provider=custom_llm_provider,
)
print(f"received optional_params: {optional_params}")
assert optional_params == {"dimensions": 512}

View file

@ -136,8 +136,8 @@ def test_image_generation_bedrock():
litellm.set_verbose = True litellm.set_verbose = True
response = litellm.image_generation( response = litellm.image_generation(
prompt="A cute baby sea otter", prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0", model="bedrock/stability.stable-diffusion-xl-v1",
aws_region_name="us-east-1", aws_region_name="us-west-2",
) )
print(f"response: {response}") print(f"response: {response}")
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
@ -156,8 +156,8 @@ async def test_aimage_generation_bedrock_with_optional_params():
try: try:
response = await litellm.aimage_generation( response = await litellm.aimage_generation(
prompt="A cute baby sea otter", prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0", model="bedrock/stability.stable-diffusion-xl-v1",
size="128x128", size="256x256",
) )
print(f"response: {response}") print(f"response: {response}")
except litellm.RateLimitError as e: except litellm.RateLimitError as e:

View file

@ -201,6 +201,7 @@ async def test_router_atext_completion_streaming():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_completion_streaming(): async def test_router_completion_streaming():
litellm.set_verbose = True
messages = [ messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"} {"role": "user", "content": "Hello, can you generate a 500 words poem?"}
] ]
@ -219,9 +220,9 @@ async def test_router_completion_streaming():
{ {
"model_name": "azure-model", "model_name": "azure-model",
"litellm_params": { "litellm_params": {
"model": "azure/gpt-35-turbo", "model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY", "api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", "api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 6, "rpm": 6,
}, },
"model_info": {"id": 2}, "model_info": {"id": 2},
@ -229,9 +230,9 @@ async def test_router_completion_streaming():
{ {
"model_name": "azure-model", "model_name": "azure-model",
"litellm_params": { "litellm_params": {
"model": "azure/gpt-35-turbo", "model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY", "api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com", "api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 6, "rpm": 6,
}, },
"model_info": {"id": 3}, "model_info": {"id": 3},
@ -262,4 +263,4 @@ async def test_router_completion_streaming():
## check if calls equally distributed ## check if calls equally distributed
cache_dict = router.cache.get_cache(key=cache_key) cache_dict = router.cache.get_cache(key=cache_key)
for k, v in cache_dict.items(): for k, v in cache_dict.items():
assert v == 1 assert v == 1, f"Failed. K={k} called v={v} times, cache_dict={cache_dict}"

View file

@ -7,7 +7,7 @@ import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os import os, copy
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -20,6 +20,96 @@ from litellm.caching import DualCache
### UNIT TESTS FOR LATENCY ROUTING ### ### UNIT TESTS FOR LATENCY ROUTING ###
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_latency_memory_leak(sync_mode):
"""
Test to make sure there's no memory leak caused by lowest latency routing
- make 10 calls -> check memory
- make 11th call -> no change in memory
"""
test_cache = DualCache()
model_list = []
lowest_latency_logger = LowestLatencyLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(5)
end_time = time.time()
for _ in range(10):
if sync_mode:
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
else:
await lowest_latency_logger.async_log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
latency_key = f"{model_group}_map"
cache_value = copy.deepcopy(
test_cache.get_cache(key=latency_key)
) # MAKE SURE NO MEMORY LEAK IN CACHING OBJECT
if sync_mode:
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
else:
await lowest_latency_logger.async_log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
new_cache_value = test_cache.get_cache(key=latency_key)
# Assert that the size of the cache doesn't grow unreasonably
assert get_size(new_cache_value) <= get_size(
cache_value
), f"Memory leak detected in function call! new_cache size={get_size(new_cache_value)}, old cache size={get_size(cache_value)}"
def get_size(obj, seen=None):
# From https://goshippo.com/blog/measure-real-size-any-python-object/
# Recursively finds size of objects
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
seen.add(obj_id)
if isinstance(obj, dict):
size += sum([get_size(v, seen) for v in obj.values()])
size += sum([get_size(k, seen) for k in obj.keys()])
elif hasattr(obj, "__dict__"):
size += get_size(obj.__dict__, seen)
elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)):
size += sum([get_size(i, seen) for i in obj])
return size
def test_latency_updated(): def test_latency_updated():
test_cache = DualCache() test_cache = DualCache()
model_list = [] model_list = []
@ -555,3 +645,171 @@ async def test_lowest_latency_routing_with_timeouts():
# ALL the Requests should have been routed to the fast-endpoint # ALL the Requests should have been routed to the fast-endpoint
assert deployments["fast-endpoint"] == 10 assert deployments["fast-endpoint"] == 10
@pytest.mark.asyncio
async def test_lowest_latency_routing_first_pick():
"""
PROD Test:
- When all deployments are latency=0, it should randomly pick a deployment
- IT SHOULD NEVER PICK THE Very First deployment everytime all deployment latencies are 0
- This ensures that after the ttl window resets it randomly picks a deployment
"""
import litellm
litellm.set_verbose = True
router = Router(
model_list=[
{
"model_name": "azure-model",
"litellm_params": {
"model": "openai/fast-endpoint",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"api_key": "fake-key",
},
"model_info": {"id": "fast-endpoint"},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "openai/fast-endpoint-2",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"api_key": "fake-key",
},
"model_info": {"id": "fast-endpoint-2"},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "openai/fast-endpoint-2",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"api_key": "fake-key",
},
"model_info": {"id": "fast-endpoint-3"},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "openai/fast-endpoint-2",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"api_key": "fake-key",
},
"model_info": {"id": "fast-endpoint-4"},
},
],
routing_strategy="latency-based-routing",
routing_strategy_args={"ttl": 0.0000000001},
set_verbose=True,
debug_level="DEBUG",
) # type: ignore
deployments = {}
for _ in range(5):
response = await router.acompletion(
model="azure-model", messages=[{"role": "user", "content": "hello"}]
)
print(response)
_picked_model_id = response._hidden_params["model_id"]
if _picked_model_id not in deployments:
deployments[_picked_model_id] = 1
else:
deployments[_picked_model_id] += 1
await asyncio.sleep(0.000000000005)
print("deployments", deployments)
# assert that len(deployments) >1
assert len(deployments) > 1
@pytest.mark.parametrize("buffer", [0, 1])
@pytest.mark.asyncio
async def test_lowest_latency_routing_buffer(buffer):
"""
Allow shuffling calls within a certain latency buffer
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="latency-based-routing",
set_verbose=False,
num_retries=3,
routing_strategy_args={"lowest_latency_buffer": buffer},
) # type: ignore
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 1},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(3)
end_time = time.time()
router.lowestlatency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = 2
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 2},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
time.sleep(2)
end_time = time.time()
router.lowestlatency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
selected_deployments = {}
for _ in range(50):
print(router.get_available_deployment(model="azure-model"))
selected_deployments[
router.get_available_deployment(model="azure-model")["model_info"]["id"]
] = 1
if buffer == 0:
assert len(selected_deployments.keys()) == 1
else:
assert len(selected_deployments.keys()) == 2

View file

@ -5,13 +5,58 @@ import pytest
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
import litellm import litellm
from litellm.utils import get_optional_params_embeddings from litellm.utils import get_optional_params_embeddings, get_optional_params
from litellm.llms.prompt_templates.factory import (
map_system_message_pt,
)
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
)
## get_optional_params_embeddings ## get_optional_params_embeddings
### Models: OpenAI, Azure, Bedrock ### Models: OpenAI, Azure, Bedrock
### Scenarios: w/ optional params + litellm.drop_params = True ### Scenarios: w/ optional params + litellm.drop_params = True
def test_supports_system_message():
"""
Check if litellm.completion(...,supports_system_message=False)
"""
messages = [
ChatCompletionSystemMessageParam(role="system", content="Listen here!"),
ChatCompletionUserMessageParam(role="user", content="Hello there!"),
]
new_messages = map_system_message_pt(messages=messages)
assert len(new_messages) == 1
assert new_messages[0]["role"] == "user"
## confirm you can make a openai call with this param
response = litellm.completion(
model="gpt-3.5-turbo", messages=new_messages, supports_system_message=False
)
assert isinstance(response, litellm.ModelResponse)
@pytest.mark.parametrize(
"stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)]
)
def test_anthropic_optional_params(stop_sequence, expected_count):
"""
Test if whitespace character optional param is dropped by anthropic
"""
litellm.drop_params = True
optional_params = get_optional_params(
model="claude-3", custom_llm_provider="anthropic", stop=stop_sequence
)
assert len(optional_params) == expected_count
def test_bedrock_optional_params_embeddings(): def test_bedrock_optional_params_embeddings():
litellm.drop_params = True litellm.drop_params = True
optional_params = get_optional_params_embeddings( optional_params = get_optional_params_embeddings(

View file

@ -1,6 +1,8 @@
# test that the proxy actually does exception mapping to the OpenAI format # test that the proxy actually does exception mapping to the OpenAI format
import sys, os import sys, os
from unittest import mock
import json
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -12,13 +14,30 @@ sys.path.insert(
import pytest import pytest
import litellm, openai import litellm, openai
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import Response
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
router, router,
save_worker_config, save_worker_config,
initialize, initialize,
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
invalid_authentication_error_response = Response(
status_code=401,
content=json.dumps({"error": "Invalid Authentication"}),
)
context_length_exceeded_error_response_dict = {
"error": {
"message": "AzureException - Error code: 400 - {'error': {'message': \"This model's maximum context length is 4096 tokens. However, your messages resulted in 10007 tokens. Please reduce the length of the messages.\", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}",
"type": None,
"param": None,
"code": 400,
},
}
context_length_exceeded_error_response = Response(
status_code=400,
content=json.dumps(context_length_exceeded_error_response_dict),
)
@pytest.fixture @pytest.fixture
def client(): def client():
@ -60,7 +79,11 @@ def test_chat_completion_exception(client):
# raise openai.AuthenticationError # raise openai.AuthenticationError
def test_chat_completion_exception_azure(client): @mock.patch(
"litellm.proxy.proxy_server.llm_router.acompletion",
return_value=invalid_authentication_error_response,
)
def test_chat_completion_exception_azure(mock_acompletion, client):
try: try:
# Your test data # Your test data
test_data = { test_data = {
@ -73,6 +96,15 @@ def test_chat_completion_exception_azure(client):
response = client.post("/chat/completions", json=test_data) response = client.post("/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
**test_data,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
json_response = response.json() json_response = response.json()
print("keys in json response", json_response.keys()) print("keys in json response", json_response.keys())
assert json_response.keys() == {"error"} assert json_response.keys() == {"error"}
@ -90,12 +122,21 @@ def test_chat_completion_exception_azure(client):
# raise openai.AuthenticationError # raise openai.AuthenticationError
def test_embedding_auth_exception_azure(client): @mock.patch(
"litellm.proxy.proxy_server.llm_router.aembedding",
return_value=invalid_authentication_error_response,
)
def test_embedding_auth_exception_azure(mock_aembedding, client):
try: try:
# Your test data # Your test data
test_data = {"model": "azure-embedding", "input": ["hi"]} test_data = {"model": "azure-embedding", "input": ["hi"]}
response = client.post("/embeddings", json=test_data) response = client.post("/embeddings", json=test_data)
mock_aembedding.assert_called_once_with(
**test_data,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
print("Response from proxy=", response) print("Response from proxy=", response)
json_response = response.json() json_response = response.json()
@ -169,7 +210,7 @@ def test_chat_completion_exception_any_model(client):
) )
assert isinstance(openai_exception, openai.BadRequestError) assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message _error_message = openai_exception.message
assert "Invalid model name passed in model=Lite-GPT-12" in str(_error_message) assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
@ -197,14 +238,18 @@ def test_embedding_exception_any_model(client):
print("Exception raised=", openai_exception) print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.BadRequestError) assert isinstance(openai_exception, openai.BadRequestError)
_error_message = openai_exception.message _error_message = openai_exception.message
assert "Invalid model name passed in model=Lite-GPT-12" in str(_error_message) assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(_error_message)
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
# raise openai.BadRequestError # raise openai.BadRequestError
def test_chat_completion_exception_azure_context_window(client): @mock.patch(
"litellm.proxy.proxy_server.llm_router.acompletion",
return_value=context_length_exceeded_error_response,
)
def test_chat_completion_exception_azure_context_window(mock_acompletion, client):
try: try:
# Your test data # Your test data
test_data = { test_data = {
@ -219,20 +264,22 @@ def test_chat_completion_exception_azure_context_window(client):
response = client.post("/chat/completions", json=test_data) response = client.post("/chat/completions", json=test_data)
print("got response from server", response) print("got response from server", response)
mock_acompletion.assert_called_once_with(
**test_data,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
json_response = response.json() json_response = response.json()
print("keys in json response", json_response.keys()) print("keys in json response", json_response.keys())
assert json_response.keys() == {"error"} assert json_response.keys() == {"error"}
assert json_response == { assert json_response == context_length_exceeded_error_response_dict
"error": {
"message": "AzureException - Error code: 400 - {'error': {'message': \"This model's maximum context length is 4096 tokens. However, your messages resulted in 10007 tokens. Please reduce the length of the messages.\", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}",
"type": None,
"param": None,
"code": 400,
}
}
# make an openai client to call _make_status_error_from_response # make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything") openai_client = openai.OpenAI(api_key="anything")

View file

@ -1,5 +1,6 @@
import sys, os import sys, os
import traceback import traceback
from unittest import mock
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -35,6 +36,77 @@ token = "sk-1234"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
example_completion_result = {
"choices": [
{
"message": {
"content": "Whispers of the wind carry dreams to me.",
"role": "assistant"
}
}
],
}
example_embedding_result = {
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
],
}
],
"model": "text-embedding-3-small",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}
example_image_generation_result = {
"created": 1589478378,
"data": [
{
"url": "https://..."
},
{
"url": "https://..."
}
]
}
def mock_patch_acompletion():
return mock.patch(
"litellm.proxy.proxy_server.llm_router.acompletion",
return_value=example_completion_result,
)
def mock_patch_aembedding():
return mock.patch(
"litellm.proxy.proxy_server.llm_router.aembedding",
return_value=example_embedding_result,
)
def mock_patch_aimage_generation():
return mock.patch(
"litellm.proxy.proxy_server.llm_router.aimage_generation",
return_value=example_image_generation_result,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client_no_auth(): def client_no_auth():
@ -52,7 +124,8 @@ def client_no_auth():
return TestClient(app) return TestClient(app)
def test_chat_completion(client_no_auth): @mock_patch_acompletion()
def test_chat_completion(mock_acompletion, client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -66,6 +139,19 @@ def test_chat_completion(client_no_auth):
print("testing proxy server with chat completions") print("testing proxy server with chat completions")
response = client_no_auth.post("/v1/chat/completions", json=test_data) response = client_no_auth.post("/v1/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "hi"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
print(f"response - {response.text}") print(f"response - {response.text}")
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -77,7 +163,8 @@ def test_chat_completion(client_no_auth):
# Run the test # Run the test
def test_chat_completion_azure(client_no_auth): @mock_patch_acompletion()
def test_chat_completion_azure(mock_acompletion, client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -92,6 +179,19 @@ def test_chat_completion_azure(client_no_auth):
print("testing proxy server with Azure Request /chat/completions") print("testing proxy server with Azure Request /chat/completions")
response = client_no_auth.post("/v1/chat/completions", json=test_data) response = client_no_auth.post("/v1/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": "write 1 sentence poem"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")
@ -104,8 +204,51 @@ def test_chat_completion_azure(client_no_auth):
# test_chat_completion_azure() # test_chat_completion_azure()
@mock_patch_acompletion()
def test_openai_deployments_model_chat_completions_azure(mock_acompletion, client_no_auth):
global headers
try:
# Your test data
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
}
url = "/openai/deployments/azure/chatgpt-v-2/chat/completions"
print(f"testing proxy server with Azure Request {url}")
response = client_no_auth.post(url, json=test_data)
mock_acompletion.assert_called_once_with(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": "write 1 sentence poem"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200
result = response.json()
print(f"Received response: {result}")
assert len(result["choices"][0]["message"]["content"]) > 0
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
# Run the test
# test_openai_deployments_model_chat_completions_azure()
### EMBEDDING ### EMBEDDING
def test_embedding(client_no_auth): @mock_patch_aembedding()
def test_embedding(mock_aembedding, client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -117,6 +260,13 @@ def test_embedding(client_no_auth):
response = client_no_auth.post("/v1/embeddings", json=test_data) response = client_no_auth.post("/v1/embeddings", json=test_data)
mock_aembedding.assert_called_once_with(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(len(result["data"][0]["embedding"])) print(len(result["data"][0]["embedding"]))
@ -125,7 +275,8 @@ def test_embedding(client_no_auth):
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
def test_bedrock_embedding(client_no_auth): @mock_patch_aembedding()
def test_bedrock_embedding(mock_aembedding, client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -137,6 +288,12 @@ def test_bedrock_embedding(client_no_auth):
response = client_no_auth.post("/v1/embeddings", json=test_data) response = client_no_auth.post("/v1/embeddings", json=test_data)
mock_aembedding.assert_called_once_with(
model="amazon-embeddings",
input=["good morning from litellm"],
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(len(result["data"][0]["embedding"])) print(len(result["data"][0]["embedding"]))
@ -171,7 +328,8 @@ def test_sagemaker_embedding(client_no_auth):
#### IMAGE GENERATION #### IMAGE GENERATION
def test_img_gen(client_no_auth): @mock_patch_aimage_generation()
def test_img_gen(mock_aimage_generation, client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -185,6 +343,14 @@ def test_img_gen(client_no_auth):
response = client_no_auth.post("/v1/images/generations", json=test_data) response = client_no_auth.post("/v1/images/generations", json=test_data)
mock_aimage_generation.assert_called_once_with(
model='dall-e-3',
prompt='A cute baby sea otter',
n=1,
size='1024x1024',
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(len(result["data"][0]["url"])) print(len(result["data"][0]["url"]))
@ -249,7 +415,8 @@ class MyCustomHandler(CustomLogger):
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
def test_chat_completion_optional_params(client_no_auth): @mock_patch_acompletion()
def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
# [PROXY: PROD TEST] - DO NOT DELETE # [PROXY: PROD TEST] - DO NOT DELETE
# This tests if all the /chat/completion params are passed to litellm # This tests if all the /chat/completion params are passed to litellm
try: try:
@ -267,6 +434,20 @@ def test_chat_completion_optional_params(client_no_auth):
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
print("testing proxy server: optional params") print("testing proxy server: optional params")
response = client_no_auth.post("/v1/chat/completions", json=test_data) response = client_no_auth.post("/v1/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "hi"},
],
max_tokens=10,
user="proxy-user",
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")

View file

@ -0,0 +1,10 @@
import warnings
import pytest
def test_namespace_conflict_warning():
with warnings.catch_warnings(record=True) as recorded_warnings:
warnings.simplefilter("always") # Capture all warnings
import litellm
# Check that no warning with the specific message was raised
assert not any("conflict with protected namespace" in str(w.message) for w in recorded_warnings), "Test failed: 'conflict with protected namespace' warning was encountered!"

View file

@ -1,7 +1,7 @@
#### What this tests #### #### What this tests ####
# This tests litellm router # This tests litellm router
import sys, os, time import sys, os, time, openai
import traceback, asyncio import traceback, asyncio
import pytest import pytest
@ -19,6 +19,45 @@ import os, httpx
load_dotenv() load_dotenv()
@pytest.mark.parametrize("num_retries", [None, 2])
@pytest.mark.parametrize("max_retries", [None, 4])
def test_router_num_retries_init(num_retries, max_retries):
"""
- test when num_retries set v/s not
- test client value when max retries set v/s not
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"max_retries": max_retries,
},
"model_info": {"id": 12345},
},
],
num_retries=num_retries,
)
if num_retries is not None:
assert router.num_retries == num_retries
else:
assert router.num_retries == openai.DEFAULT_MAX_RETRIES
model_client = router._get_client(
{"model_info": {"id": 12345}}, client_type="async", kwargs={}
)
if max_retries is not None:
assert getattr(model_client, "max_retries") == max_retries
else:
assert getattr(model_client, "max_retries") == 0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)] "timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
) )
@ -65,6 +104,42 @@ def test_router_timeout_init(timeout, ssl_verify):
) )
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_router_retries(sync_mode):
"""
- make sure retries work as expected
"""
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "bad-key"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
},
]
router = Router(model_list=model_list, num_retries=2)
if sync_mode:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
else:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mistral_api_base", "mistral_api_base",
[ [
@ -99,6 +174,7 @@ def test_router_azure_ai_studio_init(mistral_api_base):
print(f"uri_reference: {uri_reference}") print(f"uri_reference: {uri_reference}")
assert "/v1/" in uri_reference assert "/v1/" in uri_reference
assert uri_reference.count("v1") == 1
def test_exception_raising(): def test_exception_raising():
@ -1078,6 +1154,7 @@ def test_consistent_model_id():
assert id1 == id2 assert id1 == id2
@pytest.mark.skip(reason="local test")
def test_reading_keys_os_environ(): def test_reading_keys_os_environ():
import openai import openai
@ -1177,6 +1254,7 @@ def test_reading_keys_os_environ():
# test_reading_keys_os_environ() # test_reading_keys_os_environ()
@pytest.mark.skip(reason="local test")
def test_reading_openai_keys_os_environ(): def test_reading_openai_keys_os_environ():
import openai import openai

View file

@ -46,6 +46,7 @@ def test_async_fallbacks(caplog):
router = Router( router = Router(
model_list=model_list, model_list=model_list,
fallbacks=[{"gpt-3.5-turbo": ["azure/gpt-3.5-turbo"]}], fallbacks=[{"gpt-3.5-turbo": ["azure/gpt-3.5-turbo"]}],
num_retries=1,
) )
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
@ -81,8 +82,8 @@ def test_async_fallbacks(caplog):
# Define the expected log messages # Define the expected log messages
# - error request, falling back notice, success notice # - error request, falling back notice, success notice
expected_logs = [ expected_logs = [
"Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None\n\nRouter Redis Caching=None", "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m", "litellm.acompletion(model=None)\x1b[31m Exception No deployments available for selected model, passed model=gpt-3.5-turbo\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo", "Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
] ]

View file

@ -22,10 +22,10 @@ class MyCustomHandler(CustomLogger):
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print(f"Pre-API Call")
print( print(
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}"
) )
self.previous_models += len( self.previous_models = len(
kwargs["litellm_params"]["metadata"]["previous_models"] kwargs["litellm_params"]["metadata"].get("previous_models", [])
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]} ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}") print(f"self.previous_models: {self.previous_models}")
@ -127,7 +127,7 @@ def test_sync_fallbacks():
response = router.completion(**kwargs) response = router.completion(**kwargs)
print(f"response: {response}") print(f"response: {response}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4
print("Passed ! Test router_fallbacks: test_sync_fallbacks()") print("Passed ! Test router_fallbacks: test_sync_fallbacks()")
router.reset() router.reset()
@ -140,7 +140,7 @@ def test_sync_fallbacks():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_fallbacks(): async def test_async_fallbacks():
litellm.set_verbose = False litellm.set_verbose = True
model_list = [ model_list = [
{ # list of model deployments { # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name "model_name": "azure/gpt-3.5-turbo", # openai model name
@ -209,12 +209,13 @@ async def test_async_fallbacks():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
kwargs["model"] = "azure/gpt-3.5-turbo"
response = await router.acompletion(**kwargs) response = await router.acompletion(**kwargs)
print(f"customHandler.previous_models: {customHandler.previous_models}") print(f"customHandler.previous_models: {customHandler.previous_models}")
await asyncio.sleep( await asyncio.sleep(
0.05 0.05
) # allow a delay as success_callbacks are on a separate thread ) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -268,7 +269,7 @@ def test_sync_fallbacks_embeddings():
response = router.embedding(**kwargs) response = router.embedding(**kwargs)
print(f"customHandler.previous_models: {customHandler.previous_models}") print(f"customHandler.previous_models: {customHandler.previous_models}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -322,7 +323,7 @@ async def test_async_fallbacks_embeddings():
await asyncio.sleep( await asyncio.sleep(
0.05 0.05
) # allow a delay as success_callbacks are on a separate thread ) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -401,7 +402,7 @@ def test_dynamic_fallbacks_sync():
response = router.completion(**kwargs) response = router.completion(**kwargs)
print(f"response: {response}") print(f"response: {response}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
@ -487,7 +488,7 @@ async def test_dynamic_fallbacks_async():
await asyncio.sleep( await asyncio.sleep(
0.05 0.05
) # allow a delay as success_callbacks are on a separate thread ) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
@ -572,7 +573,7 @@ async def test_async_fallbacks_streaming():
await asyncio.sleep( await asyncio.sleep(
0.05 0.05
) # allow a delay as success_callbacks are on a separate thread ) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -751,7 +752,7 @@ async def test_async_fallbacks_max_retries_per_request():
router.reset() router.reset()
def test_usage_based_routing_fallbacks(): def test_ausage_based_routing_fallbacks():
try: try:
# [Prod Test] # [Prod Test]
# IT tests Usage Based Routing with fallbacks # IT tests Usage Based Routing with fallbacks
@ -765,10 +766,10 @@ def test_usage_based_routing_fallbacks():
load_dotenv() load_dotenv()
# Constants for TPM and RPM allocation # Constants for TPM and RPM allocation
AZURE_FAST_TPM = 3 AZURE_FAST_RPM = 1
AZURE_BASIC_TPM = 4 AZURE_BASIC_RPM = 1
OPENAI_TPM = 400 OPENAI_RPM = 2
ANTHROPIC_TPM = 100000 ANTHROPIC_RPM = 100000
def get_azure_params(deployment_name: str): def get_azure_params(deployment_name: str):
params = { params = {
@ -797,22 +798,26 @@ def test_usage_based_routing_fallbacks():
{ {
"model_name": "azure/gpt-4-fast", "model_name": "azure/gpt-4-fast",
"litellm_params": get_azure_params("chatgpt-v-2"), "litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": AZURE_FAST_TPM, "model_info": {"id": 1},
"rpm": AZURE_FAST_RPM,
}, },
{ {
"model_name": "azure/gpt-4-basic", "model_name": "azure/gpt-4-basic",
"litellm_params": get_azure_params("chatgpt-v-2"), "litellm_params": get_azure_params("chatgpt-v-2"),
"tpm": AZURE_BASIC_TPM, "model_info": {"id": 2},
"rpm": AZURE_BASIC_RPM,
}, },
{ {
"model_name": "openai-gpt-4", "model_name": "openai-gpt-4",
"litellm_params": get_openai_params("gpt-3.5-turbo"), "litellm_params": get_openai_params("gpt-3.5-turbo"),
"tpm": OPENAI_TPM, "model_info": {"id": 3},
"rpm": OPENAI_RPM,
}, },
{ {
"model_name": "anthropic-claude-instant-1.2", "model_name": "anthropic-claude-instant-1.2",
"litellm_params": get_anthropic_params("claude-instant-1.2"), "litellm_params": get_anthropic_params("claude-instant-1.2"),
"tpm": ANTHROPIC_TPM, "model_info": {"id": 4},
"rpm": ANTHROPIC_RPM,
}, },
] ]
# litellm.set_verbose=True # litellm.set_verbose=True
@ -830,6 +835,7 @@ def test_usage_based_routing_fallbacks():
routing_strategy="usage-based-routing", routing_strategy="usage-based-routing",
redis_host=os.environ["REDIS_HOST"], redis_host=os.environ["REDIS_HOST"],
redis_port=os.environ["REDIS_PORT"], redis_port=os.environ["REDIS_PORT"],
num_retries=0,
) )
messages = [ messages = [
@ -842,13 +848,13 @@ def test_usage_based_routing_fallbacks():
mock_response="very nice to meet you", mock_response="very nice to meet you",
) )
print("response: ", response) print("response: ", response)
print("response._hidden_params: ", response._hidden_params) print(f"response._hidden_params: {response._hidden_params}")
# in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass # in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass
# the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM
assert response._hidden_params["custom_llm_provider"] == "openai" assert response._hidden_params["model_id"] == "1"
# now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2 # now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2
for i in range(20): for i in range(21):
response = router.completion( response = router.completion(
model="azure/gpt-4-fast", model="azure/gpt-4-fast",
messages=messages, messages=messages,
@ -857,9 +863,9 @@ def test_usage_based_routing_fallbacks():
) )
print("response: ", response) print("response: ", response)
print("response._hidden_params: ", response._hidden_params) print("response._hidden_params: ", response._hidden_params)
if i == 19: if i == 20:
# by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2 # by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2
assert response._hidden_params["custom_llm_provider"] == "anthropic" assert response._hidden_params["model_id"] == "4"
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred {e}") pytest.fail(f"An exception occurred {e}")

Some files were not shown because too many files have changed in this diff Show more