Merge remote-tracking branch 'upstream/main' into patch-1

This commit is contained in:
Cam Parry 2023-12-13 09:22:25 +10:00
commit cb13018a28
No known key found for this signature in database
121 changed files with 23783 additions and 1997 deletions

View file

@ -29,12 +29,16 @@ jobs:
pip install pytest-asyncio
pip install mypy
pip install -q google-generativeai
pip install google-cloud-aiplatform
pip install "boto3>=1.28.57"
pip install appdirs
pip install langchain
pip install langfuse
pip install numpydoc
pip install traceloop-sdk==0.0.69
pip install openai
pip install prisma
pip install langfuse
- save_cache:
paths:
- ./venv
@ -44,7 +48,7 @@ jobs:
command: |
cd litellm
python -m pip install types-requests types-setuptools types-redis
if ! python -m mypy . --ignore-missing-imports --explicit-package-bases; then
if ! python -m mypy . --ignore-missing-imports; then
echo "mypy detected errors"
exit 1
fi
@ -57,7 +61,7 @@ jobs:
command: |
pwd
ls
python -m pytest -vv litellm/tests/ -x --junitxml=test-results/junit.xml --durations=5
python -m pytest -vv litellm/tests/ -x --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
# Store test results
@ -74,6 +78,11 @@ jobs:
steps:
- checkout
- run:
name: Copy model_prices_and_context_window File to model_prices_and_context_window_backup
command: |
cp model_prices_and_context_window.json litellm/model_prices_and_context_window_backup.json
- run:
name: Check if litellm dir was updated or if pyproject.toml was modified

View file

@ -10,4 +10,5 @@ anthropic
boto3
appdirs
orjson
pydantic
pydantic
google-cloud-aiplatform

3
.gitignore vendored
View file

@ -19,3 +19,6 @@ litellm/proxy/_secret_config.yaml
litellm/tests/aiologs.log
litellm/tests/exception_data.txt
litellm/tests/config_*.yaml
litellm/tests/langfuse.log
litellm/tests/test_custom_logger.py
litellm/tests/langfuse.log

View file

@ -3,6 +3,15 @@ repos:
rev: 3.8.4 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/|^litellm/integrations/
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
files: litellm/.*\.py
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
exclude: ^litellm/tests/

View file

@ -34,6 +34,9 @@ COPY --from=builder /app/wheels /app/wheels
RUN pip install --no-index --find-links=/app/wheels -r requirements.txt
# Trigger the Prisma CLI to be installed
RUN prisma -v
EXPOSE 4000/tcp
# Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy

View file

@ -5,7 +5,7 @@
<p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, Cohere, TogetherAI, Azure, OpenAI, etc.]
<br>
</p>
<h4 align="center"><a href="https://github.com/BerriAI/litellm/tree/main/litellm/proxy" target="_blank">OpenAI-Compatible Server</a></h4>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a></h4>
<h4 align="center">
<a href="https://pypi.org/project/litellm/" target="_blank">
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
@ -62,6 +62,22 @@ response = completion(model="command-nightly", messages=messages)
print(response)
```
## Async ([Docs](https://docs.litellm.ai/docs/completion/stream#async-completion))
```python
from litellm import acompletion
import asyncio
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
return response
response = asyncio.run(test_get_response())
print(response)
```
## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream))
liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response.
Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.)

Binary file not shown.

BIN
dist/litellm-1.12.5.dev1.tar.gz vendored Normal file

Binary file not shown.

View file

@ -0,0 +1,12 @@
version: "3.9"
services:
litellm:
image: ghcr.io/berriai/litellm:main
ports:
- "8000:8000" # Map the container port to the host, change the host port if necessary
volumes:
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "8000", "--num_workers", "8" ]
# ...rest of your docker-compose config if any

View file

@ -1,11 +1,14 @@
# Redis Cache
[**See Code**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/caching.py#L71)
### Pre-requisites
Install redis
```
pip install redis
```
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
### Usage
### Quick Start
```python
import litellm
from litellm import completion
@ -55,6 +58,11 @@ litellm.cache = cache # set litellm.cache to your cache
### Detecting Cached Responses
For resposes that were returned as cache hit, the response includes a param `cache` = True
:::info
Only valid for OpenAI <= 0.28.1 [Let us know if you still need this](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+)
:::
Example response with cache hit
```python
{

View file

@ -6,7 +6,7 @@ import TabItem from '@theme/TabItem';
## Common Params
LiteLLM accepts and translates the [OpenAI Chat Completion params](https://platform.openai.com/docs/api-reference/chat/create) across all providers.
### usage
### Usage
```python
import litellm
@ -23,7 +23,7 @@ response = litellm.completion(
print(response)
```
### translated OpenAI params
### Translated OpenAI params
This is a list of openai params we translate across providers.
This list is constantly being updated.
@ -40,7 +40,7 @@ This list is constantly being updated.
|AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|VertexAI| ✅ | ✅ | | ✅ | | | | | | |
|Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|Sagemaker| ✅ | ✅ | | ✅ | | | | | | |
|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
@ -185,6 +185,25 @@ def completion(
- `metadata`: *dict (optional)* - Any additional data you want to be logged when the call is made (sent to logging integrations, eg. promptlayer and accessible via custom callback function)
**CUSTOM MODEL COST**
- `input_cost_per_token`: *float (optional)* - The cost per input token for the completion call
- `output_cost_per_token`: *float (optional)* - The cost per output token for the completion call
**CUSTOM PROMPT TEMPLATE** (See [prompt formatting for more info](./prompt_formatting.md#format-prompt-yourself))
- `initial_prompt_value`: *string (optional)* - Initial string applied at the start of the input messages
- `roles`: *dict (optional)* - Dictionary specifying how to format the prompt based on the role + message passed in via `messages`.
- `final_prompt_value`: *string (optional)* - Final string applied at the end of the input messages
- `bos_token`: *string (optional)* - Initial string applied at the start of a sequence
- `eos_token`: *string (optional)* - Initial string applied at the end of a sequence
- `hf_model_name`: *string (optional)* - [Sagemaker Only] The corresponding huggingface name of the model, used to pull the right chat template for the model.
## Provider-specific Params
Providers might offer params not supported by OpenAI (e.g. top_k). You can pass those in 2 ways:
- via completion(): We'll pass the non-openai param, straight to the provider as part of the request body.

View file

@ -142,6 +142,8 @@ print(response)
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Titan Embeddings - G1 | `embedding(model="amazon.titan-embed-text-v1", input=input)` |
| Cohere Embeddings - English | `embedding(model="cohere.embed-english-v3", input=input)` |
| Cohere Embeddings - Multilingual | `embedding(model="cohere.embed-multilingual-v3", input=input)` |
## Cohere Embedding Models
@ -182,6 +184,17 @@ response = embedding(
input=["good morning from litellm"]
)
```
### Usage - Custom API Base
```python
from litellm import embedding
import os
os.environ['HUGGINGFACE_API_KEY'] = ""
response = embedding(
model='huggingface/microsoft/codebert-base',
input=["good morning from litellm"],
api_base = "https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud"
)
```
| Model Name | Function Call | Required OS Variables |
|-----------------------|--------------------------------------------------------------|-------------------------------------------------|

View file

@ -85,6 +85,43 @@ print(response)
```
## Async Callback Functions
LiteLLM currently supports just async success callback functions for async completion/embedding calls.
```python
import asyncio, litellm
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time):
print(f"On Async Success!")
async def test_chat_openai():
try:
# litellm.set_verbose = True
litellm.success_callback = [async_test_logging_fn]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
except Exception as e:
print(e)
pytest.fail(f"An error occurred - {str(e)}")
asyncio.run(test_chat_openai())
```
:::info
We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007)
:::
## What's in kwargs?
Notice we pass in a kwargs argument to custom callback.

View file

@ -1,5 +1,5 @@
# AWS Sagemaker
LiteLLM supports Llama2 on Sagemaker
LiteLLM supports All Sagemaker Huggingface Jumpstart Models
### API KEYS
```python
@ -42,6 +42,28 @@ response = completion(
)
```
### Applying Prompt Templates
To apply the correct prompt template for your sagemaker deployment, pass in it's hf model name as well.
```python
import os
from litellm import completion
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
messages=messages,
temperature=0.2,
max_tokens=80,
hf_model_name="meta-llama/Llama-2-7b",
)
```
You can also pass in your own [custom prompt template](../completion/prompt_formatting.md#format-prompt-yourself)
### Usage - Streaming
Sagemaker currently does not support streaming - LiteLLM fakes streaming by returning chunks of the response string
@ -64,14 +86,32 @@ for chunk in response:
print(chunk)
```
### AWS Sagemaker Models
### Completion Models
Here's an example of using a sagemaker model with LiteLLM
| Model Name | Function Call |
|-------------------------------|-------------------------------------------------------------------------------------------|
| Your Custom Huggingface Model | `completion(model='sagemaker/<your-deployment-name>', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']`
| Meta Llama 2 7B | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 7B (Chat/Fine-tuned) | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 13B | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-13b', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 13B (Chat/Fine-tuned) | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-13b-f', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 70B | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-70b', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 70B (Chat/Fine-tuned) | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-70b-b-f', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
### Embedding Models
LiteLLM supports all Sagemaker Jumpstart Huggingface Embedding models. Here's how to call it:
```python
from litellm import completion
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = litellm.embedding(model="sagemaker/<your-deployment-name>", input=["good morning from litellm", "this is another item"])
print(f"response: {response}")
```

View file

@ -174,3 +174,5 @@ print(response)
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Titan Embeddings - G1 | `embedding(model="amazon.titan-embed-text-v1", input=input)` |
| Cohere Embeddings - English | `embedding(model="cohere.embed-english-v3", input=input)` |
| Cohere Embeddings - Multilingual | `embedding(model="cohere.embed-multilingual-v3", input=input)` |

View file

@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM
Model Name | Function Call | Required OS Variables |
-----------------------------|----------------------------------------------------------------|--------------------------------------|
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |

View file

@ -11,18 +11,27 @@ model_list:
litellm_settings:
set_verbose: True
cache: # init cache
type: redis # tell litellm to use redis caching
cache: True # set cache responses to True, litellm defaults to using a redis cache
```
#### Step 2: Add Redis Credentials to .env
LiteLLM requires the following REDIS credentials in your env to enable caching
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
```shell
REDIS_URL = "" # REDIS_URL='redis://username:password@hostname:port/database'
## OR ##
REDIS_HOST = "" # REDIS_HOST='redis-18841.c274.us-east-1-3.ec2.cloud.redislabs.com'
REDIS_PORT = "" # REDIS_PORT='18841'
REDIS_PASSWORD = "" # REDIS_PASSWORD='liteLlmIsAmazing'
```
**Additional kwargs**
You can pass in any additional redis.Redis arg, by storing the variable + value in your os environment, like this:
```shell
REDIS_<redis-kwarg-name> = ""
```
[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40)
#### Step 3: Run proxy with config
```shell
$ litellm --config /path/to/config.yaml

View file

