Merge branch 'BerriAI:main' into feature_allow_claude_prefill

This commit is contained in:
Dustin Miller 2024-01-05 15:15:29 -06:00 committed by GitHub
commit 53e5e1df07
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 1395 additions and 490 deletions

View file

@ -134,11 +134,15 @@ jobs:
- run: - run:
name: Trigger Github Action for new Docker Container name: Trigger Github Action for new Docker Container
command: | command: |
echo "Install TOML package."
python3 -m pip install toml
VERSION=$(python3 -c "import toml; print(toml.load('pyproject.toml')['tool']['poetry']['version'])")
echo "LiteLLM Version ${VERSION}"
curl -X POST \ curl -X POST \
-H "Accept: application/vnd.github.v3+json" \ -H "Accept: application/vnd.github.v3+json" \
-H "Authorization: Bearer $GITHUB_TOKEN" \ -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \ "https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
-d '{"ref":"main"}' -d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"${VERSION}\"}}"
workflows: workflows:
version: 2 version: 2

View file

@ -1,12 +1,10 @@
# # this workflow is triggered by an API call when there is a new PyPI release of LiteLLM
name: Build, Publish LiteLLM Docker Image name: Build, Publish LiteLLM Docker Image. New Release
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
tag: tag:
description: "The tag version you want to build" description: "The tag version you want to build"
release:
types: [published]
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds. # Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
env: env:
@ -46,7 +44,7 @@ jobs:
with: with:
context: . context: .
push: true push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name || 'latest' }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest' tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
build-and-push-image-alpine: build-and-push-image-alpine:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -76,5 +74,38 @@ jobs:
context: . context: .
dockerfile: Dockerfile.alpine dockerfile: Dockerfile.alpine
push: true push: true
tags: ${{ steps.meta-alpine.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name || 'latest' }} tags: ${{ steps.meta-alpine.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}
labels: ${{ steps.meta-alpine.outputs.labels }} labels: ${{ steps.meta-alpine.outputs.labels }}
release:
name: "New LiteLLM Release"
runs-on: "ubuntu-latest"
steps:
- name: Display version
run: echo "Current version is ${{ github.event.inputs.tag }}"
- name: "Set Release Tag"
run: echo "RELEASE_TAG=${{ github.event.inputs.tag }}" >> $GITHUB_ENV
- name: Display release tag
run: echo "RELEASE_TAG is $RELEASE_TAG"
- name: "Create release"
uses: "actions/github-script@v6"
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
script: |
try {
const response = await github.rest.repos.createRelease({
draft: false,
generate_release_notes: true,
name: process.env.RELEASE_TAG,
owner: context.repo.owner,
prerelease: false,
repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG,
});
core.exportVariable('RELEASE_ID', response.data.id);
core.exportVariable('RELEASE_UPLOAD_URL', response.data.upload_url);
} catch (error) {
core.setFailed(error.message);
}

View file

@ -0,0 +1,31 @@
name: Read Version from pyproject.toml
on:
push:
branches:
- main # Change this to the default branch of your repository
jobs:
read-version:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8 # Adjust the Python version as needed
- name: Install dependencies
run: pip install toml
- name: Read version from pyproject.toml
id: read-version
run: |
version=$(python -c 'import toml; print(toml.load("pyproject.toml")["tool"]["commitizen"]["version"])')
printf "LITELLM_VERSION=%s" "$version" >> $GITHUB_ENV
- name: Display version
run: echo "Current version is $LITELLM_VERSION"

1
.gitignore vendored
View file

@ -31,3 +31,4 @@ proxy_server_config_@.yaml
.gitignore .gitignore
proxy_server_config_2.yaml proxy_server_config_2.yaml
litellm/proxy/secret_managers/credentials.json litellm/proxy/secret_managers/credentials.json
hosted_config.yaml

View file

@ -3,7 +3,6 @@ ARG LITELLM_BUILD_IMAGE=python:3.9
# Runtime image # Runtime image
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
# Builder stage # Builder stage
FROM $LITELLM_BUILD_IMAGE as builder FROM $LITELLM_BUILD_IMAGE as builder
@ -35,8 +34,12 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
# Runtime stage # Runtime stage
FROM $LITELLM_RUNTIME_IMAGE as runtime FROM $LITELLM_RUNTIME_IMAGE as runtime
ARG with_database
WORKDIR /app WORKDIR /app
# Copy the current directory contents into the container at /app
COPY . .
RUN ls -la /app
# Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present # Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present
COPY --from=builder /app/dist/*.whl . COPY --from=builder /app/dist/*.whl .
@ -45,6 +48,14 @@ COPY --from=builder /wheels/ /wheels/
# Install the built wheel using pip; again using a wildcard if it's the only file # Install the built wheel using pip; again using a wildcard if it's the only file
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
# Check if the with_database argument is set to 'true'
RUN echo "Value of with_database is: ${with_database}"
# If true, execute the following instructions
RUN if [ "$with_database" = "true" ]; then \
prisma generate; \
chmod +x /app/retry_push.sh; \
/app/retry_push.sh; \
fi
EXPOSE 4000/tcp EXPOSE 4000/tcp

View file

@ -6,10 +6,10 @@
LITELLM_MASTER_KEY="sk-1234" LITELLM_MASTER_KEY="sk-1234"
############ ############
# Database - You can change these to any PostgreSQL database that has logical replication enabled. # Database - You can change these to any PostgreSQL database.
############ ############
# LITELLM_DATABASE_URL="your-postgres-db-url" DATABASE_URL="your-postgres-db-url"
############ ############
@ -19,4 +19,4 @@ LITELLM_MASTER_KEY="sk-1234"
# SMTP_HOST = "fake-mail-host" # SMTP_HOST = "fake-mail-host"
# SMTP_USERNAME = "fake-mail-user" # SMTP_USERNAME = "fake-mail-user"
# SMTP_PASSWORD="fake-mail-password" # SMTP_PASSWORD="fake-mail-password"
# SMTP_SENDER_EMAIL="fake-sender-email" # SMTP_SENDER_EMAIL="fake-sender-email"

View file

@ -396,7 +396,48 @@ response = completion(
) )
``` ```
## OpenAI Proxy
Track spend across multiple projects/people
The proxy provides:
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
2. [Hooks for logging](https://docs.litellm.ai/docs/proxy/logging#step-1---create-your-custom-litellm-callback-class)
3. [Cost tracking](https://docs.litellm.ai/docs/proxy/virtual_keys#tracking-spend)
4. [Rate Limiting](https://docs.litellm.ai/docs/proxy/users#set-rate-limits)
### 📖 Proxy Endpoints - [Swagger Docs](https://litellm-api.up.railway.app/)
### Quick Start Proxy - CLI
```shell
pip install litellm[proxy]
```
#### Step 1: Start litellm proxy
```shell
$ litellm --model huggingface/bigcode/starcoder
#INFO: Proxy running on http://0.0.0.0:8000
```
#### Step 2: Make ChatCompletions Request to Proxy
```python
import openai # openai v1.0.0+
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
## More details ## More details
* [exception mapping](./exception_mapping.md) * [exception mapping](./exception_mapping.md)
* [retries + model fallbacks for completion()](./completion/reliable_completions.md) * [retries + model fallbacks for completion()](./completion/reliable_completions.md)
* [tutorial for model fallbacks with completion()](./tutorials/fallbacks.md) * [tutorial for model fallbacks with completion()](./tutorials/fallbacks.md)

View file

@ -161,7 +161,7 @@ litellm_settings:
The proxy support 3 cache-controls: The proxy support 3 cache-controls:
- `ttl`: Will cache the response for the user-defined amount of time (in seconds). - `ttl`: Will cache the response for the user-defined amount of time (in seconds).
- `s-max-age`: Will only accept cached responses that are within user-defined range (in seconds). - `s-maxage`: Will only accept cached responses that are within user-defined range (in seconds).
- `no-cache`: Will not return a cached response, but instead call the actual endpoint. - `no-cache`: Will not return a cached response, but instead call the actual endpoint.
[Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218) [Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218)
@ -237,7 +237,7 @@ chat_completion = client.chat.completions.create(
], ],
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
cache={ cache={
"s-max-age": 600 # only get responses cached within last 10 minutes "s-maxage": 600 # only get responses cached within last 10 minutes
} }
) )
``` ```

View file

@ -0,0 +1,43 @@
# Post-Call Rules
Use this to fail a request based on the output of an llm api call.
## Quick Start
### Step 1: Create a file (e.g. post_call_rules.py)
```python
def my_custom_rule(input): # receives the model response
if len(input) < 5: # trigger fallback if the model response is too short
return False
return True
```
### Step 2. Point it to your proxy
```python
litellm_settings:
post_call_rules: post_call_rules.my_custom_rule
num_retries: 3
```
### Step 3. Start + test your proxy
```bash
$ litellm /path/to/config.yaml
```
```bash
curl --location 'http://0.0.0.0:8000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "deepseek-coder",
"messages": [{"role":"user","content":"What llm are you?"}],
"temperature": 0.7,
"max_tokens": 10,
}'
```
---
This will now check if a response is > len 5, and if it fails, it'll retry a call 3 times before failing.

View file

@ -112,6 +112,7 @@ const sidebars = {
"proxy/reliability", "proxy/reliability",
"proxy/health", "proxy/health",
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules",
"proxy/caching", "proxy/caching",
"proxy/alerting", "proxy/alerting",
"proxy/logging", "proxy/logging",

View file

@ -375,6 +375,45 @@ response = completion(
Need a dedicated key? Email us @ krrish@berri.ai Need a dedicated key? Email us @ krrish@berri.ai
## OpenAI Proxy
Track spend across multiple projects/people
The proxy provides:
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
2. [Hooks for logging](https://docs.litellm.ai/docs/proxy/logging#step-1---create-your-custom-litellm-callback-class)
3. [Cost tracking](https://docs.litellm.ai/docs/proxy/virtual_keys#tracking-spend)
4. [Rate Limiting](https://docs.litellm.ai/docs/proxy/users#set-rate-limits)
### 📖 Proxy Endpoints - [Swagger Docs](https://litellm-api.up.railway.app/)
### Quick Start Proxy - CLI
```shell
pip install litellm[proxy]
```
#### Step 1: Start litellm proxy
```shell
$ litellm --model huggingface/bigcode/starcoder
#INFO: Proxy running on http://0.0.0.0:8000
```
#### Step 2: Make ChatCompletions Request to Proxy
```python
import openai # openai v1.0.0+
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
## More details ## More details
* [exception mapping](./exception_mapping.md) * [exception mapping](./exception_mapping.md)

View file

@ -338,7 +338,8 @@ baseten_models: List = [
] # FALCON 7B # WizardLM # Mosaic ML ] # FALCON 7B # WizardLM # Mosaic ML
# used for token counting # used for Cost Tracking & Token counting
# https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/
# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting # Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
azure_llms = { azure_llms = {
"gpt-35-turbo": "azure/gpt-35-turbo", "gpt-35-turbo": "azure/gpt-35-turbo",
@ -346,6 +347,10 @@ azure_llms = {
"gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct", "gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
} }
azure_embedding_models = {
"ada": "azure/ada",
}
petals_models = [ petals_models = [
"petals-team/StableBeluga2", "petals-team/StableBeluga2",
] ]

View file

@ -11,6 +11,7 @@ import litellm
import time, logging import time, logging
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any from typing import Optional, Literal, List, Union, Any
from openai._models import BaseModel as OpenAIObject
def print_verbose(print_statement): def print_verbose(print_statement):
@ -472,7 +473,10 @@ class Cache:
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None: if cache_key is not None:
max_age = kwargs.get("cache", {}).get("s-max-age", float("inf")) cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.cache.get_cache(cache_key) cached_result = self.cache.get_cache(cache_key)
# Check if a timestamp was stored with the cached response # Check if a timestamp was stored with the cached response
if ( if (
@ -529,7 +533,7 @@ class Cache:
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None: if cache_key is not None:
if isinstance(result, litellm.ModelResponse): if isinstance(result, OpenAIObject):
result = result.model_dump_json() result = result.model_dump_json()
## Get Cache-Controls ## ## Get Cache-Controls ##

View file

@ -724,16 +724,32 @@ class AzureChatCompletion(BaseLLM):
client_session = litellm.aclient_session or httpx.AsyncClient( client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
) )
client = AsyncAzureOpenAI( if "gateway.ai.cloudflare.com" in api_base:
api_version=api_version, ## build base url - assume api base includes resource name
azure_endpoint=api_base, if not api_base.endswith("/"):
api_key=api_key, api_base += "/"
timeout=timeout, api_base += f"{model}"
http_client=client_session, client = AsyncAzureOpenAI(
) base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
if model is None and mode != "image_generation": # only run this check if it's not cloudflare ai gateway
raise Exception("model is not set") if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None completion = None

View file

@ -14,12 +14,18 @@ model_list:
- model_name: BEDROCK_GROUP - model_name: BEDROCK_GROUP
litellm_params: litellm_params:
model: bedrock/cohere.command-text-v14 model: bedrock/cohere.command-text-v14
- model_name: Azure OpenAI GPT-4 Canada-East (External) - model_name: openai-gpt-3.5
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
model_info: model_info:
mode: chat mode: chat
- model_name: azure-cloudflare
litellm_params:
model: azure/chatgpt-v-2
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
- model_name: azure-embedding-model - model_name: azure-embedding-model
litellm_params: litellm_params:
model: azure/azure-embedding-model model: azure/azure-embedding-model

View file

@ -307,9 +307,8 @@ async def user_api_key_auth(
) )
def prisma_setup(database_url: Optional[str]): async def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj, user_api_key_cache global prisma_client, proxy_logging_obj, user_api_key_cache
if ( if (
database_url is not None and prisma_client is None database_url is not None and prisma_client is None
): # don't re-initialize prisma client after initial init ): # don't re-initialize prisma client after initial init
@ -321,6 +320,8 @@ def prisma_setup(database_url: Optional[str]):
print_verbose( print_verbose(
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}" f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
) )
if prisma_client is not None and prisma_client.db.is_connected() == False:
await prisma_client.connect()
def load_from_azure_key_vault(use_azure_key_vault: bool = False): def load_from_azure_key_vault(use_azure_key_vault: bool = False):
@ -502,232 +503,330 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval) await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str): class ProxyConfig:
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue """
config = {} Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
try: """
if os.path.exists(config_file_path):
def __init__(self) -> None:
pass
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
global prisma_client, user_config_file_path
file_path = config_file_path or user_config_file_path
if config_file_path is not None:
user_config_file_path = config_file_path user_config_file_path = config_file_path
with open(config_file_path, "r") as file: # Load existing config
config = yaml.safe_load(file) ## Yaml
else: if file_path is not None:
raise Exception( if os.path.exists(f"{file_path}"):
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False" with open(f"{file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
raise Exception(f"File not found! - {file_path}")
## DB
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
):
await prisma_setup(database_url=None) # in case it's not been connected yet
_tasks = []
keys = [
"model_list",
"general_settings",
"router_settings",
"litellm_settings",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
return config
async def save_config(self, new_config: dict):
global prisma_client, llm_router, user_config_file_path, llm_model_list, general_settings
# Load existing config
backup_config = await self.get_config()
# Save the updated config
## YAML
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(new_config, config_file, default_flow_style=False)
# update Router - verifies if this is a valid config
try:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=user_config_file_path
) )
except Exception as e: except Exception as e:
raise Exception(f"Exception while reading Config: {e}") traceback.print_exc()
# Revert to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid config passed in")
## PRINT YAML FOR CONFIRMING IT WORKS ## DB - writes valid config to db
printed_yaml = copy.deepcopy(config) """
printed_yaml.pop("environment_variables", None) - Do not write restricted params like 'api_key' to the database
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
"""
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
):
### KEY REMOVAL ###
models = new_config.get("model_list", [])
for m in models:
if m.get("litellm_params", {}).get("api_key", None) is not None:
# pop the key
api_key = m["litellm_params"].pop("api_key")
# store in local env
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
os.environ[key_name] = api_key
# save the key name (not the value)
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config")
print_verbose( async def load_config(
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" self, router: Optional[litellm.Router], config_file_path: str
) ):
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
## ENVIRONMENT VARIABLES # Load existing config
environment_variables = config.get("environment_variables", None) config = await self.get_config(config_file_path=config_file_path)
if environment_variables: ## PRINT YAML FOR CONFIRMING IT WORKS
for key, value in environment_variables.items(): printed_yaml = copy.deepcopy(config)
os.environ[key] = value printed_yaml.pop("environment_variables", None)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) print_verbose(
litellm_settings = config.get("litellm_settings", None) f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
if litellm_settings is None: )
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
for key, value in litellm_settings.items():
if key == "cache":
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
from litellm.caching import Cache
cache_params = {} ## ENVIRONMENT VARIABLES
if "cache_params" in litellm_settings: environment_variables = config.get("environment_variables", None)
cache_params_in_config = litellm_settings["cache_params"] if environment_variables:
# overwrie cache_params with cache_params_in_config for key, value in environment_variables.items():
cache_params.update(cache_params_in_config) os.environ[key] = value
cache_type = cache_params.get("type", "redis") ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", None)
if litellm_settings is None:
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
for key, value in litellm_settings.items():
if key == "cache":
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
from litellm.caching import Cache
print_verbose(f"passed cache type={cache_type}") cache_params = {}
if "cache_params" in litellm_settings:
cache_params_in_config = litellm_settings["cache_params"]
# overwrie cache_params with cache_params_in_config
cache_params.update(cache_params_in_config)
if cache_type == "redis": cache_type = cache_params.get("type", "redis")
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
cache_params = { print_verbose(f"passed cache type={cache_type}")
"type": cache_type,
"host": cache_host, if cache_type == "redis":
"port": cache_port, cache_host = litellm.get_secret("REDIS_HOST", None)
"password": cache_password, cache_port = litellm.get_secret("REDIS_PORT", None)
} cache_password = litellm.get_secret("REDIS_PASSWORD", None)
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
cache_params.update(
{
"type": cache_type,
"host": cache_host,
"port": cache_port,
"password": cache_password,
}
)
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
)
print() # noqa
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache(**cache_params)
print( # noqa print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
) )
print() # noqa elif key == "callbacks":
litellm.callbacks = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
elif key == "post_call_rules":
litellm.post_call_rules = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback":
litellm.success_callback = []
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables # intialize success callbacks
litellm.cache = Cache(**cache_params) for callback in value:
print( # noqa # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" if "." in callback:
) litellm.success_callback.append(
elif key == "callbacks": get_instance_fn(value=callback)
litellm.callbacks = [ )
get_instance_fn(value=value, config_file_path=config_file_path) # these are litellm callbacks - "langfuse", "sentry", "wandb"
] else:
print_verbose( litellm.success_callback.append(callback)
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" print_verbose(
) f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
elif key == "post_call_rules": )
litellm.post_call_rules = [ elif key == "failure_callback":
get_instance_fn(value=value, config_file_path=config_file_path) litellm.failure_callback = []
]
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback":
litellm.success_callback = []
# intialize success callbacks # intialize success callbacks
for callback in value: for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback: if "." in callback:
litellm.success_callback.append(get_instance_fn(value=callback)) litellm.failure_callback.append(
# these are litellm callbacks - "langfuse", "sentry", "wandb" get_instance_fn(value=callback)
else: )
litellm.success_callback.append(callback) # these are litellm callbacks - "langfuse", "sentry", "wandb"
print_verbose( else:
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" litellm.failure_callback.append(callback)
) print_verbose(
elif key == "failure_callback": f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
litellm.failure_callback = [] )
elif key == "cache_params":
# this is set in the cache branch
# see usage here: https://docs.litellm.ai/docs/proxy/caching
pass
else:
setattr(litellm, key, value)
# intialize success callbacks ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
for callback in value: general_settings = config.get("general_settings", {})
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function if general_settings is None:
if "." in callback: general_settings = {}
litellm.failure_callback.append(get_instance_fn(value=callback)) if general_settings:
# these are litellm callbacks - "langfuse", "sentry", "wandb" ### LOAD SECRET MANAGER ###
else: key_management_system = general_settings.get("key_management_system", None)
litellm.failure_callback.append(callback) if key_management_system is not None:
print_verbose( if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" ### LOAD FROM AZURE KEY VAULT ###
) load_from_azure_key_vault(use_azure_key_vault=True)
elif key == "cache_params": elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
# this is set in the cache branch ### LOAD FROM GOOGLE KMS ###
# see usage here: https://docs.litellm.ai/docs/proxy/caching load_google_kms(use_google_kms=True)
pass else:
else: raise ValueError("Invalid Key Management System selected")
setattr(litellm, key, value) ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging load_google_kms(use_google_kms=use_google_kms)
general_settings = config.get("general_settings", {}) ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
if general_settings is None: use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
general_settings = {} load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
if general_settings: ### ALERTING ###
### LOAD SECRET MANAGER ### proxy_logging_obj.update_values(
key_management_system = general_settings.get("key_management_system", None) alerting=general_settings.get("alerting", None),
if key_management_system is not None: alerting_threshold=general_settings.get("alerting_threshold", 600),
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600),
)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
print_verbose(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
) )
### BACKGROUND HEALTH CHECKS ### ### CONNECT TO DATABASE ###
# Enable background health checks database_url = general_settings.get("database_url", None)
use_background_health_checks = general_settings.get( if database_url and database_url.startswith("os.environ/"):
"background_health_checks", False print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
) database_url = litellm.get_secret(database_url)
health_check_interval = general_settings.get("health_check_interval", 300) print_verbose(f"RETRIEVED DB URL: {database_url}")
await prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
)
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
"background_health_checks", False
)
health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = { router_params: dict = {
"num_retries": 3, "num_retries": 3,
"cache_responses": litellm.cache "cache_responses": litellm.cache
!= None, # cache if user passed in cache values != None, # cache if user passed in cache values
}
## MODEL LIST
model_list = config.get("model_list", None)
if model_list:
router_params["model_list"] = model_list
print( # noqa
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
) # noqa
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
} }
## MODEL LIST
model_list = config.get("model_list", None)
if model_list:
router_params["model_list"] = model_list
print( # noqa
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
) # noqa
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
available_args = [x for x in arg_spec.args if x not in exclude_args] ## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
}
for k, v in router_settings.items(): available_args = [x for x in arg_spec.args if x not in exclude_args]
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore for k, v in router_settings.items():
return router, model_list, general_settings if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
proxy_config = ProxyConfig()
async def generate_key_helper_fn( async def generate_key_helper_fn(
@ -797,6 +896,7 @@ async def generate_key_helper_fn(
"max_budget": max_budget, "max_budget": max_budget,
"user_email": user_email, "user_email": user_email,
} }
print_verbose("PrismaClient: Before Insert Data")
new_verification_token = await prisma_client.insert_data( new_verification_token = await prisma_client.insert_data(
data=verification_token_data data=verification_token_data
) )
@ -831,7 +931,7 @@ def save_worker_config(**data):
os.environ["WORKER_CONFIG"] = json.dumps(data) os.environ["WORKER_CONFIG"] = json.dumps(data)
def initialize( async def initialize(
model=None, model=None,
alias=None, alias=None,
api_base=None, api_base=None,
@ -849,7 +949,7 @@ def initialize(
use_queue=False, use_queue=False,
config=None, config=None,
): ):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
generate_feedback_box() generate_feedback_box()
user_model = model user_model = model
user_debug = debug user_debug = debug
@ -857,9 +957,11 @@ def initialize(
litellm.set_verbose = True litellm.set_verbose = True
dynamic_config = {"general": {}, user_model: {}} dynamic_config = {"general": {}, user_model: {}}
if config: if config:
llm_router, llm_model_list, general_settings = load_router_config( (
router=llm_router, config_file_path=config llm_router,
) llm_model_list,
general_settings,
) = await proxy_config.load_config(router=llm_router, config_file_path=config)
if headers: # model-specific param if headers: # model-specific param
user_headers = headers user_headers = headers
dynamic_config[user_model]["headers"] = headers dynamic_config[user_model]["headers"] = headers
@ -988,7 +1090,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -1000,12 +1102,11 @@ async def startup_event():
print_verbose(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path # check if it's a valid file path
if os.path.isfile(worker_config): if os.path.isfile(worker_config):
initialize(config=worker_config) await initialize(**worker_config)
else: else:
# if not, assume it's a json string # if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config) await initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if use_background_health_checks: if use_background_health_checks:
@ -1013,10 +1114,6 @@ async def startup_event():
_run_background_health_check() _run_background_health_check()
) # start the background health check coroutine. ) # start the background health check coroutine.
print_verbose(f"prisma client - {prisma_client}")
if prisma_client is not None:
await prisma_client.connect()
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
await generate_key_helper_fn( await generate_key_helper_fn(
@ -1220,7 +1317,7 @@ async def chat_completion(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
background_tasks: BackgroundTasks = BackgroundTasks(), background_tasks: BackgroundTasks = BackgroundTasks(),
): ):
global general_settings, user_debug, proxy_logging_obj global general_settings, user_debug, proxy_logging_obj, llm_model_list
try: try:
data = {} data = {}
body = await request.body() body = await request.body()
@ -1673,6 +1770,7 @@ async def generate_key_fn(
- expires: (datetime) Datetime object for when key expires. - expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
""" """
print_verbose("entered /key/generate")
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse( return GenerateKeyResponse(
@ -1825,7 +1923,7 @@ async def user_auth(request: Request):
### Check if user email in user table ### Check if user email in user table
response = await prisma_client.get_generic_data( response = await prisma_client.get_generic_data(
key="user_email", value=user_email, db="users" key="user_email", value=user_email, table_name="users"
) )
### if so - generate a 24 hr key with that user id ### if so - generate a 24 hr key with that user id
if response is not None: if response is not None:
@ -1883,16 +1981,13 @@ async def user_update(request: Request):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def add_new_model(model_params: ModelParams): async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try: try:
print_verbose(f"User config path: {user_config_file_path}")
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file) print_verbose(f"User config path: {user_config_file_path}")
else:
config = {"model_list": []}
backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
# Add the new model to the config # Add the new model to the config
model_info = model_params.model_info.json() model_info = model_params.model_info.json()
@ -1907,22 +2002,8 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"updated model list: {config['model_list']}") print_verbose(f"updated model list: {config['model_list']}")
# Save the updated config # Save new config
with open(f"{user_config_file_path}", "w") as config_file: await proxy_config.save_config(new_config=config)
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid Model passed in")
print_verbose(f"llm_model_list: {llm_model_list}")
return {"message": "Model added successfully"} return {"message": "Model added successfully"}
except Exception as e: except Exception as e:
@ -1949,13 +2030,10 @@ async def add_new_model(model_params: ModelParams):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def model_info_v1(request: Request): async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []} # handle base case
all_models = config["model_list"] all_models = config["model_list"]
for model in all_models: for model in all_models:
@ -1984,18 +2062,18 @@ async def model_info_v1(request: Request):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def delete_model(model_info: ModelInfoDelete): async def delete_model(model_info: ModelInfoDelete):
global llm_router, llm_model_list, general_settings, user_config_file_path global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try: try:
if not os.path.exists(user_config_file_path): if not os.path.exists(user_config_file_path):
raise HTTPException(status_code=404, detail="Config file does not exist.") raise HTTPException(status_code=404, detail="Config file does not exist.")
with open(user_config_file_path, "r") as config_file: # Load existing config
config = yaml.safe_load(config_file) config = await proxy_config.get_config()
# If model_list is not in the config, nothing can be deleted # If model_list is not in the config, nothing can be deleted
if "model_list" not in config: if len(config.get("model_list", [])) == 0:
raise HTTPException( raise HTTPException(
status_code=404, detail="No model list available in the config." status_code=400, detail="No model list available in the config."
) )
# Check if the model with the specified model_id exists # Check if the model with the specified model_id exists
@ -2008,19 +2086,14 @@ async def delete_model(model_info: ModelInfoDelete):
# If the model was not found, return an error # If the model was not found, return an error
if model_to_delete is None: if model_to_delete is None:
raise HTTPException( raise HTTPException(
status_code=404, detail="Model with given model_id not found." status_code=400, detail="Model with given model_id not found."
) )
# Remove model from the list and save the updated config # Remove model from the list and save the updated config
config["model_list"].remove(model_to_delete) config["model_list"].remove(model_to_delete)
with open(user_config_file_path, "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# Update Router
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
# Save updated config
config = await proxy_config.save_config(new_config=config)
return {"message": "Model deleted successfully"} return {"message": "Model deleted successfully"}
except HTTPException as e: except HTTPException as e:
@ -2200,14 +2273,11 @@ async def update_config(config_info: ConfigYAML):
Currently supports modifying General Settings + LiteLLM settings Currently supports modifying General Settings + LiteLLM settings
""" """
global llm_router, llm_model_list, general_settings global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj
try: try:
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {}
backup_config = copy.deepcopy(config) backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
@ -2240,20 +2310,13 @@ async def update_config(config_info: ConfigYAML):
} }
# Save the updated config # Save the updated config
with open(f"{user_config_file_path}", "w") as config_file: await proxy_config.save_config(new_config=config)
yaml.dump(config, config_file, default_flow_style=False)
# update Router # Test new connections
try: ## Slack
llm_router, llm_model_list, general_settings = load_router_config( if "slack" in config.get("general_settings", {}).get("alerting", []):
router=llm_router, config_file_path=user_config_file_path await proxy_logging_obj.alerting_handler(
) message="This is a test", level="Low"
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
) )
return {"message": "Config updated successfully"} return {"message": "Config updated successfully"}
except HTTPException as e: except HTTPException as e:
@ -2263,6 +2326,21 @@ async def update_config(config_info: ConfigYAML):
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}") raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}")
@router.get(
"/config/get",
tags=["config.yaml"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_config():
"""
Master key only.
Returns the config. Mainly used for testing.
"""
global proxy_config
return await proxy_config.get_config()
@router.get("/config/yaml", tags=["config.yaml"]) @router.get("/config/yaml", tags=["config.yaml"])
async def config_yaml_endpoint(config_info: ConfigYAML): async def config_yaml_endpoint(config_info: ConfigYAML):
""" """
@ -2351,6 +2429,28 @@ async def health_endpoint(
} }
@router.get("/health/readiness", tags=["health"])
async def health_readiness():
"""
Unprotected endpoint for checking if worker can receive requests
"""
global prisma_client
if prisma_client is not None: # if db passed in, check if it's connected
if prisma_client.db.is_connected() == True:
return {"status": "healthy"}
else:
return {"status": "healthy"}
raise HTTPException(status_code=503, detail="Service Unhealthy")
@router.get("/health/liveliness", tags=["health"])
async def health_liveliness():
"""
Unprotected endpoint for checking if worker is alive
"""
return "I'm alive!"
@router.get("/") @router.get("/")
async def home(request: Request): async def home(request: Request):
return "LiteLLM: RUNNING" return "LiteLLM: RUNNING"

View file

@ -25,4 +25,9 @@ model LiteLLM_VerificationToken {
user_id String? user_id String?
max_parallel_requests Int? max_parallel_requests Int?
metadata Json @default("{}") metadata Json @default("{}")
}
model LiteLLM_Config {
param_name String @id
param_value Json?
} }

View file

@ -250,31 +250,37 @@ def on_backoff(details):
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose( ### Check if prisma client can be imported (setup done in Docker build)
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
self.connected = False
os.environ["DATABASE_URL"] = database_url
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try: try:
subprocess.run(["prisma", "generate"]) from prisma import Client # type: ignore
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Client # type: ignore
self.db = Client() # Client to connect to Prisma db os.environ["DATABASE_URL"] = database_url
self.db = Client() # Client to connect to Prisma db
except: # if not - go through normal setup process
print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
os.environ["DATABASE_URL"] = database_url
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(["prisma", "generate"])
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Client # type: ignore
self.db = Client() # Client to connect to Prisma db
def hash_token(self, token: str): def hash_token(self, token: str):
# Hash the string using SHA-256 # Hash the string using SHA-256
@ -301,20 +307,24 @@ class PrismaClient:
self, self,
key: str, key: str,
value: Any, value: Any,
db: Literal["users", "keys"], table_name: Literal["users", "keys", "config"],
): ):
""" """
Generic implementation of get data Generic implementation of get data
""" """
try: try:
if db == "users": if table_name == "users":
response = await self.db.litellm_usertable.find_first( response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif db == "keys": elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response return response
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
@ -336,15 +346,19 @@ class PrismaClient:
user_id: Optional[str] = None, user_id: Optional[str] = None,
): ):
try: try:
print_verbose("PrismaClient: get_data")
response = None response = None
if token is not None: if token is not None:
# check if plain text or hash # check if plain text or hash
hashed_token = token hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
print_verbose("PrismaClient: find_unique")
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token} where={"token": hashed_token}
) )
print_verbose(f"PrismaClient: response={response}")
if response: if response:
# Token exists, now check expiration. # Token exists, now check expiration.
if response.expires is not None and expires is not None: if response.expires is not None and expires is not None:
@ -372,6 +386,10 @@ class PrismaClient:
) )
return response return response
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback
traceback.print_exc()
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
) )
@ -385,40 +403,71 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def insert_data(self, data: dict): async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
try: try:
token = data["token"] if table_name == "user+key":
hashed_token = self.hash_token(token=token) token = data["token"]
db_data = self.jsonify_object(data=data) hashed_token = self.hash_token(token=token)
db_data["token"] = hashed_token db_data = self.jsonify_object(data=data)
max_budget = db_data.pop("max_budget", None) db_data["token"] = hashed_token
user_email = db_data.pop("user_email", None) max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore user_email = db_data.pop("user_email", None)
where={ print_verbose(
"token": hashed_token, "PrismaClient: Before upsert into litellm_verificationtoken"
}, )
data={ new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
"create": {**db_data}, # type: ignore where={
"update": {}, # don't do anything if it already exists "token": hashed_token,
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
}, },
"update": {}, # don't do anything if it already exists data={
}, "create": {**db_data}, # type: ignore
) "update": {}, # don't do anything if it already exists
return new_verification_token },
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
) )
@ -505,11 +554,7 @@ class PrismaClient:
) )
async def connect(self): async def connect(self):
try: try:
if self.connected == False: await self.db.connect()
await self.db.connect()
self.connected = True
else:
return
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -773,6 +773,10 @@ class Router:
) )
original_exception = e original_exception = e
try: try:
if (
hasattr(e, "status_code") and e.status_code == 400
): # don't retry a malformed request
raise e
self.print_verbose(f"Trying to fallback b/w models") self.print_verbose(f"Trying to fallback b/w models")
if ( if (
isinstance(e, litellm.ContextWindowExceededError) isinstance(e, litellm.ContextWindowExceededError)
@ -846,7 +850,7 @@ class Router:
return response return response
except Exception as e: except Exception as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if ( if (
isinstance(original_exception, litellm.ContextWindowExceededError) isinstance(original_exception, litellm.ContextWindowExceededError)
and context_window_fallbacks is None and context_window_fallbacks is None
@ -864,12 +868,12 @@ class Router:
min_timeout=self.retry_after, min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
elif ( elif hasattr(original_exception, "status_code") and litellm._should_retry(
hasattr(original_exception, "status_code") status_code=original_exception.status_code
and hasattr(original_exception, "response")
and litellm._should_retry(status_code=original_exception.status_code)
): ):
if hasattr(original_exception.response, "headers"): if hasattr(original_exception, "response") and hasattr(
original_exception.response, "headers"
):
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
remaining_retries=num_retries, remaining_retries=num_retries,
max_retries=num_retries, max_retries=num_retries,
@ -1326,6 +1330,7 @@ class Router:
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,

View file

@ -138,14 +138,15 @@ def test_async_completion_cloudflare():
response = await litellm.acompletion( response = await litellm.acompletion(
model="cloudflare/@cf/meta/llama-2-7b-chat-int8", model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
messages=[{"content": "what llm are you", "role": "user"}], messages=[{"content": "what llm are you", "role": "user"}],
max_tokens=50, max_tokens=5,
num_retries=3,
) )
print(response) print(response)
return response return response
response = asyncio.run(test()) response = asyncio.run(test())
text_response = response["choices"][0]["message"]["content"] text_response = response["choices"][0]["message"]["content"]
assert len(text_response) > 5 # more than 5 chars in response assert len(text_response) > 1 # more than 1 chars in response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -166,7 +167,7 @@ def test_get_cloudflare_response_streaming():
model="cloudflare/@cf/meta/llama-2-7b-chat-int8", model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
messages=messages, messages=messages,
stream=True, stream=True,
timeout=5, num_retries=3, # cloudflare ai workers is EXTREMELY UNSTABLE
) )
print(type(response)) print(type(response))

View file

@ -91,7 +91,7 @@ def test_caching_with_cache_controls():
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0} model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
) )
response2 = completion( response2 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 10} model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
@ -105,7 +105,7 @@ def test_caching_with_cache_controls():
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5} model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
) )
response2 = completion( response2 = completion(
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 5} model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
) )
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
@ -167,6 +167,8 @@ small text
def test_embedding_caching(): def test_embedding_caching():
import time import time
# litellm.set_verbose = True
litellm.cache = Cache() litellm.cache = Cache()
text_to_embed = [embedding_large_text] text_to_embed = [embedding_large_text]
start_time = time.time() start_time = time.time()
@ -182,7 +184,7 @@ def test_embedding_caching():
model="text-embedding-ada-002", input=text_to_embed, caching=True model="text-embedding-ada-002", input=text_to_embed, caching=True
) )
end_time = time.time() end_time = time.time()
print(f"embedding2: {embedding2}") # print(f"embedding2: {embedding2}")
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None litellm.cache = None
@ -274,7 +276,7 @@ def test_redis_cache_completion():
port=os.environ["REDIS_PORT"], port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"], password=os.environ["REDIS_PASSWORD"],
) )
print("test2 for caching") print("test2 for Redis Caching - non streaming")
response1 = completion( response1 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20 model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
) )
@ -326,6 +328,10 @@ def test_redis_cache_completion():
print(f"response4: {response4}") print(f"response4: {response4}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
assert response1.id == response2.id
assert response1.created == response2.created
assert response1.choices[0].message.content == response2.choices[0].message.content
# test_redis_cache_completion() # test_redis_cache_completion()
@ -395,7 +401,7 @@ def test_redis_cache_completion_stream():
""" """
# test_redis_cache_completion_stream() test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream(): def test_redis_cache_acompletion_stream():
@ -529,6 +535,7 @@ def test_redis_cache_acompletion_stream_bedrock():
assert ( assert (
response_1_content == response_2_content response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
@ -537,7 +544,7 @@ def test_redis_cache_acompletion_stream_bedrock():
raise e raise e
def test_s3_cache_acompletion_stream_bedrock(): def test_s3_cache_acompletion_stream_azure():
import asyncio import asyncio
try: try:
@ -556,10 +563,13 @@ def test_s3_cache_acompletion_stream_bedrock():
response_1_content = "" response_1_content = ""
response_2_content = "" response_2_content = ""
response_1_created = ""
response_2_created = ""
async def call1(): async def call1():
nonlocal response_1_content nonlocal response_1_content, response_1_created
response1 = await litellm.acompletion( response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1", model="azure/chatgpt-v-2",
messages=messages, messages=messages,
max_tokens=40, max_tokens=40,
temperature=1, temperature=1,
@ -567,6 +577,7 @@ def test_s3_cache_acompletion_stream_bedrock():
) )
async for chunk in response1: async for chunk in response1:
print(chunk) print(chunk)
response_1_created = chunk.created
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
@ -575,9 +586,9 @@ def test_s3_cache_acompletion_stream_bedrock():
print("\n\n Response 1 content: ", response_1_content, "\n\n") print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2(): async def call2():
nonlocal response_2_content nonlocal response_2_content, response_2_created
response2 = await litellm.acompletion( response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1", model="azure/chatgpt-v-2",
messages=messages, messages=messages,
max_tokens=40, max_tokens=40,
temperature=1, temperature=1,
@ -586,14 +597,23 @@ def test_s3_cache_acompletion_stream_bedrock():
async for chunk in response2: async for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
response_2_created = chunk.created
print(response_2_content) print(response_2_content)
asyncio.run(call2()) asyncio.run(call2())
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert ( assert (
response_1_content == response_2_content response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
# prioritizing getting a new deploy out - will look at this in the next deploy
# print("response 1 created", response_1_created)
# print("response 2 created", response_2_created)
# assert response_1_created == response_2_created
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
@ -602,7 +622,7 @@ def test_s3_cache_acompletion_stream_bedrock():
raise e raise e
test_s3_cache_acompletion_stream_bedrock() # test_s3_cache_acompletion_stream_azure()
# test_redis_cache_acompletion_stream_bedrock() # test_redis_cache_acompletion_stream_bedrock()

View file

@ -749,10 +749,14 @@ def test_completion_ollama_hosted():
model="ollama/phi", model="ollama/phi",
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
num_retries=3,
timeout=90,
api_base="https://test-ollama-endpoint.onrender.com", api_base="https://test-ollama-endpoint.onrender.com",
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -1626,6 +1630,7 @@ def test_completion_anyscale_api():
def test_azure_cloudflare_api(): def test_azure_cloudflare_api():
litellm.set_verbose = True
try: try:
messages = [ messages = [
{ {
@ -1641,11 +1646,12 @@ def test_azure_cloudflare_api():
) )
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}")
traceback.print_exc() traceback.print_exc()
pass pass
# test_azure_cloudflare_api() test_azure_cloudflare_api()
def test_completion_anyscale_2(): def test_completion_anyscale_2():
@ -1931,6 +1937,7 @@ def test_completion_cloudflare():
model="cloudflare/@cf/meta/llama-2-7b-chat-int8", model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
messages=[{"content": "what llm are you", "role": "user"}], messages=[{"content": "what llm are you", "role": "user"}],
max_tokens=15, max_tokens=15,
num_retries=3,
) )
print(response) print(response)
@ -1938,7 +1945,7 @@ def test_completion_cloudflare():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_cloudflare() test_completion_cloudflare()
def test_moderation(): def test_moderation():

View file

@ -103,7 +103,7 @@ def test_cost_azure_gpt_35():
), ),
) )
], ],
model="azure/gpt-35-turbo", # azure always has model written like this model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38), usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
) )
@ -125,3 +125,36 @@ def test_cost_azure_gpt_35():
test_cost_azure_gpt_35() test_cost_azure_gpt_35()
def test_cost_azure_embedding():
try:
import asyncio
litellm.set_verbose = True
async def _test():
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm", "gm"],
)
print(response)
return response
response = asyncio.run(_test())
cost = litellm.completion_cost(completion_response=response)
print("Cost", cost)
expected_cost = float("7e-07")
assert cost == expected_cost
except Exception as e:
pytest.fail(
f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
)
# test_cost_azure_embedding()

View file

@ -0,0 +1,15 @@
model_list:
- model_name: azure-cloudflare
litellm_params:
model: azure/chatgpt-v-2
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
litellm_settings:
set_verbose: True
cache: True # set cache responses to True
cache_params: # set cache params for s3
type: s3
s3_bucket_name: cache-bucket-litellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3

View file

@ -9,6 +9,11 @@ model_list:
api_key: os.environ/AZURE_CANADA_API_KEY api_key: os.environ/AZURE_CANADA_API_KEY
model: azure/gpt-35-turbo model: azure/gpt-35-turbo
model_name: azure-model model_name: azure-model
- litellm_params:
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
api_key: os.environ/AZURE_API_KEY
model: azure/chatgpt-v-2
model_name: azure-cloudflare-model
- litellm_params: - litellm_params:
api_base: https://openai-france-1234.openai.azure.com api_base: https://openai-france-1234.openai.azure.com
api_key: os.environ/AZURE_FRANCE_API_KEY api_key: os.environ/AZURE_FRANCE_API_KEY

View file

@ -59,6 +59,7 @@ def test_openai_embedding():
def test_openai_azure_embedding_simple(): def test_openai_azure_embedding_simple():
try: try:
litellm.set_verbose = True
response = embedding( response = embedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
input=["good morning from litellm"], input=["good morning from litellm"],
@ -70,6 +71,10 @@ def test_openai_azure_embedding_simple():
response_keys response_keys
) # assert litellm response has expected keys from OpenAI embedding response ) # assert litellm response has expected keys from OpenAI embedding response
request_cost = litellm.completion_cost(completion_response=response)
print("Calculated request cost=", request_cost)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -260,15 +265,22 @@ def test_aembedding():
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
) )
print(response) print(response)
return response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call()) response = asyncio.run(embedding_call())
print("Before caclulating cost, response", response)
cost = litellm.completion_cost(completion_response=response)
print("COST=", cost)
assert cost == float("1e-06")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_aembedding() test_aembedding()
def test_aembedding_azure(): def test_aembedding_azure():

View file

@ -10,7 +10,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -22,6 +22,7 @@ from litellm.proxy.proxy_server import (
router, router,
save_worker_config, save_worker_config,
initialize, initialize,
ProxyConfig,
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
@ -36,7 +37,7 @@ def client():
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
app = FastAPI() app = FastAPI()
initialize(config=config_fp) asyncio.run(initialize(config=config_fp))
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
return TestClient(app) return TestClient(app)

View file

@ -23,6 +23,7 @@ from litellm.proxy.proxy_server import (
router, router,
save_worker_config, save_worker_config,
initialize, initialize,
startup_event,
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
@ -39,8 +40,8 @@ python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
def client(): def client():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml" config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
initialize(config=config_fp)
app = FastAPI() app = FastAPI()
asyncio.run(initialize(config=config_fp))
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
return TestClient(app) return TestClient(app)

View file

@ -24,7 +24,7 @@ from litellm.proxy.proxy_server import (
def client(): def client():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_bad_config.yaml" config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
initialize(config=config_fp) asyncio.run(initialize(config=config_fp))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
return TestClient(app) return TestClient(app)
@ -149,7 +149,7 @@ def test_chat_completion_exception_any_model(client):
response=response response=response
) )
print("Exception raised=", openai_exception) print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.NotFoundError) assert isinstance(openai_exception, openai.BadRequestError)
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
@ -170,7 +170,7 @@ def test_embedding_exception_any_model(client):
response=response response=response
) )
print("Exception raised=", openai_exception) print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.NotFoundError) assert isinstance(openai_exception, openai.BadRequestError)
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")

View file

@ -10,7 +10,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, logging import pytest, logging, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -46,7 +46,7 @@ def client_no_auth():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
initialize(config=config_fp, debug=True) asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app

View file

@ -10,7 +10,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, logging import pytest, logging, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -45,7 +45,7 @@ def client_no_auth():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
initialize(config=config_fp) asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
@ -280,33 +280,42 @@ def test_chat_completion_optional_params(client_no_auth):
# test_chat_completion_optional_params() # test_chat_completion_optional_params()
# Test Reading config.yaml file # Test Reading config.yaml file
from litellm.proxy.proxy_server import load_router_config from litellm.proxy.proxy_server import ProxyConfig
def test_load_router_config(): def test_load_router_config():
try: try:
import asyncio
print("testing reading config") print("testing reading config")
# this is a basic config.yaml with only a model # this is a basic config.yaml with only a model
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
result = load_router_config( proxy_config = ProxyConfig()
router=None, result = asyncio.run(
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml", proxy_config.load_config(
router=None,
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
)
) )
print(result) print(result)
assert len(result[1]) == 1 assert len(result[1]) == 1
# this is a load balancing config yaml # this is a load balancing config yaml
result = load_router_config( result = asyncio.run(
router=None, proxy_config.load_config(
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml", router=None,
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
)
) )
print(result) print(result)
assert len(result[1]) == 2 assert len(result[1]) == 2
# config with general settings - custom callbacks # config with general settings - custom callbacks
result = load_router_config( result = asyncio.run(
router=None, proxy_config.load_config(
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml", router=None,
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
)
) )
print(result) print(result)
assert len(result[1]) == 2 assert len(result[1]) == 2
@ -314,9 +323,11 @@ def test_load_router_config():
# tests for litellm.cache set from config # tests for litellm.cache set from config
print("testing reading proxy config for cache") print("testing reading proxy config for cache")
litellm.cache = None litellm.cache = None
load_router_config( asyncio.run(
router=None, proxy_config.load_config(
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml", router=None,
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
)
) )
assert litellm.cache is not None assert litellm.cache is not None
assert "redis_client" in vars( assert "redis_client" in vars(
@ -329,10 +340,14 @@ def test_load_router_config():
"aembedding", "aembedding",
] # init with all call types ] # init with all call types
litellm.disable_cache()
print("testing reading proxy config for cache with params") print("testing reading proxy config for cache with params")
load_router_config( asyncio.run(
router=None, proxy_config.load_config(
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml", router=None,
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
)
) )
assert litellm.cache is not None assert litellm.cache is not None
print(litellm.cache) print(litellm.cache)

View file

@ -1,38 +1,103 @@
# #### What this tests #### #### What this tests ####
# # This tests using caching w/ litellm which requires SSL=True # This tests using caching w/ litellm which requires SSL=True
import sys, os
import traceback
from dotenv import load_dotenv
# import sys, os load_dotenv()
# import time import os, io
# import traceback
# from dotenv import load_dotenv
# load_dotenv() # this file is to test litellm/proxy
# import os
# sys.path.insert( sys.path.insert(
# 0, os.path.abspath("../..") 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
# import pytest import pytest, logging, asyncio
# import litellm import litellm
# from litellm import embedding, completion from litellm import embedding, completion, completion_cost, Timeout
# from litellm.caching import Cache from litellm import RateLimitError
# messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}] # Configure logging
logging.basicConfig(
level=logging.DEBUG, # Set the desired logging level
format="%(asctime)s - %(levelname)s - %(message)s",
)
# @pytest.mark.skip(reason="local proxy test") # test /chat/completion request to the proxy
# def test_caching_v2(): # test in memory cache from fastapi.testclient import TestClient
# try: from fastapi import FastAPI
# response1 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000") from litellm.proxy.proxy_server import (
# response2 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000") router,
# print(f"response1: {response1}") save_worker_config,
# print(f"response2: {response2}") initialize,
# litellm.cache = None # disable cache ) # Replace with the actual module where your FastAPI router is defined
# if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
# print(f"response1: {response1}")
# print(f"response2: {response2}")
# raise Exception()
# except Exception as e:
# print(f"error occurred: {traceback.format_exc()}")
# pytest.fail(f"Error occurred: {e}")
# test_caching_v2() # Your bearer token
token = ""
headers = {"Authorization": f"Bearer {token}"}
@pytest.fixture(scope="function")
def client_no_auth():
# Assuming litellm.proxy.proxy_server is an object
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI()
app.include_router(router) # Include your router in the test app
return TestClient(app)
def generate_random_word(length=4):
import string, random
letters = string.ascii_lowercase
return "".join(random.choice(letters) for _ in range(length))
def test_chat_completion(client_no_auth):
global headers
try:
user_message = f"Write a poem about {generate_random_word()}"
messages = [{"content": user_message, "role": "user"}]
# Your test data
test_data = {
"model": "azure-cloudflare",
"messages": messages,
"max_tokens": 10,
}
print("testing proxy server with chat completions")
response = client_no_auth.post("/v1/chat/completions", json=test_data)
print(f"response - {response.text}")
assert response.status_code == 200
response = response.json()
print(response)
content = response["choices"][0]["message"]["content"]
response1_id = response["id"]
print("\n content", content)
assert len(content) > 1
print("\nmaking 2nd request to proxy. Testing caching + non streaming")
response = client_no_auth.post("/v1/chat/completions", json=test_data)
print(f"response - {response.text}")
assert response.status_code == 200
response = response.json()
print(response)
response2_id = response["id"]
assert response1_id == response2_id
litellm.disable_cache()
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")

View file

@ -29,6 +29,7 @@ from litellm.proxy.proxy_server import (
router, router,
save_worker_config, save_worker_config,
startup_event, startup_event,
asyncio,
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
@ -39,7 +40,7 @@ save_worker_config(
alias=None, alias=None,
api_base=None, api_base=None,
api_version=None, api_version=None,
debug=False, debug=True,
temperature=None, temperature=None,
max_tokens=None, max_tokens=None,
request_timeout=600, request_timeout=600,
@ -51,24 +52,38 @@ save_worker_config(
save=False, save=False,
use_queue=False, use_queue=False,
) )
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup") import asyncio
async def wrapper_startup_event():
await startup_event()
@pytest.fixture
def event_loop():
"""Create an instance of the default event loop for each test case."""
policy = asyncio.WindowsSelectorEventLoopPolicy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close
res.close = lambda: None
yield res
res._close()
# Here you create a fixture that will be used by your tests # Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app) # Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True) @pytest.fixture(scope="function")
def client(): def client():
from litellm.proxy.proxy_server import cleanup_router_config_variables from litellm.proxy.proxy_server import cleanup_router_config_variables, initialize
cleanup_router_config_variables() cleanup_router_config_variables() # rest proxy before test
with TestClient(app) as client:
yield client asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI()
app.include_router(router) # Include your router in the test app
return TestClient(app)
def test_add_new_key(client): def test_add_new_key(client):
@ -79,7 +94,7 @@ def test_add_new_key(client):
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m", "duration": "20m",
} }
print("testing proxy server") print("testing proxy server - test_add_new_key")
# Your bearer token # Your bearer token
token = os.getenv("PROXY_MASTER_KEY") token = os.getenv("PROXY_MASTER_KEY")
@ -121,7 +136,7 @@ def test_update_new_key(client):
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m", "duration": "20m",
} }
print("testing proxy server") print("testing proxy server-test_update_new_key")
# Your bearer token # Your bearer token
token = os.getenv("PROXY_MASTER_KEY") token = os.getenv("PROXY_MASTER_KEY")

View file

@ -98,6 +98,73 @@ def test_init_clients_basic():
# test_init_clients_basic() # test_init_clients_basic()
def test_init_clients_basic_azure_cloudflare():
# init azure + cloudflare
# init OpenAI gpt-3.5
# init OpenAI text-embedding
# init OpenAI comptaible - Mistral/mistral-medium
# init OpenAI compatible - xinference/bge
litellm.set_verbose = True
try:
print("Test basic client init")
model_list = [
{
"model_name": "azure-cloudflare",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": "https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1",
},
},
{
"model_name": "gpt-openai",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "text-embedding-ada-002",
"litellm_params": {
"model": "text-embedding-ada-002",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "mistral",
"litellm_params": {
"model": "mistral/mistral-tiny",
"api_key": os.getenv("MISTRAL_API_KEY"),
},
},
{
"model_name": "bge-base-en",
"litellm_params": {
"model": "xinference/bge-base-en",
"api_base": "http://127.0.0.1:9997/v1",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list)
for elem in router.model_list:
model_id = elem["model_info"]["id"]
assert router.cache.get_cache(f"{model_id}_client") is not None
assert router.cache.get_cache(f"{model_id}_async_client") is not None
assert router.cache.get_cache(f"{model_id}_stream_client") is not None
assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
print("PASSED !")
# see if we can init clients without timeout or max retries set
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
# test_init_clients_basic_azure_cloudflare()
def test_timeouts_router(): def test_timeouts_router():
""" """
Test the timeouts of the router with multiple clients. This HASas to raise a timeout error Test the timeouts of the router with multiple clients. This HASas to raise a timeout error

View file

@ -0,0 +1,137 @@
#### What this tests ####
# This tests if the router sends back a policy violation, without retries
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger):
success: bool = False
failure: bool = False
previous_models: int = 0
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
print(
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
)
self.previous_models += len(
kwargs["litellm_params"]["metadata"]["previous_models"]
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
)
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
kwargs = {
"model": "azure/gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "vorrei vedere la cosa più bella ad Ercolano. Qualè?",
},
],
}
@pytest.mark.asyncio
async def test_async_fallbacks():
litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
]
router = Router(
model_list=model_list,
num_retries=3,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
# context_window_fallbacks=[
# {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
# {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
# ],
set_verbose=False,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
try:
response = await router.acompletion(**kwargs)
pytest.fail(
f"An exception occurred: {e}"
) # should've raised azure policy error
except litellm.Timeout as e:
pass
except Exception as e:
await asyncio.sleep(
0.05
) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 0 # 0 retries, 0 fallback
router.reset()
finally:
router.reset()

View file

@ -306,6 +306,8 @@ def test_completion_ollama_hosted_stream():
model="ollama/phi", model="ollama/phi",
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
num_retries=3,
timeout=90,
api_base="https://test-ollama-endpoint.onrender.com", api_base="https://test-ollama-endpoint.onrender.com",
stream=True, stream=True,
) )

View file

@ -9,7 +9,7 @@
import sys, re, binascii, struct import sys, re, binascii, struct
import litellm import litellm
import dotenv, json, traceback, threading, base64 import dotenv, json, traceback, threading, base64, ast
import subprocess, os import subprocess, os
import litellm, openai import litellm, openai
import itertools import itertools
@ -1975,7 +1975,10 @@ def client(original_function):
if ( if (
(kwargs.get("caching", None) is None and litellm.cache is not None) (kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) == True or kwargs.get("caching", False) == True
or kwargs.get("cache", {}).get("no-cache", False) != True or (
kwargs.get("cache", None) is not None
and kwargs.get("cache", {}).get("no-cache", False) != True
)
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose(f"INSIDE CHECKING CACHE") print_verbose(f"INSIDE CHECKING CACHE")
@ -2737,6 +2740,8 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
completion_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost model_cost_ref = litellm.model_cost
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map")
if model in model_cost_ref: if model in model_cost_ref:
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
@ -2746,6 +2751,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model: elif "ft:gpt-3.5-turbo" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm # fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
@ -2756,6 +2762,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_llms: elif model in litellm.azure_llms:
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
model = litellm.azure_llms[model] model = litellm.azure_llms[model]
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
@ -2764,19 +2771,29 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
model_cost_ref[model]["output_cost_per_token"] * completion_tokens model_cost_ref[model]["output_cost_per_token"] * completion_tokens
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else: elif model in litellm.azure_embedding_models:
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
input_cost_sum = 0 model = litellm.azure_embedding_models[model]
output_cost_sum = 0 prompt_tokens_cost_usd_dollar = (
model_cost_ref = litellm.model_cost model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
for model in model_cost_ref: )
input_cost_sum += model_cost_ref[model]["input_cost_per_token"] completion_tokens_cost_usd_dollar = (
output_cost_sum += model_cost_ref[model]["output_cost_per_token"] model_cost_ref[model]["output_cost_per_token"] * completion_tokens
avg_input_cost = input_cost_sum / len(model_cost_ref.keys()) )
avg_output_cost = output_cost_sum / len(model_cost_ref.keys())
prompt_tokens_cost_usd_dollar = avg_input_cost * prompt_tokens
completion_tokens_cost_usd_dollar = avg_output_cost * completion_tokens
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else:
# if model is not in model_prices_and_context_window.json. Raise an exception-let users know
error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}\n"
raise litellm.exceptions.NotFoundError( # type: ignore
message=error_str,
model=model,
response=httpx.Response(
status_code=404,
content=error_str,
request=httpx.request(method="cost_per_token", url="https://github.com/BerriAI/litellm"), # type: ignore
),
llm_provider="",
)
def completion_cost( def completion_cost(
@ -2818,8 +2835,10 @@ def completion_cost(
completion_tokens = 0 completion_tokens = 0
if completion_response is not None: if completion_response is not None:
# get input/output tokens from completion_response # get input/output tokens from completion_response
prompt_tokens = completion_response["usage"]["prompt_tokens"] prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = completion_response["usage"]["completion_tokens"] completion_tokens = completion_response.get("usage", {}).get(
"completion_tokens", 0
)
model = ( model = (
model or completion_response["model"] model or completion_response["model"]
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
@ -2829,6 +2848,10 @@ def completion_cost(
elif len(prompt) > 0: elif len(prompt) > 0:
prompt_tokens = token_counter(model=model, text=prompt) prompt_tokens = token_counter(model=model, text=prompt)
completion_tokens = token_counter(model=model, text=completion) completion_tokens = token_counter(model=model, text=completion)
if model == None:
raise ValueError(
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
)
# Calculate cost based on prompt_tokens, completion_tokens # Calculate cost based on prompt_tokens, completion_tokens
if "togethercomputer" in model or "together_ai" in model: if "togethercomputer" in model or "together_ai" in model:
@ -2849,8 +2872,7 @@ def completion_cost(
) )
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM: Excepton when cost calculating {str(e)}") raise e
return 0.0 # this should not block a users execution path
####### HELPER FUNCTIONS ################ ####### HELPER FUNCTIONS ################
@ -4081,11 +4103,11 @@ def get_llm_provider(
print() # noqa print() # noqa
error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers" error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers"
# maps to openai.NotFoundError, this is raised when openai does not recognize the llm # maps to openai.NotFoundError, this is raised when openai does not recognize the llm
raise litellm.exceptions.NotFoundError( # type: ignore raise litellm.exceptions.BadRequestError( # type: ignore
message=error_str, message=error_str,
model=model, model=model,
response=httpx.Response( response=httpx.Response(
status_code=404, status_code=400,
content=error_str, content=error_str,
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
@ -4915,6 +4937,9 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] =
if "id" in response_object: if "id" in response_object:
model_response_object.id = response_object["id"] model_response_object.id = response_object["id"]
if "created" in response_object:
model_response_object.created = response_object["created"]
if "system_fingerprint" in response_object: if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"] model_response_object.system_fingerprint = response_object["system_fingerprint"]
@ -4959,6 +4984,9 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
if "id" in response_object: if "id" in response_object:
model_response_object.id = response_object["id"] model_response_object.id = response_object["id"]
if "created" in response_object:
model_response_object.created = response_object["created"]
if "system_fingerprint" in response_object: if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"] model_response_object.system_fingerprint = response_object["system_fingerprint"]
@ -5014,6 +5042,9 @@ def convert_to_model_response_object(
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if "created" in response_object:
model_response_object.created = response_object["created"]
if "id" in response_object: if "id" in response_object:
model_response_object.id = response_object["id"] model_response_object.id = response_object["id"]
@ -6621,7 +6652,7 @@ def _is_base64(s):
def get_secret( def get_secret(
secret_name: str, secret_name: str,
default_value: Optional[str] = None, default_value: Optional[Union[str, bool]] = None,
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
@ -6672,9 +6703,24 @@ def get_secret(
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
return secret try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else: else:
return os.environ.get(secret_name) secret = os.environ.get(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
except Exception as e: except Exception as e:
if default_value is not None: if default_value is not None:
return default_value return default_value

View file

@ -111,6 +111,13 @@
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "embedding" "mode": "embedding"
}, },
"text-embedding-ada-002-v2": {
"max_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000,
"litellm_provider": "openai",
"mode": "embedding"
},
"256-x-256/dall-e-2": { "256-x-256/dall-e-2": {
"mode": "image_generation", "mode": "image_generation",
"input_cost_per_pixel": 0.00000024414, "input_cost_per_pixel": 0.00000024414,
@ -242,6 +249,13 @@
"litellm_provider": "azure", "litellm_provider": "azure",
"mode": "chat" "mode": "chat"
}, },
"azure/ada": {
"max_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000,
"litellm_provider": "azure",
"mode": "embedding"
},
"azure/text-embedding-ada-002": { "azure/text-embedding-ada-002": {
"max_tokens": 8191, "max_tokens": 8191,
"input_cost_per_token": 0.0000001, "input_cost_per_token": 0.0000001,

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.16.13" version = "1.16.14"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"
@ -59,7 +59,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.16.13" version = "1.16.14"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

28
retry_push.sh Normal file
View file

@ -0,0 +1,28 @@
#!/bin/bash
retry_count=0
max_retries=3
exit_code=1
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
do
retry_count=$((retry_count+1))
echo "Attempt $retry_count..."
# Run the Prisma db push command
prisma db push --accept-data-loss
exit_code=$?
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
echo "Retrying in 10 seconds..."
sleep 10
fi
done
if [ $exit_code -ne 0 ]; then
echo "Unable to push database changes after $max_retries retries."
exit 1
fi
echo "Database push successful!"

33
schema.prisma Normal file
View file

@ -0,0 +1,33 @@
datasource client {
provider = "postgresql"
url = env("DATABASE_URL")
}
generator client {
provider = "prisma-client-py"
}
model LiteLLM_UserTable {
user_id String @unique
max_budget Float?
spend Float @default(0.0)
user_email String?
}
// required for token gen
model LiteLLM_VerificationToken {
token String @unique
spend Float @default(0.0)
expires DateTime?
models String[]
aliases Json @default("{}")
config Json @default("{}")
user_id String?
max_parallel_requests Int?
metadata Json @default("{}")
}
model LiteLLM_Config {
param_name String @id
param_value Json?
}