@ -1,91 +1,86 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Proxy Config.yaml
Set model list, `api_base`, `api_key`, `temperature` & proxy server settings (`master-key`) on the config.yaml.
| Param Name | Description |
|----------------------|---------------------------------------------------------------|
| `model_list` | List of supported models on the server, with model-specific configs |
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base`, `litellm.cache` |
| `router_settings` | litellm Router settings, example `routing_strategy="least-busy"` [**see all**](https://github.com/BerriAI/litellm/blob/6ef0e8485e0e720c0efa6f3075ce8119f2f62eea/litellm/router.py#L64)|
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base`, `litellm.cache` [**see all**](https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py)|
| `general_settings` | Server settings, example setting `master_key: sk-my_special_key` |
| `environment_variables` | Environment Variables example, `REDIS_HOST`, `REDIS_PORT` |
#### Example Config
**Complete List:** Check the Swagger UI docs on `<your-proxy-url>/#/config.yaml` (e.g. http://0.0.0.0:8000/#/config.yaml), for everything you can pass in the config.yaml.
## Quick Start
Set a model alias for your deployments.
In the `config.yaml` the model_name parameter is the user-facing name to use for your deployment.
In the config below requests with:
- `model=vllm-models` will route to `openai/facebook/opt-125m`.
- `model=gpt-3.5-turbo` will load balance between `azure/gpt-turbo-small-eu` and `azure/gpt-turbo-small-ca`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
- model_name: gpt-3.5-turbo # user-facing model alias
litellm_params: # all params accepted by litellm.completion() - https://docs.litellm.ai/docs/completion/input
model: azure/gpt-turbo-small-eu
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key:
api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU")
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
- model_name: bedrock-claude-v1
litellm_params:
model: bedrock/anthropic.claude-instant-v1
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key:
api_key: "os.environ/AZURE_API_KEY_CA"
rpm: 6
- model_name: gpt-3.5-turbo
- model_name: vllm-models
litellm_params:
model: azure/gpt-turbo-large
api_base: https://openai-france-1234.openai.azure.com/
api_key:
model: openai/facebook/opt-125m # the `openai/` prefix tells litellm it's openai compatible
api_base: http://0.0.0.0:8000
rpm: 1440
model_info:
version: 2
litellm_settings:
litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py
drop_params: True
set_verbose: True
general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
environment_variables:
OPENAI_API_KEY: sk-123
REPLICATE_API_KEY: sk-cohere-is-okay
REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
REDIS_PORT: "16337"
REDIS_PASSWORD:
```
### Config for Multiple Models - GPT-4, Claude-2
Here's how you can use multiple llms with one proxy `config.yaml`.
#### Step 1: Setup Config
```yaml
model_list:
- model_name: zephyr-alpha # the 1st model is the default on the proxy
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
model: huggingface/HuggingFaceH4/zephyr-7b-alpha
api_base: http://0.0.0.0:8001
- model_name: gpt-4
litellm_params:
model: gpt-4
api_key: sk-1233
- model_name: claude-2
litellm_params:
model: claude-2
api_key: sk-claude
```
:::info
The proxy uses the first model in the config as the default model - in this config the default model is `zephyr-alpha`
:::
#### Step 2: Start Proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
#### Step 3: Use proxy
Curl Command
### Using Proxy - Curl Request, OpenAI Package, Langchain, Langchain JS
Calling a model group
<Tabs>
<TabItem value="Curl" label="Curl Request">
Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml.
If multiple with `model_name=gpt-3.5-turbo` does [Load Balancing](https://docs.litellm.ai/docs/proxy/load_balancing)
```shell
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "zephyr-alpha",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
@ -95,33 +90,109 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
}
'
```
</TabItem>
### Config for Embedding Models - xorbitsai/inference
<TabItem value="Curl2" label="Curl Request: Bedrock">
Here's how you can use multiple llms with one proxy `config.yaml`.
Here is how [LiteLLM calls OpenAI Compatible Embedding models](https://docs.litellm.ai/docs/embedding/supported_embedding#openai-compatible-embedding-models)
Sends this request to model where `model_name=bedrock-claude-v1` on config.yaml
#### Config
```yaml
model_list:
- model_name: custom_embedding_model
litellm_params:
model: openai/custom_embedding # the `openai/` prefix tells litellm it's openai compatible
api_base: http://0.0.0.0:8000/
- model_name: custom_embedding_model
litellm_params:
model: openai/custom_embedding # the `openai/` prefix tells litellm it's openai compatible
api_base: http://0.0.0.0:8001/
```
Run the proxy using this config
```shell
$ litellm --config /path/to/config.yaml
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "bedrock-claude-v1",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:8000"
)
# Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml.
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
# Sends this request to model where `model_name=bedrock-claude-v1` on config.yaml
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
### Save Model-specific params (API Base, API Keys, Temperature, Headers etc.)
</TabItem>
<TabItem value="langchain" label="Langchain Python">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
# Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml.
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:8000", # set openai base to the proxy
model = "gpt-3.5-turbo",
temperature=0.1
)
response = chat(messages)
print(response)
# Sends request to model where `model_name=bedrock-claude-v1` on config.yaml.
claude_chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:8000", # set openai base to the proxy
model = "bedrock-claude-v1",
temperature=0.1
)
response = claude_chat(messages)
print(response)
```
</TabItem>
</Tabs>
## Save Model-specific params (API Base, API Keys, Temperature, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
**Step 1**: Create a `config.yaml` file
```yaml
model_list:
@ -152,9 +223,11 @@ model_list:
$ litellm --config /path/to/config.yaml
```
### Load API Keys from Vault
## Load API Keys
If you have secrets saved in Azure Vault, etc. and don't want to expose them in the config.yaml, here's how to load model-specific keys from the environment.
### Load API Keys from Environment
If you have secrets saved in your environment, and don't want to expose them in the config.yaml, here's how to load model-specific keys from the environment.
```python
os.environ["AZURE_NORTH_AMERICA_API_KEY"] = "your-azure-api-key"
@ -174,30 +247,42 @@ model_list:
s/o to [@David Manouchehri](https://www.linkedin.com/in/davidmanouchehri/) for helping with this.
### Config for setting Model Aliases
### Load API Keys from Azure Vault
Set a model alias for your deployments.
1. Install Proxy dependencies
```bash
$ pip install litellm[proxy] litellm[extra_proxy]
```
In the `config.yaml` the model_name parameter is the user-facing name to use for your deployment.
In the config below requests with `model=gpt-4` will route to `ollama/llama2`
2. Save Azure details in your environment
```bash
export["AZURE_CLIENT_ID"]="your-azure-app-client-id"
export["AZURE_CLIENT_SECRET"]="your-azure-app-client-secret"
export["AZURE_TENANT_ID"]="your-azure-tenant-id"
export["AZURE_KEY_VAULT_URI"]="your-azure-key-vault-uri"
```
3. Add to proxy config.yaml
```yaml
model_list:
- model_name: text-davinci-003
litellm_params:
model: ollama/zephyr
- model_name: gpt-4
litellm_params:
model: ollama/llama2
- model_name: gpt-3.5-turbo
litellm_params:
model: ollama/llama2
model_list:
- model_name: "my-azure-models" # model alias
litellm_params:
model: "azure/<your-deployment-name>"
api_key: "os.environ/AZURE-API-KEY" # reads from key vault - get_secret("AZURE_API_KEY")
api_base: "os.environ/AZURE-API-BASE" # reads from key vault - get_secret("AZURE_API_BASE")
general_settings:
use_azure_key_vault: True
```
You can now test this by starting your proxy:
```bash
litellm --config /path/to/config.yaml
```
### Set Custom Prompt Templates
LiteLLM by default checks if a model has a [prompt template and applies it](./completion/prompt_formatting.md) (e.g. if a huggingface model has a saved chat template in it's tokenizer_config.json). However, you can also set a custom prompt template on your proxy in the `config.yaml`:
LiteLLM by default checks if a model has a [prompt template and applies it](../completion/prompt_formatting.md) (e.g. if a huggingface model has a saved chat template in it's tokenizer_config.json). However, you can also set a custom prompt template on your proxy in the `config.yaml`:
**Step 1**: Save your prompt template in a `config.yaml`
```yaml
@ -220,4 +305,41 @@ model_list:
```shell
$ litellm --config /path/to/config.yaml
```
```
## Router Settings
Use this to configure things like routing strategy.
```yaml
router_settings:
routing_strategy: "least-busy"
model_list: # will route requests to the least busy ollama model
- model_name: ollama-models
litellm_params:
model: "ollama/mistral"
api_base: "http://127.0.0.1:8001"
- model_name: ollama-models
litellm_params:
model: "ollama/codellama"
api_base: "http://127.0.0.1:8002"
- model_name: ollama-models
litellm_params:
model: "ollama/llama2"
api_base: "http://127.0.0.1:8003"
```
## Max Parallel Requests
To rate limit a user based on the number of parallel requests, e.g.:
if user's parallel requests > x, send a 429 error
if user's parallel requests <= x, let them use the API freely.
set the max parallel request limit on the config.yaml (note: this expects the user to be passing in an api key).
```yaml
general_settings:
max_parallel_requests: 100 # max parallel requests for a user = 100
```

View file

@ -1,5 +1,87 @@
# Deploying LiteLLM Proxy
### Deploy on Render https://render.com/
## Quick Start Docker Image: Github Container Registry
### Pull the litellm ghcr docker image
See the latest available ghcr docker image here:
https://github.com/berriai/litellm/pkgs/container/litellm
```shell
docker pull ghcr.io/berriai/litellm:main-v1.10.1
```
### Run the Docker Image
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0
```
#### Run the Docker Image with LiteLLM CLI args
See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
Here's how you can run the docker image and pass your config to `litellm`
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
```
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
```
#### Run the Docker Image using docker compose
**Step 1**
- (Recommended) Use the example file `docker-compose.example.yml` given in the project root. e.g. https://github.com/BerriAI/litellm/blob/main/docker-compose.example.yml
- Rename the file `docker-compose.example.yml` to `docker-compose.yml`.
Here's an example `docker-compose.yml` file
```yaml
version: "3.9"
services:
litellm:
image: ghcr.io/berriai/litellm:main
ports:
- "8000:8000" # Map the container port to the host, change the host port if necessary
volumes:
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "8000", "--num_workers", "8" ]
# ...rest of your docker-compose config if any
```
**Step 2**
Create a `litellm-config.yaml` file with your LiteLLM config relative to your `docker-compose.yml` file.
Check the config doc [here](https://docs.litellm.ai/docs/proxy/configs)
**Step 3**
Run the command `docker-compose up` or `docker compose up` as per your docker installation.
> Use `-d` flag to run the container in detached mode (background) e.g. `docker compose up -d`
Your LiteLLM container should be running now on the defined port e.g. `8000`.
## Deploy on Render https://render.com/
<iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
## LiteLLM Proxy Performance
LiteLLM proxy has been load tested to handle 1500 req/s.
### Throughput - 30% Increase
LiteLLM proxy + Load Balancer gives **30% increase** in throughput compared to Raw OpenAI API
<Image img={require('../../img/throughput.png')} />
### Latency Added - 0.00325 seconds
LiteLLM proxy adds **0.00325 seconds** latency as compared to using the Raw OpenAI API
<Image img={require('../../img/latency.png')} />

View file

@ -3,38 +3,39 @@
Load balance multiple instances of the same model
The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput**
## Quick Start - Load Balancing
### Step 1 - Set deployments on config
#### Example config
requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-eu
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key:
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key:
api_key: <your-azure-api-key>
rpm: 6
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-large
api_base: https://openai-france-1234.openai.azure.com/
api_key:
api_key: <your-azure-api-key>
rpm: 1440
```
#### Step 2: Start Proxy with config
### Step 2: Start Proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
#### Step 3: Use proxy
### Step 3: Use proxy - Call a model group [Load Balancing]
Curl Command
```shell
curl --location 'http://0.0.0.0:8000/chat/completions' \
@ -51,7 +52,28 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
'
```
### Fallbacks + Cooldowns + Retries + Timeouts
### Usage - Call a specific model deployment
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
```bash
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "azure/gpt-turbo-small-ca",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
## Fallbacks + Cooldowns + Retries + Timeouts
If a call fails after num_retries, fall back to another model group.
@ -85,7 +107,7 @@ model_list:
litellm_settings:
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
request_timeout: 10 # raise Timeout error if call takes longer than 10s
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
@ -107,7 +129,71 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
"num_retries": 2,
"request_timeout": 10
"timeout": 10
}
'
```
## Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-eu
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: <your-key>
timeout: 0.1 # timeout in (seconds)
stream_timeout: 0.01 # timeout for stream requests (seconds)
max_retries: 5
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key:
timeout: 0.1 # timeout in (seconds)
stream_timeout: 0.01 # timeout for stream requests (seconds)
max_retries: 5
```
#### Start Proxy
```shell
$ litellm --config /path/to/config.yaml
```
## Health Check LLMs on Proxy
Use this to health check all LLMs defined in your config.yaml
#### Request
Make a GET Request to `/health` on the proxy
```shell
curl --location 'http://0.0.0.0:8000/health'
```
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
```
litellm --health
```
#### Response
```shell
{
"healthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
},
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
}
],
"unhealthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://openai-france-1234.openai.azure.com/"
}
]
}
```

View file

@ -1,8 +1,340 @@
# Logging - OpenTelemetry, Langfuse, ElasticSearch
Log Proxy Input, Output, Exceptions to Langfuse, OpenTelemetry
## OpenTelemetry, ElasticSearch
# Logging - Custom Callbacks, OpenTelemetry, Langfuse, Sentry
### Step 1 Start OpenTelemetry Collecter Docker Container
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry
## Custom Callback Class [Async]
Use this when you want to run custom callbacks in `python`
### Step 1 - Create your custom `litellm` callback class
We use `litellm.integrations.custom_logger` for this, **more details about litellm custom callbacks [here](https://docs.litellm.ai/docs/observability/custom_callback)**
Define your custom callback class in a python file.
Here's an example custom logger for tracking `key, user, model, prompt, response, tokens, cost`. We create a file called `custom_callbacks.py` and initialize `proxy_handler_instance`
```python
from litellm.integrations.custom_logger import CustomLogger
import litellm
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print("On Success")
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success!")
# log: key, user, model, prompt, response, tokens, cost
# Access kwargs passed to litellm.completion()
model = kwargs.get("model", None)
messages = kwargs.get("messages", None)
user = kwargs.get("user", None)
# Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
# Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj)
response = response_obj
# tokens used in response
usage = response_obj["usage"]
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
Usage: {usage},
Cost: {cost},
Response: {response}
Proxy Metadata: {metadata}
"""
)
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"On Async Failure !")
print("\nkwargs", kwargs)
# Access kwargs passed to litellm.completion()
model = kwargs.get("model", None)
messages = kwargs.get("messages", None)
user = kwargs.get("user", None)
# Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
# Acess Exceptions & Traceback
exception_event = kwargs.get("exception", None)
traceback_event = kwargs.get("traceback_exception", None)
# Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj)
print("now checking response obj")
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
Cost: {cost},
Response: {response_obj}
Proxy Metadata: {metadata}
Exception: {exception_event}
Traceback: {traceback_event}
"""
)
except Exception as e:
print(f"Exception: {e}")
proxy_handler_instance = MyCustomHandler()
# Set litellm.callbacks = [proxy_handler_instance] on the proxy
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
```
### Step 2 - Pass your custom callback class in `config.yaml`
We pass the custom callback class defined in **Step1** to the config.yaml.
Set `callbacks` to `python_filename.logger_instance_name`
In the config below, we pass
- python_filename: `custom_callbacks.py`
- logger_instance_name: `proxy_handler_instance`. This is defined in Step 1
`callbacks: custom_callbacks.proxy_handler_instance`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
```
### Step 3 - Start proxy + test request
```shell
litellm --config proxy_config.yaml
```
```shell
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "good morning good sir"
}
],
"user": "ishaan-app",
"temperature": 0.2
}'
```
#### Resulting Log on Proxy
```shell
On Success
Model: gpt-3.5-turbo,
Messages: [{'role': 'user', 'content': 'good morning good sir'}],
User: ishaan-app,
Usage: {'completion_tokens': 10, 'prompt_tokens': 11, 'total_tokens': 21},
Cost: 3.65e-05,
Response: {'id': 'chatcmpl-8S8avKJ1aVBg941y5xzGMSKrYCMvN', 'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'content': 'Good morning! How can I assist you today?', 'role': 'assistant'}}], 'created': 1701716913, 'model': 'gpt-3.5-turbo-0613', 'object': 'chat.completion', 'system_fingerprint': None, 'usage': {'completion_tokens': 10, 'prompt_tokens': 11, 'total_tokens': 21}}
Proxy Metadata: {'user_api_key': None, 'headers': Headers({'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'authorization': 'Bearer sk-1234', 'content-length': '199', 'content-type': 'application/x-www-form-urlencoded'}), 'model_group': 'gpt-3.5-turbo', 'deployment': 'gpt-3.5-turbo-ModelID-gpt-3.5-turbo'}
```
### Logging Proxy Request Object, Header, Url
Here's how you can access the `url`, `headers`, `request body` sent to the proxy for each request
```python
class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success!")
litellm_params = kwargs.get("litellm_params", None)
proxy_server_request = litellm_params.get("proxy_server_request")
print(proxy_server_request)
```
**Expected Output**
```shell
{
"url": "http://testserver/chat/completions",
"method": "POST",
"headers": {
"host": "testserver",
"accept": "*/*",
"accept-encoding": "gzip, deflate",
"connection": "keep-alive",
"user-agent": "testclient",
"authorization": "Bearer None",
"content-length": "105",
"content-type": "application/json"
},
"body": {
"model": "Azure OpenAI GPT-4 Canada",
"messages": [
{
"role": "user",
"content": "hi"
}
],
"max_tokens": 10
}
}
```
### Logging `model_info` set in config.yaml
Here is how to log the `model_info` set in your proxy `config.yaml`. Information on setting `model_info` on [config.yaml](https://docs.litellm.ai/docs/proxy/configs)
```python
class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success!")
litellm_params = kwargs.get("litellm_params", None)
model_info = litellm_params.get("model_info")
print(model_info)
```
**Expected Output**
```json
{'mode': 'embedding', 'input_cost_per_token': 0.002}
```
### Logging responses from proxy
Both `/chat/completions` and `/embeddings` responses are available as `response_obj`
**Note: for `/chat/completions`, both `stream=True` and `non stream` responses are available as `response_obj`**
```python
class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success!")
print(response_obj)
```
**Expected Output /chat/completion [for both `stream` and `non-stream` responses]**
```json
ModelResponse(
id='chatcmpl-8Tfu8GoMElwOZuj2JlHBhNHG01PPo',
choices=[
Choices(
finish_reason='stop',
index=0,
message=Message(
content='As an AI language model, I do not have a physical body and therefore do not possess any degree or educational qualifications. My knowledge and abilities come from the programming and algorithms that have been developed by my creators.',
role='assistant'
)
)
],
created=1702083284,
model='chatgpt-v-2',
object='chat.completion',
system_fingerprint=None,
usage=Usage(
completion_tokens=42,
prompt_tokens=5,
total_tokens=47
)
)
```
**Expected Output /embeddings**
```json
{
'model': 'ada',
'data': [
{
'embedding': [
-0.035126980394124985, -0.020624293014407158, -0.015343423001468182,
-0.03980357199907303, -0.02750781551003456, 0.02111034281551838,
-0.022069307044148445, -0.019442008808255196, -0.00955679826438427,
-0.013143060728907585, 0.029583381488919258, -0.004725852981209755,
-0.015198921784758568, -0.014069183729588985, 0.00897879246622324,
0.01521205808967352,
# ... (truncated for brevity)
]
}
]
}
```
## OpenTelemetry - Traceloop
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop
**Step 1** Install traceloop-sdk and set Traceloop API key
```shell
pip install traceloop-sdk -U
```
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["traceloop"]
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
<!-- ### Step 1 Start OpenTelemetry Collecter Docker Container
This container sends logs to your selected destination
#### Install OpenTelemetry Collecter Docker Image
@ -113,48 +445,13 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
On successfull logging you should be able to see this log on your `OpenTelemetry Collecter` Docker Container
```shell
Events:
SpanEvent #0
-> Name: LiteLLM: Request Input
-> Timestamp: 2023-12-02 05:05:53.71063 +0000 UTC
-> DroppedAttributesCount: 0
-> Attributes::
-> type: Str(http)
-> asgi: Str({'version': '3.0', 'spec_version': '2.3'})
-> http_version: Str(1.1)
-> server: Str(('127.0.0.1', 8000))
-> client: Str(('127.0.0.1', 62796))
-> scheme: Str(http)
-> method: Str(POST)
-> root_path: Str()
-> path: Str(/chat/completions)
-> raw_path: Str(b'/chat/completions')
-> query_string: Str(b'')
-> headers: Str([(b'host', b'0.0.0.0:8000'), (b'user-agent', b'curl/7.88.1'), (b'accept', b'*/*'), (b'authorization', b'Bearer sk-1244'), (b'content-length', b'147'), (b'content-type', b'application/x-www-form-urlencoded')])
-> state: Str({})
-> app: Str(<fastapi.applications.FastAPI object at 0x1253dd960>)
-> fastapi_astack: Str(<contextlib.AsyncExitStack object at 0x127c8b7c0>)
-> router: Str(<fastapi.routing.APIRouter object at 0x1253dda50>)
-> endpoint: Str(<function chat_completion at 0x1254383a0>)
-> path_params: Str({})
-> route: Str(APIRoute(path='/chat/completions', name='chat_completion', methods=['POST']))
SpanEvent #1
-> Name: LiteLLM: Request Headers
-> Timestamp: 2023-12-02 05:05:53.710652 +0000 UTC
-> DroppedAttributesCount: 0
-> Attributes::
-> host: Str(0.0.0.0:8000)
-> user-agent: Str(curl/7.88.1)
-> accept: Str(*/*)
-> authorization: Str(Bearer sk-1244)
-> content-length: Str(147)
-> content-type: Str(application/x-www-form-urlencoded)
SpanEvent #2
```
### View Log on Elastic Search
Here's the log view on Elastic Search. You can see the request `input`, `output` and `headers`
<Image img={require('../../img/elastic_otel.png')} />
<Image img={require('../../img/elastic_otel.png')} /> -->
## Logging Proxy Input/Output - Langfuse
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse
@ -190,3 +487,41 @@ litellm --test
Expected output on Langfuse
<Image img={require('../../img/langfuse_small.png')} />
## Logging Proxy Input/Output - Sentry
If api calls fail (llm/database) you can log those to Sentry:
**Step 1** Install Sentry
```shell
pip install --upgrade sentry-sdk
```
**Step 2**: Save your Sentry_DSN and add `litellm_settings`: `failure_callback`
```shell
export SENTRY_DSN="your-sentry-dsn"
```
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
# other settings
failure_callback: ["sentry"]
general_settings:
database_url: "my-bad-url" # set a fake url to trigger a sentry exception
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```
litellm --test
```

View file

@ -0,0 +1,74 @@
# Model Management
Add new models + Get model info without restarting proxy.
## Get Model Information
Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.
<Tabs
defaultValue="curl"
values={[
{ label: 'cURL', value: 'curl', },
]}>
<TabItem value="curl">
```bash
curl -X GET "http://0.0.0.0:8000/model/info" \
-H "accept: application/json" \
```
</TabItem>
</Tabs>
## Add a New Model
Add a new model to the list in the `config.yaml` by providing the model parameters. This allows you to update the model list without restarting the proxy.
<Tabs
defaultValue="curl"
values={[
{ label: 'cURL', value: 'curl', },
]}>
<TabItem value="curl">
```bash
curl -X POST "http://0.0.0.0:8000/model/new" \
-H "accept: application/json" \
-H "Content-Type: application/json" \
-d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
```
</TabItem>
</Tabs>
### Model Parameters Structure
When adding a new model, your JSON payload should conform to the following structure:
- `model_name`: The name of the new model (required).
- `litellm_params`: A dictionary containing parameters specific to the Litellm setup (required).
- `model_info`: An optional dictionary to provide additional information about the model.
Here's an example of how to structure your `ModelParams`:
```json
{
"model_name": "my_awesome_model",
"litellm_params": {
"some_parameter": "some_value",
"another_parameter": "another_value"
},
"model_info": {
"author": "Your Name",
"version": "1.0",
"description": "A brief description of the model."
}
}
```
---
Keep in mind that as both endpoints are in [BETA], you may need to visit the associated GitHub issues linked in the API descriptions to check for updates or provide feedback:
- Get Model Information: [Issue #933](https://github.com/BerriAI/litellm/issues/933)
- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964)
Feedback on the beta endpoints is valuable and helps improve the API for all users.

View file

@ -43,7 +43,7 @@ litellm --test
This will now automatically route any requests for gpt-3.5-turbo to bigcode starcoder, hosted on huggingface inference endpoints.
### Using LiteLLM Proxy - Curl Request, OpenAI Package
### Using LiteLLM Proxy - Curl Request, OpenAI Package, Langchain, Langchain JS
<Tabs>
<TabItem value="Curl" label="Curl Request">
@ -84,72 +84,75 @@ print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:8000", # set openai_api_base to the LiteLLM Proxy
model = "gpt-3.5-turbo",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
<TabItem value="langchain-embedding" label="Langchain Embeddings">
```python
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="sagemaker-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
text = "This is a test document."
query_result = embeddings.embed_query(text)
print(f"SAGEMAKER EMBEDDINGS")
print(query_result[:5])
embeddings = OpenAIEmbeddings(model="bedrock-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
text = "This is a test document."
query_result = embeddings.embed_query(text)
print(f"BEDROCK EMBEDDINGS")
print(query_result[:5])
embeddings = OpenAIEmbeddings(model="bedrock-titan-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
text = "This is a test document."
query_result = embeddings.embed_query(text)
print(f"TITAN EMBEDDINGS")
print(query_result[:5])
```
</TabItem>
</Tabs>
## Quick Start - LiteLLM Proxy + Config.yaml
The config allows you to create a model list and set `api_base`, `max_tokens` (all litellm params). See more details about the config [here](https://docs.litellm.ai/docs/proxy/configs)
### Create a Config for LiteLLM Proxy
Example config
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/<your-deployment-name>
api_base: <your-azure-api-endpoint>
api_key: <your-azure-api-key>
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: <your-azure-api-key>
```
### Run proxy with config
```shell
litellm --config your_config.yaml
```
## Quick Start Docker Image: Github Container Registry
### Pull the litellm ghcr docker image
See the latest available ghcr docker image here:
https://github.com/berriai/litellm/pkgs/container/litellm
```shell
docker pull ghcr.io/berriai/litellm:main-v1.10.1
```
### Run the Docker Image
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0
```
#### Run the Docker Image with LiteLLM CLI args
See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
Here's how you can run the docker image and pass your config to `litellm`
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
```
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
```
## Server Endpoints
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
- POST `/completions` - completions endpoint
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints
- GET `/models` - available models on server
- POST `/key/generate` - generate a key to access the proxy
## Supported LLMs
### Supported LLMs
All LiteLLM supported LLMs are supported on the Proxy. Seel all [supported llms](https://docs.litellm.ai/docs/providers)
<Tabs>
<TabItem value="bedrock" label="AWS Bedrock">
@ -175,7 +178,7 @@ $ litellm --model azure/my-deployment-name
```
</TabItem>
<TabItem value="openai-proxy" label="OpenAI">
<TabItem value="openai" label="OpenAI">
```shell
$ export OPENAI_API_KEY=my-api-key
@ -185,13 +188,23 @@ $ export OPENAI_API_KEY=my-api-key
$ litellm --model gpt-3.5-turbo
```
</TabItem>
<TabItem value="openai-proxy" label="OpenAI Compatible Endpoint">
```shell
$ export OPENAI_API_KEY=my-api-key
```
```shell
$ litellm --model openai/<your model name> --api_base <your-api-base> # e.g. http://0.0.0.0:3000
```
</TabItem>
<TabItem value="huggingface" label="Huggingface (TGI) Deployed">
```shell
$ export HUGGINGFACE_API_KEY=my-api-key #[OPTIONAL]
```
```shell
$ litellm --model huggingface/<your model name> --api_base https://k58ory32yinf1ly0.us-east-1.aws.endpoints.huggingface.cloud
$ litellm --model huggingface/<your model name> --api_base <your-api-base> # e.g. http://0.0.0.0:3000
```
</TabItem>
@ -301,6 +314,124 @@ $ litellm --model command-nightly
</Tabs>
## Quick Start - LiteLLM Proxy + Config.yaml
The config allows you to create a model list and set `api_base`, `max_tokens` (all litellm params). See more details about the config [here](https://docs.litellm.ai/docs/proxy/configs)
### Create a Config for LiteLLM Proxy
Example config
```yaml
model_list:
- model_name: gpt-3.5-turbo # user-facing model alias
litellm_params: # all params accepted by litellm.completion() - https://docs.litellm.ai/docs/completion/input
model: azure/<your-deployment-name>
api_base: <your-azure-api-endpoint>
api_key: <your-azure-api-key>
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: <your-azure-api-key>
- model_name: vllm-model
litellm_params:
model: openai/<your-model-name>
api_base: <your-api-base> # e.g. http://0.0.0.0:3000
```
### Run proxy with config
```shell
litellm --config your_config.yaml
```
[**More Info**](./configs.md)
## Server Endpoints
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
- POST `/completions` - completions endpoint
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints
- GET `/models` - available models on server
- POST `/key/generate` - generate a key to access the proxy
## Gunicorn + Proxy
Command:
```python
cmd = f"gunicorn litellm.proxy.proxy_server:app --workers {num_workers} --worker-class uvicorn.workers.UvicornWorker --bind {host}:{port}"
```
[**Code**](https://github.com/BerriAI/litellm/blob/077f6b1298101079b72396bdf04f8ca0cf737720/litellm/tests/test_proxy_gunicorn.py#L4)
## Quick Start Docker Image: Github Container Registry
### Pull the litellm ghcr docker image
See the latest available ghcr docker image here:
https://github.com/berriai/litellm/pkgs/container/litellm
```shell
docker pull ghcr.io/berriai/litellm:main-v1.10.1
```
### Run the Docker Image
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0
```
#### Run the Docker Image with LiteLLM CLI args
See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
Here's how you can run the docker image and pass your config to `litellm`
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
```
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
```
#### Run the Docker Image using docker compose
**Step 1**
- (Recommended) Use the example file `docker-compose.example.yml` given in the project root. e.g. https://github.com/BerriAI/litellm/blob/main/docker-compose.example.yml
- Rename the file `docker-compose.example.yml` to `docker-compose.yml`.
Here's an example `docker-compose.yml` file
```yaml
version: "3.9"
services:
litellm:
image: ghcr.io/berriai/litellm:main
ports:
- "8000:8000" # Map the container port to the host, change the host port if necessary
volumes:
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "8000", "--num_workers", "8" ]
# ...rest of your docker-compose config if any
```
**Step 2**
Create a `litellm-config.yaml` file with your LiteLLM config relative to your `docker-compose.yml` file.
Check the config doc [here](https://docs.litellm.ai/docs/proxy/configs)
**Step 3**
Run the command `docker-compose up` or `docker compose up` as per your docker installation.
> Use `-d` flag to run the container in detached mode (background) e.g. `docker compose up -d`
Your LiteLLM container should be running now on the defined port e.g. `8000`.
## Using with OpenAI compatible projects
Set `base_url` to the LiteLLM Proxy server
@ -473,37 +604,4 @@ curl -X POST \
https://api.openai.com/v1/chat/completions \
-H 'content-type: application/json' -H 'Authorization: Bearer sk-qnWGUIW9****************************************' \
-d '{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "this is a test request, write a short poem"}]}'
```
## Health Check LLMs on Proxy
Use this to health check all LLMs defined in your config.yaml
#### Request
```shell
curl --location 'http://0.0.0.0:8000/health'
```
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
```
litellm --health
```
#### Response
```shell
{
"healthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
},
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
}
],
"unhealthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://openai-france-1234.openai.azure.com/"
}
]
}
```

View file

@ -1,9 +1,10 @@
# Cost Tracking & Virtual Keys
# Key Management
Track Spend and create virtual keys for the proxy
Grant other's temporary access to your proxy, with keys that expire after a set duration.
## Quick Start
Requirements:
- Need to a postgres database (e.g. [Supabase](https://supabase.com/), [Neon](https://neon.tech/), etc)
@ -55,7 +56,7 @@ Expected response:
}
```
### Managing Auth - Upgrade/Downgrade Models
## Managing Auth - Upgrade/Downgrade Models
If a user is expected to use a given model (i.e. gpt3-5), and you want to:
@ -104,7 +105,7 @@ curl -X POST "https://0.0.0.0:8000/key/generate" \
- **How to upgrade / downgrade request?** Change the alias mapping
- **How are routing between diff keys/api bases done?** litellm handles this by shuffling between different models in the model list with the same model_name. [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
### Managing Auth - Tracking Spend
## Managing Auth - Tracking Spend
You can get spend for a key by using the `/key/info` endpoint.
@ -136,4 +137,54 @@ This is automatically updated (in USD) when calls are made to /completions, /cha
"config": {}
}
}
```
```
## Custom Auth
You can now override the default api key auth.
Here's how:
### 1. Create a custom auth file.
Make sure the response type follows the `UserAPIKeyAuth` pydantic object. This is used by for logging usage specific to that user key.
```python
from litellm.proxy._types import UserAPIKeyAuth
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = "sk-my-master-key"
if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception
```
### 2. Pass the filepath (relative to the config.yaml)
Pass the filepath to the config.yaml
e.g. if they're both in the same dir - `./config.yaml` and `./custom_auth.py`, this is what it looks like:
```yaml
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
drop_params: True
set_verbose: True
general_settings:
custom_auth: custom_auth.user_api_key_auth
```
[**Implementation Code**](https://github.com/BerriAI/litellm/blob/caf2a6b279ddbe89ebd1d8f4499f65715d684851/litellm/proxy/utils.py#L122)
### 3. Start the proxy
```bash
$ litellm --config /path/to/config.yaml
```

View file

@ -356,6 +356,16 @@ router = Router(model_list=model_list,
print(response)
```
**Pass in Redis URL, additional kwargs**
```python
router = Router(model_list: Optional[list] = None,
## CACHING ##
redis_url=os.getenv("REDIS_URL")",
cache_kwargs= {}, # additional kwargs to pass to RedisCache (see caching.py)
cache_responses=True)
```
#### Default litellm.completion/embedding params
You can also set default params for litellm completion/embedding calls. Here's how to do that:

View file

@ -68,7 +68,7 @@ You can now test this by starting your proxy:
litellm --config /path/to/config.yaml
```
[Quick Test Proxy](./simple_proxy.md#using-litellm-proxy---curl-request-openai-package)
[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
## Infisical Secret Manager
Integrates with [Infisical's Secret Manager](https://infisical.com/) for secure storage and retrieval of API keys and sensitive data.

File diff suppressed because it is too large Load diff

View file

@ -20,6 +20,7 @@
"@docusaurus/preset-classic": "2.4.1",
"@mdx-js/react": "^1.6.22",
"clsx": "^1.2.1",
"docusaurus": "^1.14.7",
"docusaurus-lunr-search": "^2.4.1",
"prism-react-renderer": "^1.3.5",
"react": "^17.0.2",

View file

@ -99,6 +99,7 @@ const sidebars = {
"proxy/configs",
"proxy/load_balancing",
"proxy/virtual_keys",
"proxy/model_management",
"proxy/caching",
"proxy/logging",
"proxy/cli",

File diff suppressed because it is too large Load diff

View file

@ -2,15 +2,18 @@
import threading, requests
from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache
from litellm._logging import set_verbose
import httpx
input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
set_verbose = False
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
@ -43,6 +46,7 @@ caching: bool = False # Not used anymore, will be removed in next MAJOR release
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
_current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {}
@ -338,7 +342,7 @@ cohere_embedding_models: List = [
"embed-english-light-v2.0",
"embed-multilingual-v2.0",
]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1"]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models

8
litellm/_logging.py Normal file
View file

@ -0,0 +1,8 @@
set_verbose = False
def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
except:
pass

93
litellm/_redis.py Normal file
View file

@ -0,0 +1,93 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
import inspect
import redis, litellm
from typing import List, Optional
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = [
"url"
]
available_args = [
x for x in arg_spec.args if x not in exclude_args
] + include_args
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
return {
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
}
def _redis_kwargs_from_environment():
mapping = _get_redis_env_kwarg_mapping()
return_dict = {}
for k, v in mapping.items():
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
if value is not None:
return_dict[v] = value
return return_dict
def get_redis_url_from_environment():
if "REDIS_URL" in os.environ:
return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.")
if "REDIS_PASSWORD" in os.environ:
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
else:
redis_password = ""
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
def get_redis_client(**env_overrides):
### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
value = litellm.get_secret(v)
env_overrides[k] = value
redis_kwargs = {
**_redis_kwargs_from_environment(),
**env_overrides,
}
if "url" in redis_kwargs and redis_kwargs['url'] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis.Redis(**redis_kwargs)

View file

@ -13,9 +13,12 @@ class BudgetManager:
self.load_data()
def print_verbose(self, print_statement):
if litellm.set_verbose:
import logging
logging.info(print_statement)
try:
if litellm.set_verbose:
import logging
logging.info(print_statement)
except:
pass
def load_data(self):
if self.client_type == "local":

View file

@ -25,8 +25,11 @@ def get_prompt(*args, **kwargs):
return None
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
try:
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass
class BaseCache:
def set_cache(self, key, value, **kwargs):
@ -58,8 +61,6 @@ class InMemoryCache(BaseCache):
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
@ -69,13 +70,26 @@ class InMemoryCache(BaseCache):
class RedisCache(BaseCache):
def __init__(self, host, port, password):
def __init__(self, host=None, port=None, password=None, **kwargs):
import redis
# if users don't provider one, use the default litellm cache
self.redis_client = redis.Redis(host=host, port=port, password=password)
from ._redis import get_redis_client
redis_kwargs = {}
if host is not None:
redis_kwargs["host"] = host
if port is not None:
redis_kwargs["port"] = port
if password is not None:
redis_kwargs["password"] = password
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}")
try:
self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
@ -84,8 +98,9 @@ class RedisCache(BaseCache):
def get_cache(self, key, **kwargs):
try:
# TODO convert this to a ModelResponse object
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key)
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
@ -93,8 +108,6 @@ class RedisCache(BaseCache):
cached_response = json.loads(cached_response) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
@ -168,7 +181,8 @@ class Cache:
type="local",
host=None,
port=None,
password=None
password=None,
**kwargs
):
"""
Initializes the cache based on the given type.
@ -178,6 +192,7 @@ class Cache:
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
**kwargs: Additional keyword arguments for redis.Redis() cache
Raises:
ValueError: If an invalid cache type is provided.
@ -186,13 +201,15 @@ class Cache:
None
"""
if type == "redis":
self.cache = RedisCache(host, port, password)
self.cache = RedisCache(host, port, password, **kwargs)
if type == "local":
self.cache = InMemoryCache()
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.success_callback.append("cache")
if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache")
def get_cache_key(self, *args, **kwargs):
"""
@ -205,16 +222,37 @@ class Cache:
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key =""
for param in kwargs:
cache_key = ""
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
print_verbose(f"\nReturning preset cache key: {cache_key}")
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
for param in completion_kwargs:
# ignore litellm params here
if param in set(["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]):
if param in kwargs:
# check if param == model and model_group is passed in, then override model with model_group
if param == "model" and kwargs.get("metadata", None) is not None and kwargs["metadata"].get("model_group", None) is not None:
param_value = kwargs["metadata"].get("model_group", None) # for litellm.Router use model_group for caching over `model`
if param == "model":
model_group = None
metadata = kwargs.get("metadata", None)
litellm_params = kwargs.get("litellm_params", {})
if metadata is not None:
model_group = metadata.get("model_group")
if litellm_params is not None:
metadata = litellm_params.get("metadata", None)
if metadata is not None:
model_group = metadata.get("model_group", None)
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"]
else:
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key+= f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}")
return cache_key
def generate_streaming_content(self, content):
@ -241,9 +279,6 @@ class Cache:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cached_result = self.cache.get_cache(cache_key)
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True:
# if streaming is true and we got a cache hit, return a generator
return self.generate_streaming_content(cached_result["choices"][0]['message']['content'])
return cached_result
except Exception as e:
logging.debug(f"An exception occurred: {traceback.format_exc()}")

View file

@ -8,7 +8,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class CustomLogger:
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self):
pass
@ -27,9 +27,20 @@ class CustomLogger:
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs):
pass
#### DEPRECATED ####
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
@ -46,6 +57,22 @@ class CustomLogger:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call"
await callback_func(
kwargs,
)
print_verbose(
f"Custom Logger - model call details: {kwargs}"
)
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
# Method definition
try:
@ -63,3 +90,21 @@ class CustomLogger:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
await callback_func(
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass

View file

@ -37,7 +37,7 @@ class LangFuseLogger:
f"Langfuse Logging - Enters logging function for model {kwargs}"
)
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
prompt = [kwargs.get('messages')]
optional_params = kwargs.get("optional_params", {})
@ -70,6 +70,6 @@ class LangFuseLogger:
f"Langfuse Layer Logging - final response object: {response_obj}"
)
except:
# traceback.print_exc()
traceback.print_exc()
print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
pass

View file

@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
else:
azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore
response.model = "azure/" + str(response.model)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
@ -318,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
client=None,
logging_obj=None
):
response = None
try:
@ -327,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def embedding(self,
@ -372,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=input,
@ -391,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
}
},
)
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## COMPLETION CALL
response = azure_client.embeddings.create(**data) # type: ignore
## LOGGING

View file

@ -2,7 +2,7 @@ import json, copy, types
import os
from enum import Enum
import time
from typing import Callable, Optional
from typing import Callable, Optional, Any
import litellm
from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -205,15 +205,25 @@ class AmazonLlamaConfig():
def init_bedrock_client(
region_name = None,
aws_access_key_id = None,
aws_secret_access_key = None,
aws_region_name=None,
aws_bedrock_runtime_endpoint=None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] =None,
aws_bedrock_runtime_endpoint: Optional[str]=None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME")
standard_aws_region_name = get_secret("AWS_REGION")
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith('os.environ/'):
params_to_check[i] = get_secret(param)
# Assign updated values back to parameters
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
if region_name:
pass
elif aws_region_name:
@ -472,7 +482,7 @@ def completion(
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response_body,
original_response=json.dumps(response_body),
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
@ -533,37 +543,59 @@ def completion(
def _embedding_func_single(
model: str,
input: str,
client: Any,
optional_params=None,
encoding=None,
logging_obj=None,
):
# logic for parsing in - calling - parsing out model embedding calls
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop(
"aws_bedrock_client",
# only pass variables that are not None
init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
),
## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
if provider == "amazon":
input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params}
# data = json.dumps(data)
elif provider == "cohere":
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body},
modelId={model},
accept="*/*",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=input,
api_key="", # boto3 is used for init.
additional_args={"complete_input_dict": {"model": model,
"texts": input}, "request_str": request_str},
)
input = input.replace(os.linesep, " ")
body = json.dumps({"inputText": input})
try:
response = client.invoke_model(
body=body,
modelId=model,
accept="application/json",
accept="*/*",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
return response_body.get("embedding")
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
if provider == "cohere":
response = response_body.get("embeddings")
# flatten list
response = [item for sublist in response for item in sublist]
return response
elif provider == "amazon":
return response_body.get("embedding")
except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
@ -576,17 +608,21 @@ def embedding(
optional_params=None,
encoding=None,
):
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": {"model": model,
"texts": input}},
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
## Populate OpenAI compliant dictionary
@ -614,14 +650,5 @@ def embedding(
total_tokens=input_tokens + 0
)
model_response.usage = usage
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": {"model": model,
"texts": input}},
original_response=embeddings,
)
return model_response

View file

@ -170,6 +170,11 @@ class Huggingface(BaseLLM):
"content"
] = completion_response["generated_text"] # type: ignore
elif task == "text-generation-inference":
if (not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]):
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}")
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"

View file

@ -1,10 +1,9 @@
import requests, types
import requests, types, time
import json
import traceback
from typing import Optional
import litellm
import httpx
import httpx, aiohttp, asyncio
try:
from async_generator import async_generator, yield_ # optional dependency
async_generator_imported = True
@ -115,6 +114,9 @@ def get_ollama_response_stream(
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None
):
if api_base.endswith("/api/generate"):
url = api_base
@ -136,8 +138,15 @@ def get_ollama_response_stream(
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data},
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
)
if acompletion is True:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
return response
else:
return ollama_completion_stream(url=url, data=data)
def ollama_completion_stream(url, data):
session = requests.Session()
with session.post(url, json=data, stream=True) as resp:
@ -169,6 +178,52 @@ def get_ollama_response_stream(
traceback.print_exc()
session.close()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
try:
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
if resp.status != 200:
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
async for line in resp.content.iter_any():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
completion_string = ""
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "error" in j:
completion_obj = {
"role": "assistant",
"content": "",
"error": j
}
if "response" in j:
completion_obj = {
"role": "assistant",
"content": j["response"],
}
completion_string += completion_obj["content"]
except Exception as e:
traceback.print_exc()
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = completion_string
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
completion_tokens = len(encoding.encode(completion_string))
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response
except Exception as e:
traceback.print_exc()
if async_generator_imported:
# ollama implementation
@async_generator

View file

@ -195,23 +195,23 @@ class OpenAIChatCompletion(BaseLLM):
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
)
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
if client is None:
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
else:
openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore
stringified_response = response.model_dump_json()
logging_obj.post_call(
input=None,
input=messages,
api_key=api_key,
original_response=response,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
@ -259,6 +260,8 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None,
client=None,
max_retries=None,
logging_obj=None,
headers=None
):
response = None
try:
@ -266,8 +269,21 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else:
openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
)
response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
stringified_response = response.model_dump_json()
logging_obj.post_call(
input=data['messages'],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e:
if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
@ -285,12 +301,19 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str]=None,
api_base: Optional[str]=None,
client = None,
max_retries=None
max_retries=None,
headers=None
):
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
)
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
return streamwrapper
@ -304,6 +327,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None,
client=None,
max_retries=None,
headers=None
):
response = None
try:
@ -311,6 +335,13 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else:
openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
@ -325,6 +356,7 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=500, message=f"{str(e)}")
async def aembedding(
self,
input: list,
data: dict,
model_response: ModelResponse,
timeout: float,
@ -332,6 +364,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None,
client=None,
max_retries=None,
logging_obj=None
):
response = None
try:
@ -340,9 +373,24 @@ class OpenAIChatCompletion(BaseLLM):
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data) # type: ignore
return response
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def embedding(self,
model: str,
input: list,
@ -367,13 +415,7 @@ class OpenAIChatCompletion(BaseLLM):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## LOGGING
logging_obj.pre_call(
input=input,
@ -381,6 +423,14 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base},
)
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## COMPLETION CALL
response = openai_client.embeddings.create(**data) # type: ignore
## LOGGING

View file

@ -2,7 +2,7 @@ from enum import Enum
import requests, traceback
import json
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional
from typing import Optional, Any
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@ -74,7 +74,7 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
messages=messages
)
else:
prompt = "".join(m["content"] for m in messages)
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
return prompt
def mistral_instruct_pt(messages):
@ -159,26 +159,27 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt
def hf_chat_template(model: str, messages: list):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
## get the tokenizer config from huggingface
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
if chat_template is None:
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
def raise_exception(message):
raise Exception(f"Error message - {message}")
@ -231,12 +232,20 @@ def hf_chat_template(model: str, messages: list):
# Anthropic template
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
"""
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
- you can't just pass a system message
- you can't pass a system message and follow that with an assistant message
if system message is passed in, you can only do system, human, assistant or system, human
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
"""
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
prompt = ""
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
for idx, message in enumerate(messages):
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
@ -245,15 +254,44 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
prompt += (
f"{message['content']}"
)
else:
elif message["role"] == "assistant":
if idx > 0 and messages[idx - 1]["role"] == "system":
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt
### TOGETHER AI
def get_model_info(token, model):
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
if response.status_code == 200:
model_info = response.json()
for m in model_info:
if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
return None, None
else:
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template):
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
elif prompt_format is not None:
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
else:
prompt = default_pt(messages)
return prompt
###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
@ -320,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value
return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
original_model_name = model
model = model.lower()
if custom_llm_provider == "ollama":
@ -330,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return claude_2_1_pt(messages=messages)
else:
return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)
@ -342,7 +382,7 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
elif "mosaicml/mpt" in model:
if "chat" in model:
return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model:
elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model:
@ -357,4 +397,4 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return hf_chat_template(original_model_name, messages)
except:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -100,7 +100,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers},
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url},
)
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
# Function to extract version ID from model string
def model_to_version_id(model):
@ -194,41 +195,48 @@ def completion(
):
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
## Load Config
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
system_prompt = None
if optional_params is not None and "supports_system_prompt" in optional_params:
supports_sys_prompt = optional_params.pop("supports_system_prompt")
else:
supports_sys_prompt = False
if supports_sys_prompt:
for i in range(len(messages)):
if messages[i]["role"] == "system":
first_sys_message = messages.pop(i)
system_prompt = first_sys_message["content"]
break
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
if "meta/llama-2-13b-chat" in model:
system_prompt = ""
prompt = ""
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
else:
prompt += message["content"]
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
"system_prompt": system_prompt,
"prompt": prompt,
"system_prompt": system_prompt,
**optional_params
}
# Otherwise, use the prompt as is
else:
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
input_data = {
"prompt": prompt,
**optional_params

View file

@ -5,10 +5,11 @@ import requests
import time
from typing import Callable, Optional
import litellm
from litellm.utils import ModelResponse, get_secret, Usage
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys
from copy import deepcopy
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception):
def __init__(self, status_code, message):
@ -61,6 +62,8 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
hf_model_name=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -107,19 +110,24 @@ def completion(
inference_params[k] = v
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
else:
prompt += (
f"{message['content']}"
)
else:
prompt += f"{message['content']}"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({
"inputs": prompt,
@ -138,15 +146,18 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name},
)
## COMPLETION CALL
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
try:
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"].read().decode("utf8")
## LOGGING
logging_obj.post_call(
@ -185,6 +196,133 @@ def completion(
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
def embedding(model: str,
input: list,
model_response: EmbeddingResponse,
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None):
"""
Supports Huggingface Jumpstart embeddings like GPT-6B
"""
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
#### HF EMBEDDING LOGIC
data = json.dumps({
"text_inputs": input
}).encode('utf-8')
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)""" # type: ignore
logging_obj.pre_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
## EMBEDDING CALL
try:
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response,
)
response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding']
if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
return model_response

View file

@ -115,7 +115,7 @@ def completion(
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
data = {
"model": model,

View file

@ -93,45 +93,58 @@ def completion(
prompt = " ".join([message["content"] for message in messages])
mode = ""
request_str = ""
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
mode = "chat"
request_str += f"chat_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model)
mode = "text"
request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model)
mode = "text"
request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n"
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"chat_model = CodeChatModel.from_pretrained({model})\n"
if mode == "chat":
chat = chat_model.start_chat()
request_str+= f"chat = chat_model.start_chat()\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = text_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"text_model.predict({prompt}, **{optional_params}).text\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
completion_response = text_model.predict(prompt, **optional_params).text
## LOGGING

View file

@ -99,17 +99,18 @@ class Chat():
def __init__(self, params):
self.params = params
self.completions = Completions(self.params)
class Completions():
def __init__(self, params):
self.params = params
def create(self, model, messages, **kwargs):
def create(self, messages, model=None, **kwargs):
for k, v in kwargs.items():
self.params[k] = v
model = model or self.params.get('model')
response = completion(model=model, messages=messages, **self.params)
return response
return response
@client
async def acompletion(*args, **kwargs):
@ -174,7 +175,8 @@ async def acompletion(*args, **kwargs):
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = completion(*args, **kwargs)
else:
@ -318,7 +320,6 @@ def completion(
######### unpacking kwargs #####################
args = locals()
api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None)
@ -327,6 +328,8 @@ def completion(
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
id = kwargs.get('id', None)
metadata = kwargs.get('metadata', None)
model_info = kwargs.get('model_info', None)
proxy_server_request = kwargs.get('proxy_server_request', None)
fallbacks = kwargs.get('fallbacks', None)
headers = kwargs.get("headers", None)
num_retries = kwargs.get("num_retries", None) ## deprecated
@ -341,11 +344,14 @@ def completion(
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
### ASYNC CALLS ###
acompletion = kwargs.get("acompletion", False)
client = kwargs.get("client", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -442,7 +448,7 @@ def completion(
# For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params(
return_async=return_async,
acompletion=acompletion,
api_key=api_key,
force_timeout=force_timeout,
logger_fn=logger_fn,
@ -452,7 +458,10 @@ def completion(
litellm_call_id=kwargs.get('litellm_call_id', None),
model_alias_map=litellm.model_alias_map,
completion_call_id=id,
metadata=metadata
metadata=metadata,
model_info=model_info,
proxy_server_request=proxy_server_request,
preset_cache_key=preset_cache_key
)
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
if custom_llm_provider == "azure":
@ -516,17 +525,18 @@ def completion(
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
)
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
elif (
model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
@ -596,13 +606,14 @@ def completion(
)
raise e
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
elif (
custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model
@ -673,9 +684,14 @@ def completion(
logger_fn=logger_fn
)
# if "stream" in optional_params and optional_params["stream"] == True:
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
# return response
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
additional_args={"headers": headers},
)
response = model_response
elif (
"replicate" in model or
@ -720,8 +736,16 @@ def completion(
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
return response
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=replicate_key,
original_response=model_response,
)
response = model_response
elif custom_llm_provider=="anthropic":
@ -741,7 +765,7 @@ def completion(
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = anthropic.completion(
response = anthropic.completion(
model=model,
messages=messages,
api_base=api_base,
@ -757,9 +781,16 @@ def completion(
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
return response
response = model_response
response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif custom_llm_provider == "nlp_cloud":
nlp_cloud_key = (
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
@ -772,7 +803,7 @@ def completion(
or "https://api.nlpcloud.io/v1/gpu/"
)
model_response = nlp_cloud.completion(
response = nlp_cloud.completion(
model=model,
messages=messages,
api_base=api_base,
@ -788,9 +819,17 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
return response
response = model_response
response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif custom_llm_provider == "aleph_alpha":
aleph_alpha_key = (
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
@ -1166,6 +1205,8 @@ def completion(
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging
@ -1190,7 +1231,7 @@ def completion(
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = bedrock.completion(
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
@ -1208,16 +1249,24 @@ def completion(
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
response, model, custom_llm_provider="bedrock", logging_obj=logging
)
else:
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
)
return response
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT
response = model_response
response = response
elif custom_llm_provider == "vllm":
model_response = vllm.completion(
model=model,
@ -1270,7 +1319,9 @@ def completion(
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
return async_generator
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True:
return generator
if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed
response = CustomStreamWrapper(
@ -1673,7 +1724,7 @@ def batch_completion_models_all_responses(*args, **kwargs):
return responses
### EMBEDDING ENDPOINTS ####################
@client
async def aembedding(*args, **kwargs):
"""
Asynchronously calls the `embedding` function with the given arguments and keyword arguments.
@ -1770,17 +1821,24 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {}
for param in kwargs:
if param != "metadata": # filter out metadata from optional_params
optional_params[param] = kwargs[param]
for param in non_default_params:
optional_params[param] = kwargs[param]
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
try:
response = None
logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}})
if azure == True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -1899,9 +1957,19 @@ def embedding(
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=kwargs,
optional_params=optional_params,
model_response= EmbeddingResponse()
)
elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response= EmbeddingResponse(),
print_verbose=print_verbose
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
@ -1910,7 +1978,7 @@ def embedding(
## LOGGING
logging.post_call(
input=input,
api_key=openai.api_key,
api_key=api_key,
original_response=str(e),
)
## Map to OpenAI Exception
@ -2123,8 +2191,11 @@ def moderation(input: str, api_key: Optional[str]=None):
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
try:
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass
def config_completion(**kwargs):
if litellm.config_path != None:

View file

@ -41,6 +41,20 @@
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-4-1106-preview": {
"max_tokens": 128000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-4-vision-preview": {
"max_tokens": 128000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-3.5-turbo": {
"max_tokens": 4097,
"input_cost_per_token": 0.0000015,
@ -62,6 +76,13 @@
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-3.5-turbo-1106": {
"max_tokens": 16385,
"input_cost_per_token": 0.0000010,
"output_cost_per_token": 0.0000020,
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-3.5-turbo-16k": {
"max_tokens": 16385,
"input_cost_per_token": 0.000003,
@ -76,6 +97,62 @@
"litellm_provider": "openai",
"mode": "chat"
},
"ft:gpt-3.5-turbo": {
"max_tokens": 4097,
"input_cost_per_token": 0.000012,
"output_cost_per_token": 0.000016,
"litellm_provider": "openai",
"mode": "chat"
},
"text-embedding-ada-002": {
"max_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000,
"litellm_provider": "openai",
"mode": "embedding"
},
"azure/gpt-4-1106-preview": {
"max_tokens": 128000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-4-32k": {
"max_tokens": 8192,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-4": {
"max_tokens": 16385,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-3.5-turbo-16k": {
"max_tokens": 16385,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000004,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-3.5-turbo": {
"max_tokens": 4097,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/text-embedding-ada-002": {
"max_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000,
"litellm_provider": "azure",
"mode": "embedding"
},
"text-davinci-003": {
"max_tokens": 4097,
"input_cost_per_token": 0.000002,
@ -127,6 +204,7 @@
},
"claude-instant-1": {
"max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551,
"litellm_provider": "anthropic",
@ -134,15 +212,25 @@
},
"claude-instant-1.2": {
"max_tokens": 100000,
"input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000000163,
"output_cost_per_token": 0.000000551,
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-2": {
"max_tokens": 100000,
"input_cost_per_token": 0.00001102,
"output_cost_per_token": 0.00003268,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-2.1": {
"max_tokens": 200000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "anthropic",
"mode": "chat"
},
@ -227,9 +315,51 @@
"max_tokens": 32000,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "vertex_ai-chat-models",
"litellm_provider": "vertex_ai-code-chat-models",
"mode": "chat"
},
"palm/chat-bison": {
"max_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "chat"
},
"palm/chat-bison-001": {
"max_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "chat"
},
"palm/text-bison": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"palm/text-bison-001": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"palm/text-bison-safety-off": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"palm/text-bison-safety-recitation-off": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"command-nightly": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
@ -267,6 +397,8 @@
},
"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000,
"output_cost_per_token": 0.0000,
"litellm_provider": "replicate",
"mode": "chat"
},
@ -293,6 +425,7 @@
},
"openrouter/anthropic/claude-instant-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551,
"litellm_provider": "openrouter",
@ -300,6 +433,7 @@
},
"openrouter/anthropic/claude-2": {
"max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00001102,
"output_cost_per_token": 0.00003268,
"litellm_provider": "openrouter",
@ -496,20 +630,31 @@
},
"anthropic.claude-v1": {
"max_tokens": 100000,
"input_cost_per_token": 0.00001102,
"output_cost_per_token": 0.00003268,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v2": {
"max_tokens": 100000,
"input_cost_per_token": 0.00001102,
"output_cost_per_token": 0.00003268,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v2:1": {
"max_tokens": 200000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-instant-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551,
"litellm_provider": "bedrock",
@ -529,26 +674,80 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"meta.llama2-70b-chat-v1": {
"max_tokens": 4096,
"input_cost_per_token": 0.00000195,
"output_cost_per_token": 0.00000256,
"litellm_provider": "bedrock",
"mode": "chat"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "completion"
},
"sagemaker/meta-textgeneration-llama-2-7b-f": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "chat"
},
"sagemaker/meta-textgeneration-llama-2-13b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "completion"
},
"sagemaker/meta-textgeneration-llama-2-13b-f": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "chat"
},
"sagemaker/meta-textgeneration-llama-2-70b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "completion"
},
"sagemaker/meta-textgeneration-llama-2-70b-b-f": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "chat"
},
"together-ai-up-to-3b": {
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000001
"output_cost_per_token": 0.0000001,
"litellm_provider": "together_ai"
},
"together-ai-3.1b-7b": {
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002
"output_cost_per_token": 0.0000002,
"litellm_provider": "together_ai"
},
"together-ai-7.1b-20b": {
"max_tokens": 1000,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004
"output_cost_per_token": 0.0000004,
"litellm_provider": "together_ai"
},
"together-ai-20.1b-40b": {
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008,
"litellm_provider": "together_ai"
},
"together-ai-40.1b-70b": {
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000003
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "together_ai"
},
"ollama/llama2": {
"max_tokens": 4096,
@ -578,10 +777,38 @@
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/mistral": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/codellama": {
"max_tokens": 4096,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/orca-mini": {
"max_tokens": 4096,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/vicuna": {
"max_tokens": 2048,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"deepinfra/meta-llama/Llama-2-70b-chat-hf": {
"max_tokens": 6144,
"input_cost_per_token": 0.000001875,
"output_cost_per_token": 0.000001875,
"max_tokens": 4096,
"input_cost_per_token": 0.000000700,
"output_cost_per_token": 0.000000950,
"litellm_provider": "deepinfra",
"mode": "chat"
},
@ -619,5 +846,103 @@
"output_cost_per_token": 0.00000095,
"litellm_provider": "deepinfra",
"mode": "chat"
},
"perplexity/pplx-7b-chat": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-70b-chat": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-7b-online": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.0005,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-70b-online": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.0005,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-2-13b-chat": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-2-70b-chat": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/mistral-7b-instruct": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/replit-code-v1.5-3b": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
"max_tokens": 16384,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
"max_tokens": 16384,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
"max_tokens": 4096,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/meta-llama/Llama-2-13b-chat-hf": {
"max_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000025,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/meta-llama/Llama-2-70b-chat-hf": {
"max_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/codellama/CodeLlama-34b-Instruct-hf": {
"max_tokens": 16384,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "anyscale",
"mode": "chat"
}
}

173
litellm/proxy/_types.py Normal file
View file

@ -0,0 +1,173 @@
from pydantic import BaseModel, Extra, Field, root_validator
from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
######### Request Class Definition ######
class ProxyChatCompletionRequest(LiteLLMBase):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase):
id: Optional[str]
class ModelInfo(LiteLLMBase):
id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']]
input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json
base_model: Optional[Literal
[
'gpt-4-1106-preview',
'gpt-4-32k',
'gpt-4',
'gpt-3.5-turbo-16k',
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
]
class Config:
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
if values.get("mode") is None:
values.update({"mode": None})
if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None:
values.update({"max_tokens": None})
if values.get("base_model") is None:
values.update({"base_model": None})
return values
class ModelParams(LiteLLMBase):
model_name: str
litellm_params: dict
model_info: ModelInfo
class Config:
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
return values
class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h"
models: Optional[list] = []
aliases: Optional[dict] = {}
config: Optional[dict] = {}
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
models: list = []
aliases: dict = {}
config: dict = {}
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
duration: str = "1h"
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
"""
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault")
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy")
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key")
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.")
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
background_health_checks: Optional[bool] = Field(None, description="run health checks in background")
health_check_interval: int = Field(300, description="background health check interval in seconds")
class ConfigYAML(LiteLLMBase):
"""
Documents all the fields supported by the config.yaml
"""
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
general_settings: Optional[ConfigGeneralSettings] = None
class Config:
protected_namespaces = ()

View file

@ -0,0 +1,14 @@
from litellm.proxy._types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception

View file

@ -0,0 +1,59 @@
from litellm.integrations.custom_logger import CustomLogger
import litellm
import inspect
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class MyCustomHandler(CustomLogger):
def __init__(self):
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
# Pretty print_verbose the methods
for method in methods:
print_verbose(f" - {method}")
print_verbose(f"{reset_color_code}")
except:
pass
def log_pre_api_call(self, model, messages, kwargs):
print_verbose(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj)
assert response_cost > 0.0
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print_verbose(f"On Async Failure !")
except Exception as e:
print_verbose(f"Exception: {e}")
proxy_handler_instance = MyCustomHandler()
# need to set litellm.callbacks = [customHandler] # on the proxy
# litellm.success_callback = [async_on_succes_logger]

View file

@ -4,14 +4,18 @@ model_list:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
azure_ad_token: eyJ0eXAiOiJ
api_key: os.environ/AZURE_API_KEY
tpm: 20_000
timeout: 5 # 1 second timeout
stream_timeout: 0.5 # 0.5 second timeout for streaming requests
max_retries: 4
- model_name: gpt-4-team2
litellm_params:
model: azure/gpt-4
api_key: sk-123
api_key: os.environ/AZURE_API_KEY
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
- model_name: gpt-4-team3
litellm_params:
model: azure/gpt-4
api_key: sk-123
tpm: 100_000
timeout: 5 # 1 second timeout
stream_timeout: 0.5 # 0.5 second timeout for streaming requests
max_retries: 4

View file

@ -1,13 +0,0 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-1
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: 73g
tpm: 80_000
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-2
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: 9kj
tpm: 80_000

View file

@ -1,8 +0,0 @@
litellm_settings:
set_verbose: True
general_settings:
master_key: sk-hosted-litellm
use_queue: True
database_url: " # [OPTIONAL] use for token-based auth to proxy

View file

@ -1,11 +0,0 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
api_key:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2 # actual model name
api_key:
api_version: 2023-07-01-preview
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/

View file

@ -0,0 +1,4 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo

View file

@ -0,0 +1,117 @@
# This file runs a health check for the LLM, used on litellm/proxy
import asyncio
import random
from typing import Optional
import litellm
import logging
from litellm._logging import print_verbose
logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = [
"messages",
"api_key"
]
def _get_random_llm_message():
"""
Get a random message from the LLM.
"""
messages = [
"Hey how's it going?",
"What's 1 + 1?"
]
return [
{"role": "user", "content": random.choice(messages)}
]
def _clean_litellm_params(litellm_params: dict):
"""
Clean the litellm params for display to users.
"""
return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}
async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""
async def _check_embedding_model(model_params: dict):
model_params.pop("messages", None)
model_params["input"] = ["test from litellm"]
try:
await litellm.aembedding(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
return False
return True
async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
return False
return True
prepped_params = []
tasks = []
for model in model_list:
litellm_params = model["litellm_params"]
model_info = model.get("model_info", {})
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
litellm_params["messages"] = _get_random_llm_message()
prepped_params.append(litellm_params)
if model_info.get("mode", None) == "embedding":
# this is an embedding model
tasks.append(_check_embedding_model(litellm_params))
else:
tasks.append(_check_model(litellm_params))
results = await asyncio.gather(*tasks)
healthy_endpoints = []
unhealthy_endpoints = []
for is_healthy, model in zip(results, model_list):
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
if is_healthy:
healthy_endpoints.append(cleaned_litellm_params)
else:
unhealthy_endpoints.append(cleaned_litellm_params)
return healthy_endpoints, unhealthy_endpoints
async def perform_health_check(model_list: list, model: Optional[str] = None):
"""
Perform a health check on the system.
Returns:
(bool): True if the health check passes, False otherwise.
"""
if not model_list:
return [], []
if model is not None:
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
return healthy_endpoints, unhealthy_endpoints

View file

@ -0,0 +1 @@
from . import *

View file

@ -0,0 +1,70 @@
from typing import Optional
import litellm
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None:
return
if max_parallel_requests is None:
return
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
self.print_verbose(f"current: {current}")
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if user_api_key is None:
return
request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_call(self, api_key, user_api_key_cache):
try:
if api_key is None:
return
request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(f"An exception occurred - {str(e)}") # noqa

View file

@ -26,7 +26,7 @@ def run_ollama_serve():
except Exception as e:
print(f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""")
""") # noqa
def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo
@ -73,8 +73,7 @@ def is_port_in_use(port):
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls')
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params')
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt')
@click.option('--config', '-c', default=None, help='Configure Litellm')
@click.option('--file', '-f', help='Path to config file')
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
@ -83,7 +82,7 @@ def is_port_in_use(port):
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
@click.option('--local', is_flag=True, default=False, help='for local debugging')
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, file, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health):
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health):
global feature_telemetry
args = locals()
if local:
@ -110,11 +109,11 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
# get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
print(json.dumps(recent_logs, indent=4))
print(json.dumps(recent_logs, indent=4)) # noqa
except:
print("LiteLLM: No logs saved!")
raise Exception("LiteLLM: No logs saved!")
return
if model and "ollama" in model:
if model and "ollama" in model and api_base is None:
run_ollama_serve()
if test_async is True:
import requests, concurrent, time
@ -141,7 +140,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished":
llm_response = polling_response["result"]
break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)

View file

@ -1,8 +1,54 @@
model_list:
- model_name: gpt-3.5-turbo
- model_name: Azure OpenAI GPT-4 Canada-East (External)
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
mode: chat
input_cost_per_token: 0.0.00006
output_cost_per_token: 0.00003
max_tokens: 4096
base_model: gpt-3.5-turbo
- model_name: BEDROCK_GROUP
litellm_params:
model: bedrock/cohere.command-text-v14
- model_name: Azure OpenAI GPT-4 Canada-East (External)
litellm_params:
model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
model_info:
mode: chat
- model_name: azure-embedding-model
litellm_params:
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
mode: embedding
base_model: text-embedding-ada-002
- model_name: text-embedding-ada-002
litellm_params:
model: text-embedding-ada-002
api_key: os.environ/OPENAI_API_KEY
model_info:
mode: embedding
- model_name: text-davinci-003
litellm_params:
model: text-davinci-003
api_key: os.environ/OPENAI_API_KEY
model_info:
mode: completion
litellm_settings:
# cache: True
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
general_settings:
environment_variables:
# otel: True # OpenTelemetry Logger
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)

File diff suppressed because it is too large Load diff

View file

@ -45,7 +45,7 @@ celery_app.conf.update(
@celery_app.task(name='process_job', max_retries=3)
def process_job(*args, **kwargs):
try:
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list"))
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
response = llm_router.completion(*args, **kwargs) # type: ignore
if isinstance(response, litellm.ModelResponse):
response = response.model_dump_json()

View file

@ -16,4 +16,5 @@ model LiteLLM_VerificationToken {
aliases Json @default("{}")
config Json @default("{}")
user_id String?
max_parallel_requests Int?
}

View file

@ -0,0 +1,40 @@
## LOCAL TEST
# from langchain.chat_models import ChatOpenAI
# from langchain.prompts.chat import (
# ChatPromptTemplate,
# HumanMessagePromptTemplate,
# SystemMessagePromptTemplate,
# )
# from langchain.schema import HumanMessage, SystemMessage
# chat = ChatOpenAI(
# openai_api_base="http://0.0.0.0:8000",
# model = "gpt-3.5-turbo",
# temperature=0.1
# )
# messages = [
# SystemMessage(
# content="You are a helpful assistant that im using to make a test request to."
# ),
# HumanMessage(
# content="test from litellm. tell me why it's amazing in 1 sentence"
# ),
# ]
# response = chat(messages)
# print(response)
# claude_chat = ChatOpenAI(
# openai_api_base="http://0.0.0.0:8000",
# model = "claude-v1",
# temperature=0.1
# )
# response = claude_chat(messages)
# print(response)

View file

@ -1,87 +1,342 @@
from typing import Optional, List, Any
import os, subprocess, hashlib
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio, copy
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
### LOGGING ###
class ProxyLogging:
"""
Logging/Custom Handlers for proxy.
Implemented mainly to:
- log successful/failed db read/writes
- support the max parallel request integration
"""
def __init__(self, user_api_key_cache: DualCache):
## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
pass
def _init_litellm_callbacks(self):
litellm.callbacks.append(self.max_parallel_request_limiter)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
):
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(
callback_list=callback_list
)
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
async def success_handler(self, *args, **kwargs):
"""
Log successful db read/writes
"""
pass
async def failure_handler(self, original_exception):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
# check if max parallel requests set
if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await self.max_parallel_request_limiter.async_log_failure_call(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return
### DB CONNECTOR ###
# Define the retry decorator with backoff strategy
# Function to be called whenever a retry is about to happen
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient:
def __init__(self, database_url: str):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
## init logging object
self.proxy_logging_obj = proxy_logging_obj
os.environ["DATABASE_URL"] = database_url
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Client
from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token
async def get_data(self, token: str, expires: Optional[Any]=None):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(self, token: str, expires: Optional[Any]=None):
try:
# check if plain text or hash
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
"token": hashed_token
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
)
return response
return response
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(self, data: dict):
"""
Add a key to the database. If it already exists, do nothing.
"""
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = copy.deepcopy(data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
return new_verification_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
},
data={
"create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def update_data(self, token: str, data: dict):
"""
Update existing data
"""
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
await self.db.litellm_verificationtoken.update(
where={
"token": hashed_token
},
data={**data} # type: ignore
)
return {"token": token, "data": data}
try:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data = copy.deepcopy(data)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": token
},
data={**db_data} # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
try:
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
await self.db.connect()
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
await self.db.disconnect()
try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1])
instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.'))
module_file_path += '.py'
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
else:
# Dynamically import the module
module = importlib.import_module(module_name)
# Get the instance from the module
instance = getattr(module, instance_name)
return instance
except ImportError as e:
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
except Exception as e:
raise e

View file

@ -7,16 +7,18 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal
import random, threading, time, traceback
from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid
import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
import copy
class Router:
"""
Example usage:
@ -53,17 +55,22 @@ class Router:
```
"""
model_names: List = []
cache_responses: bool = False
cache_responses: Optional[bool] = False
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0
tenacity = None
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
def __init__(self,
model_list: Optional[list] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
cache_responses: bool = False,
cache_responses: Optional[bool] = False,
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
## RELIABILITY ##
num_retries: int = 0,
timeout: Optional[float] = None,
default_litellm_params = {}, # default params for Router.chat.completion.create
@ -74,7 +81,9 @@ class Router:
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
self.set_verbose = set_verbose
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
if model_list:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
self.deployment_latency_map = {}
@ -92,8 +101,8 @@ class Router:
self.total_calls: defaultdict = defaultdict(int) # dict to store total calls made to each model
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
# make Router.chat.completions.create compatible for openai.chat.completions.create
self.chat = litellm.Chat(params=default_litellm_params)
@ -102,28 +111,44 @@ class Router:
self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0)
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
self._start_health_check_thread()
### CACHING ###
cache_type = "local" # default to an in-memory cache
redis_cache = None
if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = {
'type': 'redis',
'host': redis_host,
'port': redis_port,
'password': redis_password
}
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
else: # use an in-memory cache
cache_config = {
"type": "local"
}
cache_config = {}
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
cache_type = "redis"
if redis_url is not None:
cache_config['url'] = redis_url
if redis_host is not None:
cache_config['host'] = redis_host
if redis_port is not None:
cache_config['port'] = str(redis_port) # type: ignore
if redis_password is not None:
cache_config['password'] = redis_password
# Add additional key-value pairs from cache_kwargs
cache_config.update(cache_kwargs)
redis_cache = RedisCache(**cache_config)
if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
self.cache_responses = cache_responses
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
## add callback
if isinstance(litellm.input_callback, list):
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
else:
litellm.input_callback = [self.leastbusy_logger] # type: ignore
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
## USAGE TRACKING ##
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
@ -171,9 +196,10 @@ class Router:
try:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
@ -188,7 +214,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e:
raise e
@ -219,8 +245,9 @@ class Router:
try:
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
original_model_string = None # set a default for this variable
deployment = self.get_available_deployment(model=model, messages=messages)
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
@ -234,7 +261,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("async_client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[original_model_string] +=1
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[original_model_string] +=1
@ -255,7 +282,7 @@ class Router:
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
@ -288,8 +315,9 @@ class Router:
is_async: Optional[bool] = False,
**kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input)
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("model_info", {})
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
@ -303,7 +331,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -313,9 +341,10 @@ class Router:
is_async: Optional[bool] = True,
**kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input)
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
@ -328,7 +357,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("async_client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -345,7 +374,7 @@ class Router:
self.print_verbose(f'Async Response: {response}')
return response
except Exception as e:
self.print_verbose(f"An exception occurs")
self.print_verbose(f"An exception occurs: {e}")
original_exception = e
try:
self.print_verbose(f"Trying to fallback b/w models")
@ -380,6 +409,8 @@ class Router:
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
kwargs["metadata"]["model_group"] = mg
response = await self.async_function_with_retries(*args, **kwargs)
@ -423,6 +454,10 @@ class Router:
else:
raise original_exception
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
@ -433,6 +468,8 @@ class Router:
return response
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
if "No models available" in str(e):
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=1)
@ -458,13 +495,12 @@ class Router:
try:
response = self.function_with_retries(*args, **kwargs)
return response
except Exception as e:
except Exception as e:
original_exception = e
self.print_verbose(f"An exception occurs {original_exception}")
try:
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None:
self.print_verbose(f"inside context window fallbacks: {context_window_fallbacks}")
fallback_model_group = None
for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
@ -480,6 +516,8 @@ class Router:
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
response = self.function_with_fallbacks(*args, **kwargs)
return response
@ -501,11 +539,13 @@ class Router:
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass
raise e
except Exception as e:
raise e
raise original_exception
@ -515,7 +555,6 @@ class Router:
Try calling the model 3 times. Shuffle between available deployments.
"""
self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}")
backoff_factor = 1
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
@ -531,6 +570,9 @@ class Router:
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
raise original_exception
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
@ -539,19 +581,19 @@ class Router:
response = original_function(*args, **kwargs)
return response
except openai.RateLimitError as e:
if num_retries > 0:
remaining_retries = num_retries - current_attempt
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
# on RateLimitError we'll wait for an exponential time before trying again
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
if "No models available" in str(e):
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=1)
time.sleep(timeout)
elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, response_headers=e.response.headers)
else:
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
time.sleep(timeout)
else:
raise e
except Exception as e:
# for any other exception types, immediately retry
if num_retries > 0:
pass
else:
raise e
raise original_exception
@ -614,6 +656,27 @@ class Router:
except Exception as e:
raise e
def log_retry(self, kwargs: dict, e: Exception) -> dict:
"""
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
"""
try:
# Log failed model as the previous model
previous_model = {"exception_type": type(e).__name__, "exception_string": str(e)}
for k, v in kwargs.items(): # log everything in kwargs except the old previous_models value - prevent nesting
if k != "metadata":
previous_model[k] = v
elif k == "metadata" and isinstance(v, dict):
previous_model["metadata"] = {} # type: ignore
for metadata_k, metadata_v in kwargs['metadata'].items():
if metadata_k != "previous_models":
previous_model[k][metadata_k] = metadata_v # type: ignore
self.previous_models.append(previous_model)
kwargs["metadata"]["previous_models"] = self.previous_models
return kwargs
except Exception as e:
raise e
def _set_cooldown_deployments(self,
deployment: str):
"""
@ -817,12 +880,16 @@ class Router:
return chosen_item
def set_model_list(self, model_list: list):
self.model_list = model_list
self.model_list = copy.deepcopy(model_list)
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os
for model in self.model_list:
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider is None:
@ -845,6 +912,7 @@ class Router:
if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
@ -852,13 +920,35 @@ class Router:
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
if "azure" in model_name:
if api_base is None:
raise ValueError("api_base is required for Azure OpenAI. Set it on your config")
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
@ -869,40 +959,107 @@ class Router:
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version
api_version=api_version,
timeout=timeout,
max_retries=max_retries
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version
api_version=api_version,
timeout=timeout,
max_retries=max_retries
)
# streaming clients can have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
model["stream_client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
else:
self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}")
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version
api_version=api_version,
timeout=timeout,
max_retries=max_retries
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version
api_version=api_version,
timeout=timeout,
max_retries=max_retries
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
model["stream_client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
else:
self.print_verbose(f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}")
model["async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries
)
model["client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries
)
# streaming clients should have diff timeouts
model["stream_client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries
)
############ End of initializing Clients for OpenAI/Azure ###################
self.deployment_names.append(model["litellm_params"]["model"])
model_id = ""
for key in model["litellm_params"]:
if key != "api_key":
if key != "api_key" and key != "metadata":
model_id+= str(model["litellm_params"][key])
model["litellm_params"]["model"] += "-ModelID-" + model_id
self.print_verbose(f"\n Initialized Model List {self.model_list}")
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
@ -916,17 +1073,63 @@ class Router:
def get_model_names(self):
return self.model_names
def _get_client(self, deployment, kwargs, client_type=None):
"""
Returns the appropriate client based on the given deployment, kwargs, and client_type.
Parameters:
deployment (dict): The deployment dictionary containing the clients.
kwargs (dict): The keyword arguments passed to the function.
client_type (str): The type of client to return.
Returns:
The appropriate client based on the given client_type and kwargs.
"""
if client_type == "async":
if kwargs.get("stream") == True:
return deployment.get("stream_async_client", None)
else:
return deployment.get("async_client", None)
else:
if kwargs.get("stream") == True:
return deployment.get("stream_client", None)
else:
return deployment.get("client", None)
def print_verbose(self, print_statement):
if self.set_verbose or litellm.set_verbose:
print(f"LiteLLM.Router: {print_statement}") # noqa
try:
if self.set_verbose or litellm.set_verbose:
print(f"LiteLLM.Router: {print_statement}") # noqa
except:
pass
def get_available_deployment(self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None):
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False
):
"""
Returns the deployment based on routing strategy
"""
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
# When this was no explicit we had several issues with fallbacks timing out
if specific_deployment == True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list:
cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model"))
if cleaned_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name
return deployment
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
# check if aliases set on litellm model alias map
if model in litellm.model_group_alias_map:
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {litellm.model_group_alias_map.get(model)}")
model = litellm.model_group_alias_map[model]
## get healthy deployments
### get all deployments
### filter out the deployments currently cooling down
@ -934,6 +1137,7 @@ class Router:
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
@ -953,13 +1157,24 @@ class Router:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy":
if len(self.healthy_deployments) > 0:
for item in self.healthy_deployments:
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
return item[0]
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
# pick least busy deployment
min_traffic = float('inf')
min_deployment = None
for k, v in deployments.items():
if v < min_traffic:
min_deployment = k
############## No Available Deployments passed, we do a random pick #################
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)
############## Available Deployments passed, we find the relevant item #################
else:
raise ValueError("No models available.")
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
min_deployment = random.choice(healthy_deployments)
return min_deployment
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################
@ -1010,11 +1225,14 @@ class Router:
raise ValueError("No models available.")
def flush_cache(self):
litellm.cache = None
self.cache.flush_cache()
def reset(self):
## clean up on close
litellm.success_callback = []
litellm.__async_success_callback = []
litellm.failure_callback = []
litellm._async_failure_callback = []
self.flush_cache()

View file

@ -0,0 +1,96 @@
#### What this does ####
# identifies least busy deployment
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
import dotenv, os, requests
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
def __init__(self, router_cache: DualCache):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
def log_pre_api_call(self, model, messages, kwargs):
"""
Log when a model is being used.
Caching based on model group.
"""
try:
if kwargs['litellm_params'].get('metadata') is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
if deployment is None or model_group is None or id is None:
return
# map deployment to id
self.mapping_deployment_to_id[deployment] = id
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
except Exception as e:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
except Exception as e:
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
except Exception as e:
pass
def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count"
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
# map deployment to id
return_dict = {}
for key, value in request_count_dict.items():
return_dict[self.mapping_deployment_to_id[key]] = value
return return_dict

34
litellm/tests/conftest.py Normal file
View file

@ -0,0 +1,34 @@
# conftest.py
import pytest, sys, os
import importlib
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
import litellm
importlib.reload(litellm)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests

View file

@ -0,0 +1,30 @@
model_list:
- model_name: text-davinci-003
litellm_params:
model: ollama/zephyr
- model_name: gpt-4
litellm_params:
model: ollama/llama2
- model_name: gpt-3.5-turbo
litellm_params:
model: ollama/llama2
temperature: 0.1
max_tokens: 20
# request to gpt-4, response from ollama/llama2
# curl --location 'http://0.0.0.0:8000/chat/completions' \
# --header 'Content-Type: application/json' \
# --data ' {
# "model": "gpt-4",
# "messages": [
# {
# "role": "user",
# "content": "what llm are you"
# }
# ],
# }
# '
#
# {"id":"chatcmpl-27c85cf0-ab09-4bcf-8cb1-0ee950520743","choices":[{"finish_reason":"stop","index":0,"message":{"content":" Hello! I'm just an AI, I don't have personal experiences or emotions like humans do. However, I can help you with any questions or tasks you may have! Is there something specific you'd like to know or discuss?","role":"assistant","_logprobs":null}}],"created":1700094955.373751,"model":"ollama/llama2","object":"chat.completion","system_fingerprint":null,"usage":{"prompt_tokens":12,"completion_tokens":47,"total_tokens":59},"_response_ms":8028.017999999999}%

View file

@ -0,0 +1,15 @@
model_list:
- model_name: gpt-4-team1
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY
tpm: 20_000
- model_name: gpt-4-team2
litellm_params:
model: azure/gpt-4
api_key: os.environ/AZURE_API_KEY
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
tpm: 100_000

View file

@ -0,0 +1,7 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_settings:
drop_params: True
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration

View file

@ -0,0 +1,28 @@
litellm_settings:
drop_params: True
# Model-specific settings
model_list: # use the same model_name for using the litellm router. LiteLLM will use the router between gpt-3.5-turbo
- model_name: gpt-3.5-turbo # litellm will
litellm_params:
model: gpt-3.5-turbo
api_key: sk-uj6F
tpm: 20000 # [OPTIONAL] REPLACE with your openai tpm
rpm: 3 # [OPTIONAL] REPLACE with your openai rpm
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
api_key: sk-Imn
tpm: 20000 # [OPTIONAL] REPLACE with your openai tpm
rpm: 3 # [OPTIONAL] REPLACE with your openai rpm
- model_name: gpt-3.5-turbo
litellm_params:
model: openrouter/gpt-3.5-turbo
- model_name: mistral-7b-instruct
litellm_params:
model: mistralai/mistral-7b-instruct
environment_variables:
REDIS_HOST: localhost
REDIS_PASSWORD:
REDIS_PORT:

View file

@ -0,0 +1,7 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
general_settings:
otel: True # OpenTelemetry Logger this logs OTEL data to your collector

View file

@ -0,0 +1,4 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo

View file

View file

@ -0,0 +1,117 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
import json
import os
import tempfile
litellm.num_retries = 3
litellm.cache = None
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + '/vertex_key.json'
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, 'r') as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
# Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
def test_vertex_ai():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718"
test_models = random.sample(test_models, 4)
for model in test_models:
try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
print("making request", model)
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}])
print("\nModel Response", response)
print(response)
assert type(response.choices[0].message.content) == str
assert len(response.choices[0].message.content) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai()
def test_vertex_ai_stream():
load_vertex_ai_credentials()
litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718"
import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = random.sample(test_models, 4)
for model in test_models:
try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
print("making request", model)
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True)
completed_str = ""
for chunk in response:
print(chunk)
content = chunk.choices[0].delta.content or ""
print("\n content", content)
completed_str += content
assert type(content) == str
# pass
assert len(completed_str) > 4
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream()

View file

@ -19,6 +19,13 @@ import random
messages = [{"role": "user", "content": "who is ishaan Github? "}]
# comment
import random
import string
def generate_random_word(length=4):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length))
messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache
try:
@ -28,6 +35,8 @@ def test_caching_v2(): # test in memory cache
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
litellm.success_callback = []
litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
@ -51,6 +60,8 @@ def test_caching_with_models_v2():
print(f"response2: {response2}")
print(f"response3: {response3}")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
# if models are different, it should not return cached response
print(f"response2: {response2}")
@ -84,13 +95,15 @@ def test_embedding_caching():
print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
test_embedding_caching()
# test_embedding_caching()
def test_embedding_caching_azure():
@ -138,6 +151,8 @@ def test_embedding_caching_azure():
print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
print(f"embedding1: {embedding1}")
@ -158,9 +173,9 @@ def test_redis_cache_completion():
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=10, seed=1222)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=10, seed=1222)
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=1)
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5)
response4 = completion(model="command-nightly", messages=messages, caching=True)
print("\nresponse 1", response1)
@ -168,6 +183,8 @@ def test_redis_cache_completion():
print("\nresponse 3", response3)
print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
"""
1 & 2 should be exactly the same
@ -190,11 +207,132 @@ def test_redis_cache_completion():
print(f"response4: {response4}")
pytest.fail(f"Error occurred:")
test_redis_cache_completion()
# test_redis_cache_completion()
def test_redis_cache_completion_stream():
try:
litellm.success_callback = []
litellm._async_success_callback = []
litellm.callbacks = []
litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response_1_content = ""
for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
time.sleep(0.5)
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response_2_content = ""
for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = []
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
litellm.success_callback = []
raise e
"""
1 & 2 should be exactly the same
"""
# test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream_bedrock()
# redis cache with custom keys
def custom_get_cache_key(*args, **kwargs):
# return key to use for your cache:
# return key to use for your cache:
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", ""))
return key
@ -228,9 +366,50 @@ def test_custom_redis_cache_with_key():
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
pytest.fail(f"Error occurred:")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# test_custom_redis_cache_with_key()
def test_custom_redis_cache_params():
# test if we can init redis with **kwargs
try:
litellm.cache = Cache(
type="redis",
host=os.environ['REDIS_HOST'],
port=os.environ['REDIS_PORT'],
password=os.environ['REDIS_PASSWORD'],
db = 0,
ssl=True,
ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key",
ssl_ca_certs="./redis_ca.pem",
)
print(litellm.cache.cache.redis_client)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
pytest.fail(f"Error occurred:", e)
def test_get_cache_key():
from litellm.caching import Cache
try:
cache_instance = Cache()
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
)
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred:", e)
# test_get_cache_key()
# test_custom_redis_cache_params()
# def test_redis_cache_with_ttl():
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
# sample_model_response_object_str = """{

View file

@ -0,0 +1,74 @@
#### What this tests ####
# This tests using caching w/ litellm which requires SSL=True
import sys, os
import time
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, Router
from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache
try:
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception()
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_v2()
def test_caching_router():
"""
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
"""
try:
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
}
]
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
router = Router(model_list=model_list,
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1) # type: ignore
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content']
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_router()

View file

@ -7,25 +7,33 @@ import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
litellm.num_retries = 3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}")
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
def test_completion_custom_provider_model_name():
try:
litellm.cache = None
response = completion(
model="together_ai/togethercomputer/llama-2-70b-chat",
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
logger_fn=logger_fn,
)
@ -53,7 +61,7 @@ def test_completion_claude():
print(response)
print(response.usage)
print(response.usage.completion_tokens)
print(response["usage"]["completion_tokens"])
print(response["usage"]["completion_tokens"])
# print("new cost tracking")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -63,6 +71,7 @@ def test_completion_claude():
def test_completion_claude2_1():
try:
print("claude2.1 test request")
messages=[{'role': 'system', 'content': 'Your goal is generate a joke on the topic user gives'}, {'role': 'assistant', 'content': 'Hi, how can i assist you today?'}, {'role': 'user', 'content': 'Generate a 3 liner joke for me'}]
# test without max tokens
response = completion(
model="claude-2.1",
@ -131,6 +140,7 @@ def test_completion_gpt4_turbo():
pytest.fail(f"Error occurred: {e}")
# test_completion_gpt4_turbo()
@pytest.mark.skip(reason="this test is flaky")
def test_completion_gpt4_vision():
try:
litellm.set_verbose=True
@ -284,7 +294,7 @@ def hf_test_completion_tgi():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# hf_test_completion_tgi()
hf_test_completion_tgi()
# ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv():
@ -442,9 +452,46 @@ def test_completion_text_openai():
pytest.fail(f"Error occurred: {e}")
# test_completion_text_openai()
def custom_callback(
kwargs, # kwargs to completion
completion_response, # response from completion
start_time, end_time # start/end time
):
# Your custom code here
try:
print("LITELLM: in custom callback function")
print("\nkwargs\n", kwargs)
model = kwargs["model"]
messages = kwargs["messages"]
user = kwargs.get("user")
#################################################
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
Seed: {kwargs["seed"]},
temperature: {kwargs["temperature"]},
"""
)
assert kwargs["user"] == "ishaans app"
assert kwargs["model"] == "gpt-3.5-turbo-1106"
assert kwargs["seed"] == 12
assert kwargs["temperature"] == 0.5
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_openai_with_optional_params():
# [Proxy PROD TEST] WARNING: DO NOT DELETE THIS TEST
# assert that `user` gets passed to the completion call
# Note: This tests that we actually send the optional params to the completion call
# We use custom callbacks to test this
try:
litellm.set_verbose = True
litellm.success_callback = [custom_callback]
response = completion(
model="gpt-3.5-turbo-1106",
messages=[
@ -458,11 +505,13 @@ def test_completion_openai_with_optional_params():
seed=12,
response_format={ "type": "json_object" },
logit_bias=None,
user = "ishaans app"
)
# Add any assertions here to check the response
print(response)
except litellm.Timeout as e:
pass
litellm.success_callback = [] # unset callbacks
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -569,7 +618,7 @@ def test_completion_azure_key_completion_arg():
os.environ.pop("AZURE_API_KEY", None)
try:
print("azure gpt-3.5 test\n\n")
litellm.set_verbose=False
litellm.set_verbose=True
## Test azure call
response = completion(
model="azure/chatgpt-v-2",
@ -654,11 +703,12 @@ def test_completion_azure():
print(response)
cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for azure completion request", cost)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure()
test_completion_azure()
def test_azure_openai_ad_token():
# this tests if the azure ad token is set in the request header
@ -960,12 +1010,18 @@ def test_replicate_custom_prompt_dict():
######## Test TogetherAI ########
def test_completion_together_ai():
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct"
try:
messages =[
{"role": "user", "content": "Who are you"},
{"role": "assistant", "content": "I am your helpful assistant."},
{"role": "user", "content": "Tell me a joke"},
]
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -975,15 +1031,17 @@ def test_customprompt_together_ai():
try:
litellm.set_verbose = False
litellm.num_retries = 0
print("in test_customprompt_together_ai")
print(litellm.success_callback)
print(litellm._async_success_callback)
response = completion(
model="together_ai/togethercomputer/llama-2-70b-chat",
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
)
print(response)
except litellm.exceptions.Timeout as e:
print(f"Timeout Error")
litellm.num_retries = 3 # reset retries
pass
except Exception as e:
print(f"ERROR TYPE {type(e)}")
@ -996,7 +1054,7 @@ def test_completion_sagemaker():
print("testing sagemaker")
litellm.set_verbose=True
response = completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages,
temperature=0.2,
max_tokens=80,
@ -1009,18 +1067,40 @@ def test_completion_sagemaker():
def test_completion_chat_sagemaker():
try:
print("testing sagemaker")
messages = [{"role": "user", "content": "Hey, how's it going?"}]
litellm.set_verbose=True
response = completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f",
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages,
max_tokens=100,
temperature=0.7,
stream=True,
)
# Add any assertions here to check the response
print(response)
# Add any assertions here to check the response
complete_response = ""
for chunk in response:
complete_response += chunk.choices[0].delta.content or ""
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_chat_sagemaker()
def test_completion_chat_sagemaker_mistral():
try:
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = completion(
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct",
messages=messages,
max_tokens=100,
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"An error occurred: {str(e)}")
# test_completion_chat_sagemaker_mistral()
def test_completion_bedrock_titan():
try:
response = completion(
@ -1251,43 +1331,6 @@ def test_completion_bedrock_claude_completion_auth():
# test_completion_custom_api_base()
# def test_vertex_ai():
# test_models = ["codechat-bison"] + litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
# # test_models = ["chat-bison"]
# for model in test_models:
# try:
# if model in ["code-gecko@001", "code-gecko@latest"]:
# # our account does not have access to this model
# continue
# print("making request", model)
# response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}])
# print(response)
# print(response.usage.completion_tokens)
# print(response['usage']['completion_tokens'])
# assert type(response.choices[0].message.content) == str
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_vertex_ai()
# def test_vertex_ai_stream():
# litellm.set_verbose=False
# test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
# for model in test_models:
# try:
# if model in ["code-gecko@001", "code-gecko@latest"]:
# # our account does not have access to this model
# continue
# print("making request", model)
# response = completion(model=model, messages=[{"role": "user", "content": "write 100 line code code for saying hi"}], stream=True)
# for chunk in response:
# print(chunk)
# # pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream()
def test_completion_with_fallbacks():
print(f"RUNNING TEST COMPLETION WITH FALLBACKS - test_completion_with_fallbacks")
fallbacks = ["gpt-3.5-turbo", "gpt-3.5-turbo", "command-nightly"]
@ -1337,7 +1380,7 @@ def test_azure_cloudflare_api():
traceback.print_exc()
pass
test_azure_cloudflare_api()
# test_azure_cloudflare_api()
def test_completion_anyscale_2():
try:
@ -1567,7 +1610,7 @@ def test_completion_together_ai_stream():
messages = [{ "content": user_message,"role": "user"}]
try:
response = completion(
model="together_ai/togethercomputer/llama-2-70b-chat",
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages, stream=True,
max_tokens=5
)

View file

@ -0,0 +1,14 @@
from litellm.proxy._types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
print(f"api_key: {api_key}")
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception

View file

@ -0,0 +1,113 @@
from litellm.integrations.custom_logger import CustomLogger
import inspect
import litellm
class testCustomCallbackProxy(CustomLogger):
def __init__(self):
self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore
self.async_success: bool = False # type: ignore
self.async_success_embedding: bool = False # type: ignore
self.async_failure: bool = False # type: ignore
self.async_failure_embedding: bool = False # type: ignore
self.async_completion_kwargs = None # type: ignore
self.async_embedding_kwargs = None # type: ignore
self.async_embedding_response = None # type: ignore
self.async_completion_kwargs_fail = None # type: ignore
self.async_embedding_kwargs_fail = None # type: ignore
self.streaming_response_obj = None # type: ignore
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
print(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
# Pretty print the methods
for method in methods:
print(f" - {method}")
print(f"{reset_color_code}")
except:
pass
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
self.success = True
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
self.failure = True
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success")
self.async_success = True
print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs)
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
print("Got an embedding model", kwargs.get("model"))
print("Setting embedding success to True")
self.async_success_embedding = True
print("Value of async success embedding: ", self.async_success_embedding)
self.async_embedding_kwargs = kwargs
self.async_embedding_response = response_obj
if kwargs.get("stream") == True:
self.streaming_response_obj = response_obj
self.async_completion_kwargs = kwargs
model = kwargs.get("model", None)
messages = kwargs.get("messages", None)
user = kwargs.get("user", None)
# Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
# Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj)
response = response_obj
# tokens used in response
usage = response_obj["usage"]
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
Usage: {usage},
Cost: {cost},
Response: {response}
Proxy Metadata: {metadata}
"""
)
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
self.async_failure = True
print("Value of async failure: ", self.async_failure)
print("\n kwargs: ", kwargs)
if kwargs.get("model") == "text-embedding-ada-002":
self.async_failure_embedding = True
self.async_embedding_kwargs_fail = kwargs
self.async_completion_kwargs_fail = kwargs
my_custom_logger = testCustomCallbackProxy()

View file

@ -0,0 +1,28 @@
general_settings:
database_url: os.environ/PROXY_DATABASE_URL
master_key: os.environ/PROXY_MASTER_KEY
litellm_settings:
drop_params: true
success_callback: ["langfuse"]
model_list:
- litellm_params:
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
model: azure/gpt-35-turbo
model_name: azure-model
- litellm_params:
api_base: https://my-endpoint-canada-berri992.openai.azure.com
api_key: os.environ/AZURE_CANADA_API_KEY
model: azure/gpt-35-turbo
model_name: azure-model
- litellm_params:
api_base: https://openai-france-1234.openai.azure.com
api_key: os.environ/AZURE_FRANCE_API_KEY
model: azure/gpt-turbo
model_name: azure-model
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
model_name: test_openai_models

View file

@ -0,0 +1,11 @@
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
drop_params: True
set_verbose: True
general_settings:
custom_auth: custom_auth.user_api_key_auth

View file

@ -0,0 +1,75 @@
model_list:
- litellm_params:
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
model: azure/gpt-35-turbo
model_name: azure-model
- litellm_params:
api_base: https://my-endpoint-canada-berri992.openai.azure.com
api_key: os.environ/AZURE_CANADA_API_KEY
model: azure/gpt-35-turbo
model_name: azure-model
- litellm_params:
api_base: https://openai-france-1234.openai.azure.com
api_key: os.environ/AZURE_FRANCE_API_KEY
model: azure/gpt-turbo
model_name: azure-model
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 56f1bd94-3b54-4b67-9ea2-7c70e9a3a709
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 4d1ee26c-abca-450c-8744-8e87fd6755e9
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 00e19c0f-b63d-42bb-88e9-016fb0c60764
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 79fc75bf-8e1b-47d5-8d24-9365a854af03
model_name: test_openai_models
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/azure-embedding-model
model_name: azure-embedding-model
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 55848c55-4162-40f9-a6e2-9a722b9ef404
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 34339b1e-e030-4bcc-a531-c48559f10ce4
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: f6f74e14-ac64-4403-9365-319e584dcdc5
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 9b1ef341-322c-410a-8992-903987fef439
model_name: test_openai_models

View file

@ -0,0 +1,26 @@
model_list:
- model_name: Azure OpenAI GPT-4 Canada
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
mode: chat
input_cost_per_token: 0.0002
id: gm
- model_name: azure-embedding-model
litellm_params:
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
mode: embedding
input_cost_per_token: 0.002
id: hello
litellm_settings:
drop_params: True
set_verbose: True
callbacks: custom_callbacks.my_custom_logger

View file

@ -0,0 +1,579 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional, Literal, List, Union
from litellm import completion, embedding
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock
# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router + proxy tests
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("async_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# COMPLETION
## Test OpenAI + sync
def test_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync openai"
}])
## test streaming
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
for chunk in response:
continue
## test failure callback
try:
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
## test streaming
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync
def test_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}])
# test streaming
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
stream=True)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream()
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}])
## test streaming
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync
def test_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}])
# test streaming
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}],
stream=True)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}],
aws_region_name="my-bad-region",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_bedrock_stream()
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}])
# test streaming
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}],
stream=True)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}],
aws_region_name="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_bedrock_stream())
# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_embedding_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="text-embedding-ada-002",
input=["good morning from litellm"],
api_key="my-bad-key")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_openai())
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_azure())
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
litellm.set_verbose = True
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="my-bad-region")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_bedrock())

View file

@ -0,0 +1,437 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional, Literal, List
from litellm import Router
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## fallbacks
## retries
# Test cases
## 1. Simple Azure OpenAI acompletion + streaming call
## 2. Simple Azure OpenAI aembedding call
## 3. Azure OpenAI acompletion + streaming call with retries
## 4. Azure OpenAI aembedding call with retries
## 5. Azure OpenAI acompletion + streaming call with fallbacks
## 6. Azure OpenAI aembedding call with fallbacks
# Test interfaces
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
print(f'received kwargs in pre-input: {kwargs}')
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
"""
No-op.
Not implemented yet.
"""
pass
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"received original response: {kwargs['original_response']}")
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure():
try:
customHandler_completion_azure_router = CompletionCustomHandler()
customHandler_streaming_azure_router = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_completion_azure_router]
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
await asyncio.sleep(2)
assert len(customHandler_completion_azure_router.errors) == 0
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
print(f"async azure router chunk: {chunk}")
continue
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
# failure
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
print(f"response in router3 acompletion: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure())
## EMBEDDING
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
customHandler = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler]
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
await asyncio.sleep(2)
assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success
# failure
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
print(f"response in router3 aembedding: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Fallbacks
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure_with_fallbacks():
try:
customHandler_fallbacks = CompletionCustomHandler()
litellm.callbacks = [customHandler_fallbacks]
# with fallbacks
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
await asyncio.sleep(2)
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
assert len(customHandler_fallbacks.errors) == 0
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
litellm.callbacks = []
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure_with_fallbacks())

View file

@ -1,5 +1,5 @@
### What this tests ####
import sys, os, time
import sys, os, time, inspect, asyncio, traceback
import pytest
sys.path.insert(0, os.path.abspath('../..'))
@ -8,8 +8,24 @@ import litellm
from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger):
success: bool = False
failure: bool = False
complete_streaming_response_in_callback = ""
def __init__(self):
self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore
self.async_success: bool = False # type: ignore
self.async_success_embedding: bool = False # type: ignore
self.async_failure: bool = False # type: ignore
self.async_failure_embedding: bool = False # type: ignore
self.async_completion_kwargs = None # type: ignore
self.async_embedding_kwargs = None # type: ignore
self.async_embedding_response = None # type: ignore
self.async_completion_kwargs_fail = None # type: ignore
self.async_embedding_kwargs_fail = None # type: ignore
self.stream_collected_response = None # type: ignore
self.sync_stream_collected_response = None # type: ignore
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
@ -23,29 +39,78 @@ class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
self.success = True
if kwargs.get("stream") == True:
self.sync_stream_collected_response = response_obj
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
self.failure = True
# def test_chat_openai():
# try:
# customHandler = MyCustomHandler()
# litellm.callbacks = [customHandler]
# response = completion(model="gpt-3.5-turbo",
# messages=[{
# "role": "user",
# "content": "Hi 👋 - i'm openai"
# }],
# stream=True)
# time.sleep(1)
# assert customHandler.success == True
# except Exception as e:
# pytest.fail(f"An error occurred - {str(e)}")
# pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success")
self.async_success = True
if kwargs.get("model") == "text-embedding-ada-002":
self.async_success_embedding = True
self.async_embedding_kwargs = kwargs
self.async_embedding_response = response_obj
if kwargs.get("stream") == True:
self.stream_collected_response = response_obj
self.async_completion_kwargs = kwargs
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
self.async_failure = True
if kwargs.get("model") == "text-embedding-ada-002":
self.async_failure_embedding = True
self.async_embedding_kwargs_fail = kwargs
self.async_completion_kwargs_fail = kwargs
class TmpFunction:
complete_streaming_response_in_callback = ""
async_success: bool = False
async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
print(f"ON ASYNC LOGGING")
self.async_success = True
self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
# test_chat_openai()
def test_async_chat_openai_stream():
try:
tmp_function = TmpFunction()
# litellm.set_verbose = True
litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = ""
async def call_gpt():
nonlocal complete_streaming_response
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response)
asyncio.run(call_gpt())
complete_streaming_response = complete_streaming_response.strip("'")
print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
print(f"complete_streaming_response: {complete_streaming_response}")
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
response2 = complete_streaming_response
assert [ord(c) for c in response1] == [ord(c) for c in response2]
assert tmp_function.async_success == True
except Exception as e:
print(e)
pytest.fail(f"An error occurred - {str(e)}")
test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure():
try:
@ -72,75 +137,192 @@ def test_completion_azure_stream_moderation_failure():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_stream_moderation_failure()
def test_async_custom_handler_stream():
try:
# [PROD Test] - Do not DELETE
# checks if the model response available in the async + stream callbacks is equal to the received response
customHandler2 = MyCustomHandler()
litellm.callbacks = [customHandler2]
litellm.set_verbose = False
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "write 1 sentence about litellm being amazing",
},
]
complete_streaming_response = ""
async def test_1():
nonlocal complete_streaming_response
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=messages,
stream=True
)
async for chunk in response:
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response)
asyncio.run(test_1())
response_in_success_handler = customHandler2.stream_collected_response
response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"]
print("\n\n")
print("response_in_success_handler: ", response_in_success_handler)
print("complete_streaming_response: ", complete_streaming_response)
assert response_in_success_handler == complete_streaming_response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_async_custom_handler_stream()
# def custom_callback(
# kwargs,
# completion_response,
# start_time,
# end_time,
# ):
# print(
# "in custom callback func"
# )
# print("kwargs", kwargs)
# print(completion_response)
# print(start_time)
# print(end_time)
# if "complete_streaming_response" in kwargs:
# print("\n\n complete response\n\n")
# complete_streaming_response = kwargs["complete_streaming_response"]
# print(kwargs["complete_streaming_response"])
# usage = complete_streaming_response["usage"]
# print("usage", usage)
# def send_slack_alert(
# kwargs,
# completion_response,
# start_time,
# end_time,
# ):
# print(
# "in custom slack callback func"
# )
# import requests
# import json
def test_azure_completion_stream():
# [PROD Test] - Do not DELETE
# test if completion() + sync custom logger get the same complete stream response
try:
# checks if the model response available in the async + stream callbacks is equal to the received response
customHandler2 = MyCustomHandler()
litellm.callbacks = [customHandler2]
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "write 1 sentence about litellm being amazing",
},
]
complete_streaming_response = ""
# # Define the Slack webhook URL
# slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=messages,
stream=True
)
for chunk in response:
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response)
time.sleep(0.5) # wait 1/2 second before checking callbacks
response_in_success_handler = customHandler2.sync_stream_collected_response
response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"]
print("\n\n")
print("response_in_success_handler: ", response_in_success_handler)
print("complete_streaming_response: ", complete_streaming_response)
assert response_in_success_handler == complete_streaming_response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # Define the text payload, send data available in litellm custom_callbacks
# text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
# """
# payload = {
# "text": text_payload
# }
@pytest.mark.asyncio
async def test_async_custom_handler_completion():
try:
customHandler_success = MyCustomHandler()
customHandler_failure = MyCustomHandler()
# success
assert customHandler_success.async_success == False
litellm.callbacks = [customHandler_success]
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "hello from litellm test",
}]
)
await asyncio.sleep(1)
assert customHandler_success.async_success == True, "async success is not set to True even after success"
assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
# failure
litellm.callbacks = [customHandler_failure]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how do i kill someone",
},
]
# # Set the headers
# headers = {
# "Content-type": "application/json"
# }
assert customHandler_failure.async_failure == False
try:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
api_key="my-bad-key",
)
except:
pass
assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure"
assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
litellm.callbacks = []
print("Passed setting async failure")
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_custom_handler_completion())
# # Make the POST request
# response = requests.post(slack_webhook_url, json=payload, headers=headers)
@pytest.mark.asyncio
async def test_async_custom_handler_embedding():
try:
customHandler_embedding = MyCustomHandler()
litellm.callbacks = [customHandler_embedding]
# success
assert customHandler_embedding.async_success_embedding == False
response = await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hello world"],
)
await asyncio.sleep(1)
assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success"
assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2
print("Passed setting async success: Embedding")
# failure
assert customHandler_embedding.async_failure_embedding == False
try:
response = await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hello world"],
api_key="my-bad-key",
)
except:
pass
assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_async_custom_handler_embedding())
from litellm import Cache
def test_redis_cache_completion_stream():
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random
try:
print("\nrunning test_redis_cache_completion_stream")
litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response_1_content = ""
for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
# # Check the response status
# if response.status_code == 200:
# print("Message sent successfully to Slack!")
# else:
# print(f"Failed to send message to Slack. Status code: {response.status_code}")
# print(response.json())
# def get_transformed_inputs(
# kwargs,
# ):
# params_to_model = kwargs["additional_args"]["complete_input_dict"]
# print("params to model", params_to_model)
# litellm.success_callback = [custom_callback, send_slack_alert]
# litellm.failure_callback = [send_slack_alert]
# litellm.set_verbose = False
# # litellm.input_callback = [get_transformed_inputs]
time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response_2_content = ""
for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = []
litellm._async_success_callback = []
litellm.cache = None
except Exception as e:
print(e)
litellm.success_callback = []
raise e
test_redis_cache_completion_stream()

View file

@ -151,16 +151,36 @@ def test_cohere_embedding3():
# test_cohere_embedding3()
def test_bedrock_embedding():
def test_bedrock_embedding_titan():
try:
litellm.set_verbose=True
response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
"lets test a second string for good measure"]
)
print(f"response:", response)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding()
# test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere():
try:
litellm.set_verbose=False
response = embedding(
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
aws_region_name="os.environ/AWS_REGION_NAME_2"
)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
# print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_cohere()
# comment out hf tests - since hf endpoints are unstable
def test_hf_embedding():
@ -214,7 +234,14 @@ def test_aembedding_azure():
# test_aembedding_azure()
# def test_custom_openai_embedding():
def test_sagemaker_embeddings():
try:
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
print(f"response: {response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_sagemaker_embeddings()
# def local_proxy_embeddings():
# litellm.set_verbose=True
# response = embedding(
# model="openai/custom_embedding",
@ -222,4 +249,5 @@ def test_aembedding_azure():
# api_base="http://0.0.0.0:8000/"
# )
# print(response)
# test_custom_openai_embedding()
# local_proxy_embeddings()

View file

@ -189,6 +189,7 @@ def test_completion_azure_exception():
}
],
)
os.environ["AZURE_API_KEY"] = old_azure_key
print(f"response: {response}")
print(response)
except openai.AuthenticationError as e:

View file

@ -0,0 +1,19 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
def test_get_llm_provider():
_, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1")
assert response == "bedrock"
test_get_llm_provider()

View file

@ -9,33 +9,107 @@ from litellm import completion
import litellm
litellm.num_retries = 3
litellm.success_callback = ["langfuse"]
# litellm.set_verbose = True
os.environ["LANGFUSE_DEBUG"] = "True"
import time
import pytest
def search_logs(log_file_path):
"""
Searches the given log file for logs containing the "/api/public" string.
Parameters:
- log_file_path (str): The path to the log file to be searched.
Returns:
- None
Raises:
- Exception: If there are any bad logs found in the log file.
"""
import re
print("\n searching logs")
bad_logs = []
good_logs = []
all_logs = []
try:
with open(log_file_path, 'r') as log_file:
lines = log_file.readlines()
print(f"searching logslines: {lines}")
for line in lines:
all_logs.append(line.strip())
if "/api/public" in line:
print("Found log with /api/public:")
print(line.strip())
print("\n\n")
match = re.search(r'receive_response_headers.complete return_value=\(b\'HTTP/1.1\', (\d+),', line)
if match:
status_code = int(match.group(1))
if status_code != 200 and status_code != 201:
print("got a BAD log")
bad_logs.append(line.strip())
else:
good_logs.append(line.strip())
print("\nBad Logs")
print(bad_logs)
if len(bad_logs)>0:
raise Exception(f"bad logs, Bad logs = {bad_logs}")
print("\nGood Logs")
print(good_logs)
if len(good_logs) <= 0:
raise Exception(f"There were no Good Logs from Langfuse. No logs with /api/public status 200. \nAll logs:{all_logs}")
except Exception as e:
raise e
def pre_langfuse_setup():
"""
Set up the logging for the 'pre_langfuse_setup' function.
"""
# sends logs to langfuse.log
import logging
# Configure the logging to write to a file
logging.basicConfig(filename="langfuse.log", level=logging.DEBUG)
logger = logging.getLogger()
# Add a FileHandler to the logger
file_handler = logging.FileHandler("langfuse.log", mode='w')
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
return
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_async():
try:
pre_langfuse_setup()
litellm.set_verbose = True
async def _test_langfuse():
return await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
max_tokens=1000,
max_tokens=100,
temperature=0.7,
timeout=5,
)
response = asyncio.run(_test_langfuse())
print(f"response: {response}")
# time.sleep(2)
# # check langfuse.log to see if there was a failed response
# search_logs("langfuse.log")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
# test_langfuse_logging_async()
test_langfuse_logging_async()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging():
try:
# litellm.set_verbose = True
pre_langfuse_setup()
litellm.set_verbose = True
response = completion(model="claude-instant-1.2",
messages=[{
"role": "user",
@ -43,17 +117,20 @@ def test_langfuse_logging():
}],
max_tokens=10,
temperature=0.2,
metadata={"langfuse/key": "foo"}
)
print(response)
# time.sleep(5)
# # check langfuse.log to see if there was a failed response
# search_logs("langfuse.log")
except litellm.Timeout as e:
pass
except Exception as e:
print(e)
pytest.fail(f"An exception occurred - {e}")
test_langfuse_logging()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_stream():
try:
litellm.set_verbose=True
@ -77,6 +154,7 @@ def test_langfuse_logging_stream():
# test_langfuse_logging_stream()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_custom_generation_name():
try:
litellm.set_verbose=True
@ -99,8 +177,8 @@ def test_langfuse_logging_custom_generation_name():
pytest.fail(f"An exception occurred - {e}")
print(e)
test_langfuse_logging_custom_generation_name()
# test_langfuse_logging_custom_generation_name()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_function_calling():
function1 = [
{

View file

@ -0,0 +1,79 @@
# #### What this tests ####
# # This tests the router's ability to identify the least busy deployment
# #
# # How is this achieved?
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# # - use litellm.success + failure callbacks to log when a request completed
# # - in get_available_deployment, for a given model group name -> pick based on traffic
# import sys, os, asyncio, time
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# from litellm import Router
# import litellm
# async def test_least_busy_routing():
# model_list = [{
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-turbo",
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
# "api_base": "https://openai-france-1234.openai.azure.com",
# "rpm": 1440,
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
# "rpm": 6
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
# "rpm": 6
# }
# }]
# router = Router(model_list=model_list,
# routing_strategy="least-busy",
# set_verbose=False,
# num_retries=3) # type: ignore
# async def call_azure_completion():
# try:
# response = await router.acompletion(
# model="azure-model",
# messages=[
# {
# "role": "user",
# "content": "hello this request will pass"
# }
# ]
# )
# print("\n response", response)
# return response
# except:
# return None
# n = 1000
# start_time = time.time()
# tasks = [call_azure_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None]
# print(n, time.time() - start_time, len(successful_completions))
# asyncio.run(test_least_busy_routing())

View file

@ -17,10 +17,10 @@ model_alias_map = {
"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"
}
litellm.model_alias_map = model_alias_map
def test_model_alias_map():
try:
litellm.model_alias_map = model_alias_map
response = completion(
"good-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],

View file

@ -395,7 +395,7 @@ def sagemaker_test_completion():
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}],
max_tokens=100
)
@ -404,7 +404,7 @@ def sagemaker_test_completion():
# USE CONFIG TOKENS
response_2 = litellm.completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}],
)
response_2_text = response_2.choices[0].message.content

View file

@ -0,0 +1,65 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
# Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app)
@pytest.fixture(scope="function")
def client():
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
app = FastAPI()
initialize(config=config_fp)
app.include_router(router) # Include your router in the test app
return TestClient(app)
def test_custom_auth(client):
try:
# Your test data
test_data = {
"model": "openai-model",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/chat/completions", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 401
result = response.json()
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)

View file

@ -0,0 +1,236 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io, asyncio
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
import importlib, inspect
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
# @app.on_event("startup")
# async def wrapper_startup_event():
# initialize(config=config_fp)
# Use the app fixture in your client fixture
@pytest.fixture
def client():
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
initialize(config=config_fp)
app = FastAPI()
app.include_router(router) # Include your router in the test app
return TestClient(app)
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
print("Testing proxy custom logger")
def test_embedding(client):
try:
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(
value = "custom_callbacks.my_custom_logger",
config_file_path=python_file_path
)
print("id of initialized custom logger", id(my_custom_logger))
litellm.callbacks = [my_custom_logger]
# Your test data
print("initialized proxy")
# import the initialized custom logger
print(litellm.callbacks)
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
print("my_custom_logger", my_custom_logger)
assert my_custom_logger.async_success_embedding == False
test_data = {
"model": "azure-embedding-model",
"input": ["hello"]
}
response = client.post("/embeddings", json=test_data, headers=headers)
print("made request", response.status_code, response.text)
print("vars my custom logger /embeddings", vars(my_custom_logger), "id", id(my_custom_logger))
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
kwargs = my_custom_logger.async_embedding_kwargs
litellm_params = kwargs.get("litellm_params")
metadata = litellm_params.get("metadata", None)
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
assert metadata is not None
assert "user_api_key" in metadata
assert "headers" in metadata
proxy_server_request = litellm_params.get("proxy_server_request")
model_info = litellm_params.get("model_info")
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
result = response.json()
print(f"Received response: {result}")
print("Passed Embedding custom logger on proxy!")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
def test_chat_completion(client):
try:
# Your test data
print("initialized proxy")
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(
value = "custom_callbacks.my_custom_logger",
config_file_path=python_file_path
)
print("id of initialized custom logger", id(my_custom_logger))
litellm.callbacks = [my_custom_logger]
# import the initialized custom logger
print(litellm.callbacks)
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
print("LiteLLM Callbacks", litellm.callbacks)
print("my_custom_logger", my_custom_logger)
assert my_custom_logger.async_success == False
test_data = {
"model": "Azure OpenAI GPT-4 Canada",
"messages": [
{
"role": "user",
"content": "write a litellm poem"
},
],
"max_tokens": 10,
}
response = client.post("/chat/completions", json=test_data, headers=headers)
print("made request", response.status_code, response.text)
print("LiteLLM Callbacks", litellm.callbacks)
asyncio.sleep(1) # sleep while waiting for callback to run
print("my_custom_logger in /chat/completions", my_custom_logger, "id", id(my_custom_logger))
print("vars my custom logger, ", vars(my_custom_logger))
assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct
print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs)
litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params")
metadata = litellm_params.get("metadata", None)
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
assert metadata is not None
assert "user_api_key" in metadata
assert "headers" in metadata
config_model_info = litellm_params.get("model_info")
proxy_server_request_object = litellm_params.get("proxy_server_request")
assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'}
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '123', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'write a litellm poem'}], 'max_tokens': 10}}
result = response.json()
print(f"Received response: {result}")
print("\nPassed /chat/completions with Custom Logger!")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
def test_chat_completion_stream(client):
try:
# Your test data
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(
value = "custom_callbacks.my_custom_logger",
config_file_path=python_file_path
)
print("id of initialized custom logger", id(my_custom_logger))
litellm.callbacks = [my_custom_logger]
import json
print("initialized proxy")
# import the initialized custom logger
print(litellm.callbacks)
print("LiteLLM Callbacks", litellm.callbacks)
print("my_custom_logger", my_custom_logger)
assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call
test_data = {
"model": "Azure OpenAI GPT-4 Canada",
"messages": [
{
"role": "user",
"content": "write 1 line poem about LiteLLM"
},
],
"max_tokens": 40,
"stream": True # streaming call
}
response = client.post("/chat/completions", json=test_data, headers=headers)
print("made request", response.status_code, response.text)
complete_response = ""
for line in response.iter_lines():
if line:
# Process the streaming data line here
print("\n\n Line", line)
print(line)
line = str(line)
json_data = line.replace('data: ', '')
# Parse the JSON string
data = json.loads(json_data)
print("\n\n decode_data", data)
# Access the content of choices[0]['message']['content']
content = data['choices'][0]['delta']['content'] or ""
# Process the content as needed
print("Content:", content)
complete_response+= content
print("\n\nHERE is the complete streaming response string", complete_response)
print("\n\nHERE IS the streaming Response from callback\n\n")
print(my_custom_logger.streaming_response_obj)
import time
time.sleep(0.5)
streamed_response = my_custom_logger.streaming_response_obj
assert complete_response == streamed_response["choices"][0]["message"]["content"]
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")

View file

@ -0,0 +1,61 @@
# #### What this tests ####
# # Allow the user to easily run the local proxy server with Gunicorn
# # LOCAL TESTING ONLY
# import sys, os, subprocess
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os, io
# # this file is to test litellm/proxy
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# import litellm
# ### LOCAL Proxy Server INIT ###
# from litellm.proxy.proxy_server import save_worker_config # Replace with the actual module where your FastAPI router is defined
# filepath = os.path.dirname(os.path.abspath(__file__))
# config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
# def get_openai_info():
# return {
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"),
# }
# def run_server(host="0.0.0.0",port=8008,num_workers=None):
# if num_workers is None:
# # Set it to min(8,cpu_count())
# import multiprocessing
# num_workers = min(4,multiprocessing.cpu_count())
# ### LOAD KEYS ###
# # Load the Azure keys. For now get them from openai-usage
# azure_info = get_openai_info()
# print(f"Azure info:{azure_info}")
# os.environ["AZURE_API_KEY"] = azure_info['api_key']
# os.environ["AZURE_API_BASE"] = azure_info['api_base']
# os.environ["AZURE_API_VERSION"] = "2023-09-01-preview"
# ### SAVE CONFIG ###
# os.environ["WORKER_CONFIG"] = config_fp
# # In order for the app to behave well with signals, run it with gunicorn
# # The first argument must be the "name of the command run"
# cmd = f"gunicorn litellm.proxy.proxy_server:app --workers {num_workers} --worker-class uvicorn.workers.UvicornWorker --bind {host}:{port}"
# cmd = cmd.split()
# print(f"Running command: {cmd}")
# import sys
# sys.stdout.flush()
# sys.stderr.flush()
# # Make sure to propage env variables
# subprocess.run(cmd) # This line actually starts Gunicorn
# if __name__ == "__main__":
# run_server()

Some files were not shown because too many files have changed in this diff Show more