forked from phoenix/litellm-mirror
Merge branch 'BerriAI:main' into feature_allow_claude_prefill
This commit is contained in:
commit
53e5e1df07
42 changed files with 1395 additions and 490 deletions
|
@ -134,11 +134,15 @@ jobs:
|
||||||
- run:
|
- run:
|
||||||
name: Trigger Github Action for new Docker Container
|
name: Trigger Github Action for new Docker Container
|
||||||
command: |
|
command: |
|
||||||
|
echo "Install TOML package."
|
||||||
|
python3 -m pip install toml
|
||||||
|
VERSION=$(python3 -c "import toml; print(toml.load('pyproject.toml')['tool']['poetry']['version'])")
|
||||||
|
echo "LiteLLM Version ${VERSION}"
|
||||||
curl -X POST \
|
curl -X POST \
|
||||||
-H "Accept: application/vnd.github.v3+json" \
|
-H "Accept: application/vnd.github.v3+json" \
|
||||||
-H "Authorization: Bearer $GITHUB_TOKEN" \
|
-H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||||
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
|
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
|
||||||
-d '{"ref":"main"}'
|
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"${VERSION}\"}}"
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
version: 2
|
version: 2
|
||||||
|
|
43
.github/workflows/ghcr_deploy.yml
vendored
43
.github/workflows/ghcr_deploy.yml
vendored
|
@ -1,12 +1,10 @@
|
||||||
#
|
# this workflow is triggered by an API call when there is a new PyPI release of LiteLLM
|
||||||
name: Build, Publish LiteLLM Docker Image
|
name: Build, Publish LiteLLM Docker Image. New Release
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
tag:
|
tag:
|
||||||
description: "The tag version you want to build"
|
description: "The tag version you want to build"
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
|
|
||||||
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||||
env:
|
env:
|
||||||
|
@ -46,7 +44,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name || 'latest' }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
build-and-push-image-alpine:
|
build-and-push-image-alpine:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -76,5 +74,38 @@ jobs:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: Dockerfile.alpine
|
dockerfile: Dockerfile.alpine
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta-alpine.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name || 'latest' }}
|
tags: ${{ steps.meta-alpine.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}
|
||||||
labels: ${{ steps.meta-alpine.outputs.labels }}
|
labels: ${{ steps.meta-alpine.outputs.labels }}
|
||||||
|
release:
|
||||||
|
name: "New LiteLLM Release"
|
||||||
|
|
||||||
|
runs-on: "ubuntu-latest"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Display version
|
||||||
|
run: echo "Current version is ${{ github.event.inputs.tag }}"
|
||||||
|
- name: "Set Release Tag"
|
||||||
|
run: echo "RELEASE_TAG=${{ github.event.inputs.tag }}" >> $GITHUB_ENV
|
||||||
|
- name: Display release tag
|
||||||
|
run: echo "RELEASE_TAG is $RELEASE_TAG"
|
||||||
|
- name: "Create release"
|
||||||
|
uses: "actions/github-script@v6"
|
||||||
|
with:
|
||||||
|
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
script: |
|
||||||
|
try {
|
||||||
|
const response = await github.rest.repos.createRelease({
|
||||||
|
draft: false,
|
||||||
|
generate_release_notes: true,
|
||||||
|
name: process.env.RELEASE_TAG,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
prerelease: false,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
tag_name: process.env.RELEASE_TAG,
|
||||||
|
});
|
||||||
|
|
||||||
|
core.exportVariable('RELEASE_ID', response.data.id);
|
||||||
|
core.exportVariable('RELEASE_UPLOAD_URL', response.data.upload_url);
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(error.message);
|
||||||
|
}
|
||||||
|
|
31
.github/workflows/read_pyproject_version.yml
vendored
Normal file
31
.github/workflows/read_pyproject_version.yml
vendored
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
name: Read Version from pyproject.toml
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main # Change this to the default branch of your repository
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
read-version:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.8 # Adjust the Python version as needed
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install toml
|
||||||
|
|
||||||
|
- name: Read version from pyproject.toml
|
||||||
|
id: read-version
|
||||||
|
run: |
|
||||||
|
version=$(python -c 'import toml; print(toml.load("pyproject.toml")["tool"]["commitizen"]["version"])')
|
||||||
|
printf "LITELLM_VERSION=%s" "$version" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Display version
|
||||||
|
run: echo "Current version is $LITELLM_VERSION"
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -31,3 +31,4 @@ proxy_server_config_@.yaml
|
||||||
.gitignore
|
.gitignore
|
||||||
proxy_server_config_2.yaml
|
proxy_server_config_2.yaml
|
||||||
litellm/proxy/secret_managers/credentials.json
|
litellm/proxy/secret_managers/credentials.json
|
||||||
|
hosted_config.yaml
|
||||||
|
|
13
Dockerfile
13
Dockerfile
|
@ -3,7 +3,6 @@ ARG LITELLM_BUILD_IMAGE=python:3.9
|
||||||
|
|
||||||
# Runtime image
|
# Runtime image
|
||||||
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
|
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
|
||||||
|
|
||||||
# Builder stage
|
# Builder stage
|
||||||
FROM $LITELLM_BUILD_IMAGE as builder
|
FROM $LITELLM_BUILD_IMAGE as builder
|
||||||
|
|
||||||
|
@ -35,8 +34,12 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
|
||||||
|
|
||||||
# Runtime stage
|
# Runtime stage
|
||||||
FROM $LITELLM_RUNTIME_IMAGE as runtime
|
FROM $LITELLM_RUNTIME_IMAGE as runtime
|
||||||
|
ARG with_database
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
# Copy the current directory contents into the container at /app
|
||||||
|
COPY . .
|
||||||
|
RUN ls -la /app
|
||||||
|
|
||||||
# Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present
|
# Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present
|
||||||
COPY --from=builder /app/dist/*.whl .
|
COPY --from=builder /app/dist/*.whl .
|
||||||
|
@ -45,6 +48,14 @@ COPY --from=builder /wheels/ /wheels/
|
||||||
# Install the built wheel using pip; again using a wildcard if it's the only file
|
# Install the built wheel using pip; again using a wildcard if it's the only file
|
||||||
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
|
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
|
||||||
|
|
||||||
|
# Check if the with_database argument is set to 'true'
|
||||||
|
RUN echo "Value of with_database is: ${with_database}"
|
||||||
|
# If true, execute the following instructions
|
||||||
|
RUN if [ "$with_database" = "true" ]; then \
|
||||||
|
prisma generate; \
|
||||||
|
chmod +x /app/retry_push.sh; \
|
||||||
|
/app/retry_push.sh; \
|
||||||
|
fi
|
||||||
|
|
||||||
EXPOSE 4000/tcp
|
EXPOSE 4000/tcp
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
LITELLM_MASTER_KEY="sk-1234"
|
LITELLM_MASTER_KEY="sk-1234"
|
||||||
|
|
||||||
############
|
############
|
||||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
# Database - You can change these to any PostgreSQL database.
|
||||||
############
|
############
|
||||||
|
|
||||||
# LITELLM_DATABASE_URL="your-postgres-db-url"
|
DATABASE_URL="your-postgres-db-url"
|
||||||
|
|
||||||
|
|
||||||
############
|
############
|
||||||
|
@ -19,4 +19,4 @@ LITELLM_MASTER_KEY="sk-1234"
|
||||||
# SMTP_HOST = "fake-mail-host"
|
# SMTP_HOST = "fake-mail-host"
|
||||||
# SMTP_USERNAME = "fake-mail-user"
|
# SMTP_USERNAME = "fake-mail-user"
|
||||||
# SMTP_PASSWORD="fake-mail-password"
|
# SMTP_PASSWORD="fake-mail-password"
|
||||||
# SMTP_SENDER_EMAIL="fake-sender-email"
|
# SMTP_SENDER_EMAIL="fake-sender-email"
|
||||||
|
|
|
@ -396,7 +396,48 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## OpenAI Proxy
|
||||||
|
|
||||||
|
Track spend across multiple projects/people
|
||||||
|
|
||||||
|
The proxy provides:
|
||||||
|
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
|
||||||
|
2. [Hooks for logging](https://docs.litellm.ai/docs/proxy/logging#step-1---create-your-custom-litellm-callback-class)
|
||||||
|
3. [Cost tracking](https://docs.litellm.ai/docs/proxy/virtual_keys#tracking-spend)
|
||||||
|
4. [Rate Limiting](https://docs.litellm.ai/docs/proxy/users#set-rate-limits)
|
||||||
|
|
||||||
|
### 📖 Proxy Endpoints - [Swagger Docs](https://litellm-api.up.railway.app/)
|
||||||
|
|
||||||
|
### Quick Start Proxy - CLI
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install litellm[proxy]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 1: Start litellm proxy
|
||||||
|
```shell
|
||||||
|
$ litellm --model huggingface/bigcode/starcoder
|
||||||
|
|
||||||
|
#INFO: Proxy running on http://0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Make ChatCompletions Request to Proxy
|
||||||
|
```python
|
||||||
|
import openai # openai v1.0.0+
|
||||||
|
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## More details
|
## More details
|
||||||
* [exception mapping](./exception_mapping.md)
|
* [exception mapping](./exception_mapping.md)
|
||||||
* [retries + model fallbacks for completion()](./completion/reliable_completions.md)
|
* [retries + model fallbacks for completion()](./completion/reliable_completions.md)
|
||||||
* [tutorial for model fallbacks with completion()](./tutorials/fallbacks.md)
|
* [tutorial for model fallbacks with completion()](./tutorials/fallbacks.md)
|
||||||
|
|
|
@ -161,7 +161,7 @@ litellm_settings:
|
||||||
The proxy support 3 cache-controls:
|
The proxy support 3 cache-controls:
|
||||||
|
|
||||||
- `ttl`: Will cache the response for the user-defined amount of time (in seconds).
|
- `ttl`: Will cache the response for the user-defined amount of time (in seconds).
|
||||||
- `s-max-age`: Will only accept cached responses that are within user-defined range (in seconds).
|
- `s-maxage`: Will only accept cached responses that are within user-defined range (in seconds).
|
||||||
- `no-cache`: Will not return a cached response, but instead call the actual endpoint.
|
- `no-cache`: Will not return a cached response, but instead call the actual endpoint.
|
||||||
|
|
||||||
[Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218)
|
[Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218)
|
||||||
|
@ -237,7 +237,7 @@ chat_completion = client.chat.completions.create(
|
||||||
],
|
],
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
cache={
|
cache={
|
||||||
"s-max-age": 600 # only get responses cached within last 10 minutes
|
"s-maxage": 600 # only get responses cached within last 10 minutes
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
43
docs/my-website/docs/proxy/rules.md
Normal file
43
docs/my-website/docs/proxy/rules.md
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# Post-Call Rules
|
||||||
|
|
||||||
|
Use this to fail a request based on the output of an llm api call.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Step 1: Create a file (e.g. post_call_rules.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def my_custom_rule(input): # receives the model response
|
||||||
|
if len(input) < 5: # trigger fallback if the model response is too short
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2. Point it to your proxy
|
||||||
|
|
||||||
|
```python
|
||||||
|
litellm_settings:
|
||||||
|
post_call_rules: post_call_rules.my_custom_rule
|
||||||
|
num_retries: 3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3. Start + test your proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:8000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{
|
||||||
|
"model": "deepseek-coder",
|
||||||
|
"messages": [{"role":"user","content":"What llm are you?"}],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 10,
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
---
|
||||||
|
|
||||||
|
This will now check if a response is > len 5, and if it fails, it'll retry a call 3 times before failing.
|
|
@ -112,6 +112,7 @@ const sidebars = {
|
||||||
"proxy/reliability",
|
"proxy/reliability",
|
||||||
"proxy/health",
|
"proxy/health",
|
||||||
"proxy/call_hooks",
|
"proxy/call_hooks",
|
||||||
|
"proxy/rules",
|
||||||
"proxy/caching",
|
"proxy/caching",
|
||||||
"proxy/alerting",
|
"proxy/alerting",
|
||||||
"proxy/logging",
|
"proxy/logging",
|
||||||
|
|
|
@ -375,6 +375,45 @@ response = completion(
|
||||||
|
|
||||||
Need a dedicated key? Email us @ krrish@berri.ai
|
Need a dedicated key? Email us @ krrish@berri.ai
|
||||||
|
|
||||||
|
## OpenAI Proxy
|
||||||
|
|
||||||
|
Track spend across multiple projects/people
|
||||||
|
|
||||||
|
The proxy provides:
|
||||||
|
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
|
||||||
|
2. [Hooks for logging](https://docs.litellm.ai/docs/proxy/logging#step-1---create-your-custom-litellm-callback-class)
|
||||||
|
3. [Cost tracking](https://docs.litellm.ai/docs/proxy/virtual_keys#tracking-spend)
|
||||||
|
4. [Rate Limiting](https://docs.litellm.ai/docs/proxy/users#set-rate-limits)
|
||||||
|
|
||||||
|
### 📖 Proxy Endpoints - [Swagger Docs](https://litellm-api.up.railway.app/)
|
||||||
|
|
||||||
|
### Quick Start Proxy - CLI
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install litellm[proxy]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 1: Start litellm proxy
|
||||||
|
```shell
|
||||||
|
$ litellm --model huggingface/bigcode/starcoder
|
||||||
|
|
||||||
|
#INFO: Proxy running on http://0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Make ChatCompletions Request to Proxy
|
||||||
|
```python
|
||||||
|
import openai # openai v1.0.0+
|
||||||
|
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
## More details
|
## More details
|
||||||
* [exception mapping](./exception_mapping.md)
|
* [exception mapping](./exception_mapping.md)
|
||||||
|
|
|
@ -338,7 +338,8 @@ baseten_models: List = [
|
||||||
] # FALCON 7B # WizardLM # Mosaic ML
|
] # FALCON 7B # WizardLM # Mosaic ML
|
||||||
|
|
||||||
|
|
||||||
# used for token counting
|
# used for Cost Tracking & Token counting
|
||||||
|
# https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/
|
||||||
# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
|
# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
|
||||||
azure_llms = {
|
azure_llms = {
|
||||||
"gpt-35-turbo": "azure/gpt-35-turbo",
|
"gpt-35-turbo": "azure/gpt-35-turbo",
|
||||||
|
@ -346,6 +347,10 @@ azure_llms = {
|
||||||
"gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
|
"gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
azure_embedding_models = {
|
||||||
|
"ada": "azure/ada",
|
||||||
|
}
|
||||||
|
|
||||||
petals_models = [
|
petals_models = [
|
||||||
"petals-team/StableBeluga2",
|
"petals-team/StableBeluga2",
|
||||||
]
|
]
|
||||||
|
|
|
@ -11,6 +11,7 @@ import litellm
|
||||||
import time, logging
|
import time, logging
|
||||||
import json, traceback, ast, hashlib
|
import json, traceback, ast, hashlib
|
||||||
from typing import Optional, Literal, List, Union, Any
|
from typing import Optional, Literal, List, Union, Any
|
||||||
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -472,7 +473,10 @@ class Cache:
|
||||||
else:
|
else:
|
||||||
cache_key = self.get_cache_key(*args, **kwargs)
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
max_age = kwargs.get("cache", {}).get("s-max-age", float("inf"))
|
cache_control_args = kwargs.get("cache", {})
|
||||||
|
max_age = cache_control_args.get(
|
||||||
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
|
)
|
||||||
cached_result = self.cache.get_cache(cache_key)
|
cached_result = self.cache.get_cache(cache_key)
|
||||||
# Check if a timestamp was stored with the cached response
|
# Check if a timestamp was stored with the cached response
|
||||||
if (
|
if (
|
||||||
|
@ -529,7 +533,7 @@ class Cache:
|
||||||
else:
|
else:
|
||||||
cache_key = self.get_cache_key(*args, **kwargs)
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
if isinstance(result, litellm.ModelResponse):
|
if isinstance(result, OpenAIObject):
|
||||||
result = result.model_dump_json()
|
result = result.model_dump_json()
|
||||||
|
|
||||||
## Get Cache-Controls ##
|
## Get Cache-Controls ##
|
||||||
|
|
|
@ -724,16 +724,32 @@ class AzureChatCompletion(BaseLLM):
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
)
|
)
|
||||||
client = AsyncAzureOpenAI(
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
api_version=api_version,
|
## build base url - assume api base includes resource name
|
||||||
azure_endpoint=api_base,
|
if not api_base.endswith("/"):
|
||||||
api_key=api_key,
|
api_base += "/"
|
||||||
timeout=timeout,
|
api_base += f"{model}"
|
||||||
http_client=client_session,
|
client = AsyncAzureOpenAI(
|
||||||
)
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
model = None
|
||||||
|
# cloudflare ai gateway, needs model=None
|
||||||
|
else:
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
|
||||||
if model is None and mode != "image_generation":
|
# only run this check if it's not cloudflare ai gateway
|
||||||
raise Exception("model is not set")
|
if model is None and mode != "image_generation":
|
||||||
|
raise Exception("model is not set")
|
||||||
|
|
||||||
completion = None
|
completion = None
|
||||||
|
|
||||||
|
|
0
litellm/llms/custom_httpx/bedrock_async.py
Normal file
0
litellm/llms/custom_httpx/bedrock_async.py
Normal file
|
@ -14,12 +14,18 @@ model_list:
|
||||||
- model_name: BEDROCK_GROUP
|
- model_name: BEDROCK_GROUP
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/cohere.command-text-v14
|
model: bedrock/cohere.command-text-v14
|
||||||
- model_name: Azure OpenAI GPT-4 Canada-East (External)
|
- model_name: openai-gpt-3.5
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
model_info:
|
||||||
mode: chat
|
mode: chat
|
||||||
|
- model_name: azure-cloudflare
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
- model_name: azure-embedding-model
|
- model_name: azure-embedding-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/azure-embedding-model
|
model: azure/azure-embedding-model
|
||||||
|
|
|
@ -307,9 +307,8 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
async def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||||
|
|
||||||
if (
|
if (
|
||||||
database_url is not None and prisma_client is None
|
database_url is not None and prisma_client is None
|
||||||
): # don't re-initialize prisma client after initial init
|
): # don't re-initialize prisma client after initial init
|
||||||
|
@ -321,6 +320,8 @@ def prisma_setup(database_url: Optional[str]):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
|
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
|
||||||
)
|
)
|
||||||
|
if prisma_client is not None and prisma_client.db.is_connected() == False:
|
||||||
|
await prisma_client.connect()
|
||||||
|
|
||||||
|
|
||||||
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
|
@ -502,232 +503,330 @@ async def _run_background_health_check():
|
||||||
await asyncio.sleep(health_check_interval)
|
await asyncio.sleep(health_check_interval)
|
||||||
|
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
class ProxyConfig:
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
"""
|
||||||
config = {}
|
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
||||||
try:
|
"""
|
||||||
if os.path.exists(config_file_path):
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
||||||
|
global prisma_client, user_config_file_path
|
||||||
|
|
||||||
|
file_path = config_file_path or user_config_file_path
|
||||||
|
if config_file_path is not None:
|
||||||
user_config_file_path = config_file_path
|
user_config_file_path = config_file_path
|
||||||
with open(config_file_path, "r") as file:
|
# Load existing config
|
||||||
config = yaml.safe_load(file)
|
## Yaml
|
||||||
else:
|
if file_path is not None:
|
||||||
raise Exception(
|
if os.path.exists(f"{file_path}"):
|
||||||
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False"
|
with open(f"{file_path}", "r") as config_file:
|
||||||
|
config = yaml.safe_load(config_file)
|
||||||
|
else:
|
||||||
|
raise Exception(f"File not found! - {file_path}")
|
||||||
|
|
||||||
|
## DB
|
||||||
|
if (
|
||||||
|
prisma_client is not None
|
||||||
|
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
|
||||||
|
):
|
||||||
|
await prisma_setup(database_url=None) # in case it's not been connected yet
|
||||||
|
_tasks = []
|
||||||
|
keys = [
|
||||||
|
"model_list",
|
||||||
|
"general_settings",
|
||||||
|
"router_settings",
|
||||||
|
"litellm_settings",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
response = prisma_client.get_generic_data(
|
||||||
|
key="param_name", value=k, table_name="config"
|
||||||
|
)
|
||||||
|
_tasks.append(response)
|
||||||
|
|
||||||
|
responses = await asyncio.gather(*_tasks)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def save_config(self, new_config: dict):
|
||||||
|
global prisma_client, llm_router, user_config_file_path, llm_model_list, general_settings
|
||||||
|
# Load existing config
|
||||||
|
backup_config = await self.get_config()
|
||||||
|
|
||||||
|
# Save the updated config
|
||||||
|
## YAML
|
||||||
|
with open(f"{user_config_file_path}", "w") as config_file:
|
||||||
|
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||||
|
|
||||||
|
# update Router - verifies if this is a valid config
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
llm_router,
|
||||||
|
llm_model_list,
|
||||||
|
general_settings,
|
||||||
|
) = await proxy_config.load_config(
|
||||||
|
router=llm_router, config_file_path=user_config_file_path
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Exception while reading Config: {e}")
|
traceback.print_exc()
|
||||||
|
# Revert to old config instead
|
||||||
|
with open(f"{user_config_file_path}", "w") as config_file:
|
||||||
|
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid config passed in")
|
||||||
|
|
||||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
## DB - writes valid config to db
|
||||||
printed_yaml = copy.deepcopy(config)
|
"""
|
||||||
printed_yaml.pop("environment_variables", None)
|
- Do not write restricted params like 'api_key' to the database
|
||||||
|
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
prisma_client is not None
|
||||||
|
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
|
||||||
|
):
|
||||||
|
### KEY REMOVAL ###
|
||||||
|
models = new_config.get("model_list", [])
|
||||||
|
for m in models:
|
||||||
|
if m.get("litellm_params", {}).get("api_key", None) is not None:
|
||||||
|
# pop the key
|
||||||
|
api_key = m["litellm_params"].pop("api_key")
|
||||||
|
# store in local env
|
||||||
|
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
|
||||||
|
os.environ[key_name] = api_key
|
||||||
|
# save the key name (not the value)
|
||||||
|
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
|
||||||
|
await prisma_client.insert_data(data=new_config, table_name="config")
|
||||||
|
|
||||||
print_verbose(
|
async def load_config(
|
||||||
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
self, router: Optional[litellm.Router], config_file_path: str
|
||||||
)
|
):
|
||||||
|
"""
|
||||||
|
Load config values into proxy global state
|
||||||
|
"""
|
||||||
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||||
|
|
||||||
## ENVIRONMENT VARIABLES
|
# Load existing config
|
||||||
environment_variables = config.get("environment_variables", None)
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
if environment_variables:
|
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||||
for key, value in environment_variables.items():
|
printed_yaml = copy.deepcopy(config)
|
||||||
os.environ[key] = value
|
printed_yaml.pop("environment_variables", None)
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
print_verbose(
|
||||||
litellm_settings = config.get("litellm_settings", None)
|
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
||||||
if litellm_settings is None:
|
)
|
||||||
litellm_settings = {}
|
|
||||||
if litellm_settings:
|
|
||||||
# ANSI escape code for blue text
|
|
||||||
blue_color_code = "\033[94m"
|
|
||||||
reset_color_code = "\033[0m"
|
|
||||||
for key, value in litellm_settings.items():
|
|
||||||
if key == "cache":
|
|
||||||
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
|
|
||||||
from litellm.caching import Cache
|
|
||||||
|
|
||||||
cache_params = {}
|
## ENVIRONMENT VARIABLES
|
||||||
if "cache_params" in litellm_settings:
|
environment_variables = config.get("environment_variables", None)
|
||||||
cache_params_in_config = litellm_settings["cache_params"]
|
if environment_variables:
|
||||||
# overwrie cache_params with cache_params_in_config
|
for key, value in environment_variables.items():
|
||||||
cache_params.update(cache_params_in_config)
|
os.environ[key] = value
|
||||||
|
|
||||||
cache_type = cache_params.get("type", "redis")
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
|
litellm_settings = config.get("litellm_settings", None)
|
||||||
|
if litellm_settings is None:
|
||||||
|
litellm_settings = {}
|
||||||
|
if litellm_settings:
|
||||||
|
# ANSI escape code for blue text
|
||||||
|
blue_color_code = "\033[94m"
|
||||||
|
reset_color_code = "\033[0m"
|
||||||
|
for key, value in litellm_settings.items():
|
||||||
|
if key == "cache":
|
||||||
|
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
|
||||||
|
from litellm.caching import Cache
|
||||||
|
|
||||||
print_verbose(f"passed cache type={cache_type}")
|
cache_params = {}
|
||||||
|
if "cache_params" in litellm_settings:
|
||||||
|
cache_params_in_config = litellm_settings["cache_params"]
|
||||||
|
# overwrie cache_params with cache_params_in_config
|
||||||
|
cache_params.update(cache_params_in_config)
|
||||||
|
|
||||||
if cache_type == "redis":
|
cache_type = cache_params.get("type", "redis")
|
||||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
|
||||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
|
||||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
|
||||||
|
|
||||||
cache_params = {
|
print_verbose(f"passed cache type={cache_type}")
|
||||||
"type": cache_type,
|
|
||||||
"host": cache_host,
|
if cache_type == "redis":
|
||||||
"port": cache_port,
|
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||||
"password": cache_password,
|
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||||
}
|
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
|
||||||
|
cache_params.update(
|
||||||
|
{
|
||||||
|
"type": cache_type,
|
||||||
|
"host": cache_host,
|
||||||
|
"port": cache_port,
|
||||||
|
"password": cache_password,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||||
|
print( # noqa
|
||||||
|
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
||||||
|
) # noqa
|
||||||
|
print( # noqa
|
||||||
|
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
|
||||||
|
) # noqa
|
||||||
|
print( # noqa
|
||||||
|
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
|
||||||
|
) # noqa
|
||||||
|
print( # noqa
|
||||||
|
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
||||||
|
)
|
||||||
|
print() # noqa
|
||||||
|
|
||||||
|
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||||
|
litellm.cache = Cache(**cache_params)
|
||||||
print( # noqa
|
print( # noqa
|
||||||
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
||||||
) # noqa
|
|
||||||
print( # noqa
|
|
||||||
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
|
|
||||||
) # noqa
|
|
||||||
print( # noqa
|
|
||||||
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
|
|
||||||
) # noqa
|
|
||||||
print( # noqa
|
|
||||||
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
|
|
||||||
)
|
)
|
||||||
print() # noqa
|
elif key == "callbacks":
|
||||||
|
litellm.callbacks = [
|
||||||
|
get_instance_fn(value=value, config_file_path=config_file_path)
|
||||||
|
]
|
||||||
|
print_verbose(
|
||||||
|
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
||||||
|
)
|
||||||
|
elif key == "post_call_rules":
|
||||||
|
litellm.post_call_rules = [
|
||||||
|
get_instance_fn(value=value, config_file_path=config_file_path)
|
||||||
|
]
|
||||||
|
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
|
||||||
|
elif key == "success_callback":
|
||||||
|
litellm.success_callback = []
|
||||||
|
|
||||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
# intialize success callbacks
|
||||||
litellm.cache = Cache(**cache_params)
|
for callback in value:
|
||||||
print( # noqa
|
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||||
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
if "." in callback:
|
||||||
)
|
litellm.success_callback.append(
|
||||||
elif key == "callbacks":
|
get_instance_fn(value=callback)
|
||||||
litellm.callbacks = [
|
)
|
||||||
get_instance_fn(value=value, config_file_path=config_file_path)
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||||
]
|
else:
|
||||||
print_verbose(
|
litellm.success_callback.append(callback)
|
||||||
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
|
print_verbose(
|
||||||
)
|
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
|
||||||
elif key == "post_call_rules":
|
)
|
||||||
litellm.post_call_rules = [
|
elif key == "failure_callback":
|
||||||
get_instance_fn(value=value, config_file_path=config_file_path)
|
litellm.failure_callback = []
|
||||||
]
|
|
||||||
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
|
|
||||||
elif key == "success_callback":
|
|
||||||
litellm.success_callback = []
|
|
||||||
|
|
||||||
# intialize success callbacks
|
# intialize success callbacks
|
||||||
for callback in value:
|
for callback in value:
|
||||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||||
if "." in callback:
|
if "." in callback:
|
||||||
litellm.success_callback.append(get_instance_fn(value=callback))
|
litellm.failure_callback.append(
|
||||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
get_instance_fn(value=callback)
|
||||||
else:
|
)
|
||||||
litellm.success_callback.append(callback)
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||||
print_verbose(
|
else:
|
||||||
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
|
litellm.failure_callback.append(callback)
|
||||||
)
|
print_verbose(
|
||||||
elif key == "failure_callback":
|
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
|
||||||
litellm.failure_callback = []
|
)
|
||||||
|
elif key == "cache_params":
|
||||||
|
# this is set in the cache branch
|
||||||
|
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
setattr(litellm, key, value)
|
||||||
|
|
||||||
# intialize success callbacks
|
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||||
for callback in value:
|
general_settings = config.get("general_settings", {})
|
||||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
if general_settings is None:
|
||||||
if "." in callback:
|
general_settings = {}
|
||||||
litellm.failure_callback.append(get_instance_fn(value=callback))
|
if general_settings:
|
||||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
### LOAD SECRET MANAGER ###
|
||||||
else:
|
key_management_system = general_settings.get("key_management_system", None)
|
||||||
litellm.failure_callback.append(callback)
|
if key_management_system is not None:
|
||||||
print_verbose(
|
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
|
||||||
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
|
### LOAD FROM AZURE KEY VAULT ###
|
||||||
)
|
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||||
elif key == "cache_params":
|
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||||
# this is set in the cache branch
|
### LOAD FROM GOOGLE KMS ###
|
||||||
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
load_google_kms(use_google_kms=True)
|
||||||
pass
|
else:
|
||||||
else:
|
raise ValueError("Invalid Key Management System selected")
|
||||||
setattr(litellm, key, value)
|
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
||||||
|
use_google_kms = general_settings.get("use_google_kms", False)
|
||||||
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
load_google_kms(use_google_kms=use_google_kms)
|
||||||
general_settings = config.get("general_settings", {})
|
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
|
||||||
if general_settings is None:
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
general_settings = {}
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
if general_settings:
|
### ALERTING ###
|
||||||
### LOAD SECRET MANAGER ###
|
proxy_logging_obj.update_values(
|
||||||
key_management_system = general_settings.get("key_management_system", None)
|
alerting=general_settings.get("alerting", None),
|
||||||
if key_management_system is not None:
|
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
||||||
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
|
|
||||||
### LOAD FROM AZURE KEY VAULT ###
|
|
||||||
load_from_azure_key_vault(use_azure_key_vault=True)
|
|
||||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
|
||||||
### LOAD FROM GOOGLE KMS ###
|
|
||||||
load_google_kms(use_google_kms=True)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid Key Management System selected")
|
|
||||||
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
|
||||||
use_google_kms = general_settings.get("use_google_kms", False)
|
|
||||||
load_google_kms(use_google_kms=use_google_kms)
|
|
||||||
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
|
|
||||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
|
||||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
|
||||||
### ALERTING ###
|
|
||||||
proxy_logging_obj.update_values(
|
|
||||||
alerting=general_settings.get("alerting", None),
|
|
||||||
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
|
||||||
)
|
|
||||||
### CONNECT TO DATABASE ###
|
|
||||||
database_url = general_settings.get("database_url", None)
|
|
||||||
if database_url and database_url.startswith("os.environ/"):
|
|
||||||
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
|
||||||
database_url = litellm.get_secret(database_url)
|
|
||||||
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
|
||||||
prisma_setup(database_url=database_url)
|
|
||||||
## COST TRACKING ##
|
|
||||||
cost_tracking()
|
|
||||||
### MASTER KEY ###
|
|
||||||
master_key = general_settings.get(
|
|
||||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
|
||||||
)
|
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
|
||||||
master_key = litellm.get_secret(master_key)
|
|
||||||
### CUSTOM API KEY AUTH ###
|
|
||||||
custom_auth = general_settings.get("custom_auth", None)
|
|
||||||
if custom_auth:
|
|
||||||
user_custom_auth = get_instance_fn(
|
|
||||||
value=custom_auth, config_file_path=config_file_path
|
|
||||||
)
|
)
|
||||||
### BACKGROUND HEALTH CHECKS ###
|
### CONNECT TO DATABASE ###
|
||||||
# Enable background health checks
|
database_url = general_settings.get("database_url", None)
|
||||||
use_background_health_checks = general_settings.get(
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
"background_health_checks", False
|
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
||||||
)
|
database_url = litellm.get_secret(database_url)
|
||||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
||||||
|
await prisma_setup(database_url=database_url)
|
||||||
|
## COST TRACKING ##
|
||||||
|
cost_tracking()
|
||||||
|
### MASTER KEY ###
|
||||||
|
master_key = general_settings.get(
|
||||||
|
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||||
|
)
|
||||||
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
|
master_key = litellm.get_secret(master_key)
|
||||||
|
### CUSTOM API KEY AUTH ###
|
||||||
|
custom_auth = general_settings.get("custom_auth", None)
|
||||||
|
if custom_auth:
|
||||||
|
user_custom_auth = get_instance_fn(
|
||||||
|
value=custom_auth, config_file_path=config_file_path
|
||||||
|
)
|
||||||
|
### BACKGROUND HEALTH CHECKS ###
|
||||||
|
# Enable background health checks
|
||||||
|
use_background_health_checks = general_settings.get(
|
||||||
|
"background_health_checks", False
|
||||||
|
)
|
||||||
|
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||||
|
|
||||||
router_params: dict = {
|
router_params: dict = {
|
||||||
"num_retries": 3,
|
"num_retries": 3,
|
||||||
"cache_responses": litellm.cache
|
"cache_responses": litellm.cache
|
||||||
!= None, # cache if user passed in cache values
|
!= None, # cache if user passed in cache values
|
||||||
}
|
|
||||||
## MODEL LIST
|
|
||||||
model_list = config.get("model_list", None)
|
|
||||||
if model_list:
|
|
||||||
router_params["model_list"] = model_list
|
|
||||||
print( # noqa
|
|
||||||
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
|
|
||||||
) # noqa
|
|
||||||
for model in model_list:
|
|
||||||
### LOAD FROM os.environ/ ###
|
|
||||||
for k, v in model["litellm_params"].items():
|
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
||||||
model["litellm_params"][k] = litellm.get_secret(v)
|
|
||||||
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
|
||||||
litellm_model_name = model["litellm_params"]["model"]
|
|
||||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
|
||||||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
|
||||||
run_ollama_serve()
|
|
||||||
|
|
||||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
|
||||||
router_settings = config.get("router_settings", None)
|
|
||||||
if router_settings and isinstance(router_settings, dict):
|
|
||||||
arg_spec = inspect.getfullargspec(litellm.Router)
|
|
||||||
# model list already set
|
|
||||||
exclude_args = {
|
|
||||||
"self",
|
|
||||||
"model_list",
|
|
||||||
}
|
}
|
||||||
|
## MODEL LIST
|
||||||
|
model_list = config.get("model_list", None)
|
||||||
|
if model_list:
|
||||||
|
router_params["model_list"] = model_list
|
||||||
|
print( # noqa
|
||||||
|
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
|
||||||
|
) # noqa
|
||||||
|
for model in model_list:
|
||||||
|
### LOAD FROM os.environ/ ###
|
||||||
|
for k, v in model["litellm_params"].items():
|
||||||
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
model["litellm_params"][k] = litellm.get_secret(v)
|
||||||
|
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
||||||
|
litellm_model_name = model["litellm_params"]["model"]
|
||||||
|
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||||
|
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||||
|
run_ollama_serve()
|
||||||
|
|
||||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||||
|
router_settings = config.get("router_settings", None)
|
||||||
|
if router_settings and isinstance(router_settings, dict):
|
||||||
|
arg_spec = inspect.getfullargspec(litellm.Router)
|
||||||
|
# model list already set
|
||||||
|
exclude_args = {
|
||||||
|
"self",
|
||||||
|
"model_list",
|
||||||
|
}
|
||||||
|
|
||||||
for k, v in router_settings.items():
|
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||||
if k in available_args:
|
|
||||||
router_params[k] = v
|
|
||||||
|
|
||||||
router = litellm.Router(**router_params) # type:ignore
|
for k, v in router_settings.items():
|
||||||
return router, model_list, general_settings
|
if k in available_args:
|
||||||
|
router_params[k] = v
|
||||||
|
|
||||||
|
router = litellm.Router(**router_params) # type:ignore
|
||||||
|
return router, model_list, general_settings
|
||||||
|
|
||||||
|
|
||||||
|
proxy_config = ProxyConfig()
|
||||||
|
|
||||||
|
|
||||||
async def generate_key_helper_fn(
|
async def generate_key_helper_fn(
|
||||||
|
@ -797,6 +896,7 @@ async def generate_key_helper_fn(
|
||||||
"max_budget": max_budget,
|
"max_budget": max_budget,
|
||||||
"user_email": user_email,
|
"user_email": user_email,
|
||||||
}
|
}
|
||||||
|
print_verbose("PrismaClient: Before Insert Data")
|
||||||
new_verification_token = await prisma_client.insert_data(
|
new_verification_token = await prisma_client.insert_data(
|
||||||
data=verification_token_data
|
data=verification_token_data
|
||||||
)
|
)
|
||||||
|
@ -831,7 +931,7 @@ def save_worker_config(**data):
|
||||||
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
def initialize(
|
async def initialize(
|
||||||
model=None,
|
model=None,
|
||||||
alias=None,
|
alias=None,
|
||||||
api_base=None,
|
api_base=None,
|
||||||
|
@ -849,7 +949,7 @@ def initialize(
|
||||||
use_queue=False,
|
use_queue=False,
|
||||||
config=None,
|
config=None,
|
||||||
):
|
):
|
||||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
|
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
|
||||||
generate_feedback_box()
|
generate_feedback_box()
|
||||||
user_model = model
|
user_model = model
|
||||||
user_debug = debug
|
user_debug = debug
|
||||||
|
@ -857,9 +957,11 @@ def initialize(
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
dynamic_config = {"general": {}, user_model: {}}
|
dynamic_config = {"general": {}, user_model: {}}
|
||||||
if config:
|
if config:
|
||||||
llm_router, llm_model_list, general_settings = load_router_config(
|
(
|
||||||
router=llm_router, config_file_path=config
|
llm_router,
|
||||||
)
|
llm_model_list,
|
||||||
|
general_settings,
|
||||||
|
) = await proxy_config.load_config(router=llm_router, config_file_path=config)
|
||||||
if headers: # model-specific param
|
if headers: # model-specific param
|
||||||
user_headers = headers
|
user_headers = headers
|
||||||
dynamic_config[user_model]["headers"] = headers
|
dynamic_config[user_model]["headers"] = headers
|
||||||
|
@ -988,7 +1090,7 @@ def parse_cache_control(cache_control):
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key, use_background_health_checks
|
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
|
||||||
import json
|
import json
|
||||||
|
|
||||||
### LOAD MASTER KEY ###
|
### LOAD MASTER KEY ###
|
||||||
|
@ -1000,12 +1102,11 @@ async def startup_event():
|
||||||
print_verbose(f"worker_config: {worker_config}")
|
print_verbose(f"worker_config: {worker_config}")
|
||||||
# check if it's a valid file path
|
# check if it's a valid file path
|
||||||
if os.path.isfile(worker_config):
|
if os.path.isfile(worker_config):
|
||||||
initialize(config=worker_config)
|
await initialize(**worker_config)
|
||||||
else:
|
else:
|
||||||
# if not, assume it's a json string
|
# if not, assume it's a json string
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
initialize(**worker_config)
|
await initialize(**worker_config)
|
||||||
|
|
||||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
|
@ -1013,10 +1114,6 @@ async def startup_event():
|
||||||
_run_background_health_check()
|
_run_background_health_check()
|
||||||
) # start the background health check coroutine.
|
) # start the background health check coroutine.
|
||||||
|
|
||||||
print_verbose(f"prisma client - {prisma_client}")
|
|
||||||
if prisma_client is not None:
|
|
||||||
await prisma_client.connect()
|
|
||||||
|
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
# add master key to db
|
# add master key to db
|
||||||
await generate_key_helper_fn(
|
await generate_key_helper_fn(
|
||||||
|
@ -1220,7 +1317,7 @@ async def chat_completion(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
try:
|
try:
|
||||||
data = {}
|
data = {}
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -1673,6 +1770,7 @@ async def generate_key_fn(
|
||||||
- expires: (datetime) Datetime object for when key expires.
|
- expires: (datetime) Datetime object for when key expires.
|
||||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||||
"""
|
"""
|
||||||
|
print_verbose("entered /key/generate")
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
response = await generate_key_helper_fn(**data_json)
|
response = await generate_key_helper_fn(**data_json)
|
||||||
return GenerateKeyResponse(
|
return GenerateKeyResponse(
|
||||||
|
@ -1825,7 +1923,7 @@ async def user_auth(request: Request):
|
||||||
|
|
||||||
### Check if user email in user table
|
### Check if user email in user table
|
||||||
response = await prisma_client.get_generic_data(
|
response = await prisma_client.get_generic_data(
|
||||||
key="user_email", value=user_email, db="users"
|
key="user_email", value=user_email, table_name="users"
|
||||||
)
|
)
|
||||||
### if so - generate a 24 hr key with that user id
|
### if so - generate a 24 hr key with that user id
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
@ -1883,16 +1981,13 @@ async def user_update(request: Request):
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def add_new_model(model_params: ModelParams):
|
async def add_new_model(model_params: ModelParams):
|
||||||
global llm_router, llm_model_list, general_settings, user_config_file_path
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||||
try:
|
try:
|
||||||
print_verbose(f"User config path: {user_config_file_path}")
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
if os.path.exists(f"{user_config_file_path}"):
|
config = await proxy_config.get_config()
|
||||||
with open(f"{user_config_file_path}", "r") as config_file:
|
|
||||||
config = yaml.safe_load(config_file)
|
print_verbose(f"User config path: {user_config_file_path}")
|
||||||
else:
|
|
||||||
config = {"model_list": []}
|
|
||||||
backup_config = copy.deepcopy(config)
|
|
||||||
print_verbose(f"Loaded config: {config}")
|
print_verbose(f"Loaded config: {config}")
|
||||||
# Add the new model to the config
|
# Add the new model to the config
|
||||||
model_info = model_params.model_info.json()
|
model_info = model_params.model_info.json()
|
||||||
|
@ -1907,22 +2002,8 @@ async def add_new_model(model_params: ModelParams):
|
||||||
|
|
||||||
print_verbose(f"updated model list: {config['model_list']}")
|
print_verbose(f"updated model list: {config['model_list']}")
|
||||||
|
|
||||||
# Save the updated config
|
# Save new config
|
||||||
with open(f"{user_config_file_path}", "w") as config_file:
|
await proxy_config.save_config(new_config=config)
|
||||||
yaml.dump(config, config_file, default_flow_style=False)
|
|
||||||
|
|
||||||
# update Router
|
|
||||||
try:
|
|
||||||
llm_router, llm_model_list, general_settings = load_router_config(
|
|
||||||
router=llm_router, config_file_path=user_config_file_path
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Rever to old config instead
|
|
||||||
with open(f"{user_config_file_path}", "w") as config_file:
|
|
||||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid Model passed in")
|
|
||||||
|
|
||||||
print_verbose(f"llm_model_list: {llm_model_list}")
|
|
||||||
return {"message": "Model added successfully"}
|
return {"message": "Model added successfully"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1949,13 +2030,10 @@ async def add_new_model(model_params: ModelParams):
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def model_info_v1(request: Request):
|
async def model_info_v1(request: Request):
|
||||||
global llm_model_list, general_settings, user_config_file_path
|
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
if os.path.exists(f"{user_config_file_path}"):
|
config = await proxy_config.get_config()
|
||||||
with open(f"{user_config_file_path}", "r") as config_file:
|
|
||||||
config = yaml.safe_load(config_file)
|
|
||||||
else:
|
|
||||||
config = {"model_list": []} # handle base case
|
|
||||||
|
|
||||||
all_models = config["model_list"]
|
all_models = config["model_list"]
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
|
@ -1984,18 +2062,18 @@ async def model_info_v1(request: Request):
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def delete_model(model_info: ModelInfoDelete):
|
async def delete_model(model_info: ModelInfoDelete):
|
||||||
global llm_router, llm_model_list, general_settings, user_config_file_path
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(user_config_file_path):
|
if not os.path.exists(user_config_file_path):
|
||||||
raise HTTPException(status_code=404, detail="Config file does not exist.")
|
raise HTTPException(status_code=404, detail="Config file does not exist.")
|
||||||
|
|
||||||
with open(user_config_file_path, "r") as config_file:
|
# Load existing config
|
||||||
config = yaml.safe_load(config_file)
|
config = await proxy_config.get_config()
|
||||||
|
|
||||||
# If model_list is not in the config, nothing can be deleted
|
# If model_list is not in the config, nothing can be deleted
|
||||||
if "model_list" not in config:
|
if len(config.get("model_list", [])) == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail="No model list available in the config."
|
status_code=400, detail="No model list available in the config."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the model with the specified model_id exists
|
# Check if the model with the specified model_id exists
|
||||||
|
@ -2008,19 +2086,14 @@ async def delete_model(model_info: ModelInfoDelete):
|
||||||
# If the model was not found, return an error
|
# If the model was not found, return an error
|
||||||
if model_to_delete is None:
|
if model_to_delete is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail="Model with given model_id not found."
|
status_code=400, detail="Model with given model_id not found."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove model from the list and save the updated config
|
# Remove model from the list and save the updated config
|
||||||
config["model_list"].remove(model_to_delete)
|
config["model_list"].remove(model_to_delete)
|
||||||
with open(user_config_file_path, "w") as config_file:
|
|
||||||
yaml.dump(config, config_file, default_flow_style=False)
|
|
||||||
|
|
||||||
# Update Router
|
|
||||||
llm_router, llm_model_list, general_settings = load_router_config(
|
|
||||||
router=llm_router, config_file_path=user_config_file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Save updated config
|
||||||
|
config = await proxy_config.save_config(new_config=config)
|
||||||
return {"message": "Model deleted successfully"}
|
return {"message": "Model deleted successfully"}
|
||||||
|
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
|
@ -2200,14 +2273,11 @@ async def update_config(config_info: ConfigYAML):
|
||||||
|
|
||||||
Currently supports modifying General Settings + LiteLLM settings
|
Currently supports modifying General Settings + LiteLLM settings
|
||||||
"""
|
"""
|
||||||
global llm_router, llm_model_list, general_settings
|
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj
|
||||||
try:
|
try:
|
||||||
# Load existing config
|
# Load existing config
|
||||||
if os.path.exists(f"{user_config_file_path}"):
|
config = await proxy_config.get_config()
|
||||||
with open(f"{user_config_file_path}", "r") as config_file:
|
|
||||||
config = yaml.safe_load(config_file)
|
|
||||||
else:
|
|
||||||
config = {}
|
|
||||||
backup_config = copy.deepcopy(config)
|
backup_config = copy.deepcopy(config)
|
||||||
print_verbose(f"Loaded config: {config}")
|
print_verbose(f"Loaded config: {config}")
|
||||||
|
|
||||||
|
@ -2240,20 +2310,13 @@ async def update_config(config_info: ConfigYAML):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the updated config
|
# Save the updated config
|
||||||
with open(f"{user_config_file_path}", "w") as config_file:
|
await proxy_config.save_config(new_config=config)
|
||||||
yaml.dump(config, config_file, default_flow_style=False)
|
|
||||||
|
|
||||||
# update Router
|
# Test new connections
|
||||||
try:
|
## Slack
|
||||||
llm_router, llm_model_list, general_settings = load_router_config(
|
if "slack" in config.get("general_settings", {}).get("alerting", []):
|
||||||
router=llm_router, config_file_path=user_config_file_path
|
await proxy_logging_obj.alerting_handler(
|
||||||
)
|
message="This is a test", level="Low"
|
||||||
except Exception as e:
|
|
||||||
# Rever to old config instead
|
|
||||||
with open(f"{user_config_file_path}", "w") as config_file:
|
|
||||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
|
|
||||||
)
|
)
|
||||||
return {"message": "Config updated successfully"}
|
return {"message": "Config updated successfully"}
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
|
@ -2263,6 +2326,21 @@ async def update_config(config_info: ConfigYAML):
|
||||||
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}")
|
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/config/get",
|
||||||
|
tags=["config.yaml"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def get_config():
|
||||||
|
"""
|
||||||
|
Master key only.
|
||||||
|
|
||||||
|
Returns the config. Mainly used for testing.
|
||||||
|
"""
|
||||||
|
global proxy_config
|
||||||
|
return await proxy_config.get_config()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/yaml", tags=["config.yaml"])
|
@router.get("/config/yaml", tags=["config.yaml"])
|
||||||
async def config_yaml_endpoint(config_info: ConfigYAML):
|
async def config_yaml_endpoint(config_info: ConfigYAML):
|
||||||
"""
|
"""
|
||||||
|
@ -2351,6 +2429,28 @@ async def health_endpoint(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/readiness", tags=["health"])
|
||||||
|
async def health_readiness():
|
||||||
|
"""
|
||||||
|
Unprotected endpoint for checking if worker can receive requests
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
if prisma_client is not None: # if db passed in, check if it's connected
|
||||||
|
if prisma_client.db.is_connected() == True:
|
||||||
|
return {"status": "healthy"}
|
||||||
|
else:
|
||||||
|
return {"status": "healthy"}
|
||||||
|
raise HTTPException(status_code=503, detail="Service Unhealthy")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/liveliness", tags=["health"])
|
||||||
|
async def health_liveliness():
|
||||||
|
"""
|
||||||
|
Unprotected endpoint for checking if worker is alive
|
||||||
|
"""
|
||||||
|
return "I'm alive!"
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def home(request: Request):
|
async def home(request: Request):
|
||||||
return "LiteLLM: RUNNING"
|
return "LiteLLM: RUNNING"
|
||||||
|
|
|
@ -25,4 +25,9 @@ model LiteLLM_VerificationToken {
|
||||||
user_id String?
|
user_id String?
|
||||||
max_parallel_requests Int?
|
max_parallel_requests Int?
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_Config {
|
||||||
|
param_name String @id
|
||||||
|
param_value Json?
|
||||||
}
|
}
|
|
@ -250,31 +250,37 @@ def on_backoff(details):
|
||||||
|
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
print_verbose(
|
### Check if prisma client can be imported (setup done in Docker build)
|
||||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
|
||||||
)
|
|
||||||
## init logging object
|
|
||||||
self.proxy_logging_obj = proxy_logging_obj
|
|
||||||
self.connected = False
|
|
||||||
os.environ["DATABASE_URL"] = database_url
|
|
||||||
# Save the current working directory
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
# set the working directory to where this script is
|
|
||||||
abspath = os.path.abspath(__file__)
|
|
||||||
dname = os.path.dirname(abspath)
|
|
||||||
os.chdir(dname)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(["prisma", "generate"])
|
from prisma import Client # type: ignore
|
||||||
subprocess.run(
|
|
||||||
["prisma", "db", "push", "--accept-data-loss"]
|
|
||||||
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
# Now you can import the Prisma Client
|
|
||||||
from prisma import Client # type: ignore
|
|
||||||
|
|
||||||
self.db = Client() # Client to connect to Prisma db
|
os.environ["DATABASE_URL"] = database_url
|
||||||
|
self.db = Client() # Client to connect to Prisma db
|
||||||
|
except: # if not - go through normal setup process
|
||||||
|
print_verbose(
|
||||||
|
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||||
|
)
|
||||||
|
## init logging object
|
||||||
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
|
os.environ["DATABASE_URL"] = database_url
|
||||||
|
# Save the current working directory
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
# set the working directory to where this script is
|
||||||
|
abspath = os.path.abspath(__file__)
|
||||||
|
dname = os.path.dirname(abspath)
|
||||||
|
os.chdir(dname)
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(["prisma", "generate"])
|
||||||
|
subprocess.run(
|
||||||
|
["prisma", "db", "push", "--accept-data-loss"]
|
||||||
|
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
# Now you can import the Prisma Client
|
||||||
|
from prisma import Client # type: ignore
|
||||||
|
|
||||||
|
self.db = Client() # Client to connect to Prisma db
|
||||||
|
|
||||||
def hash_token(self, token: str):
|
def hash_token(self, token: str):
|
||||||
# Hash the string using SHA-256
|
# Hash the string using SHA-256
|
||||||
|
@ -301,20 +307,24 @@ class PrismaClient:
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: Any,
|
||||||
db: Literal["users", "keys"],
|
table_name: Literal["users", "keys", "config"],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generic implementation of get data
|
Generic implementation of get data
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if db == "users":
|
if table_name == "users":
|
||||||
response = await self.db.litellm_usertable.find_first(
|
response = await self.db.litellm_usertable.find_first(
|
||||||
where={key: value} # type: ignore
|
where={key: value} # type: ignore
|
||||||
)
|
)
|
||||||
elif db == "keys":
|
elif table_name == "keys":
|
||||||
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
||||||
where={key: value} # type: ignore
|
where={key: value} # type: ignore
|
||||||
)
|
)
|
||||||
|
elif table_name == "config":
|
||||||
|
response = await self.db.litellm_config.find_first( # type: ignore
|
||||||
|
where={key: value} # type: ignore
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -336,15 +346,19 @@ class PrismaClient:
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
print_verbose("PrismaClient: get_data")
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
if token is not None:
|
if token is not None:
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
|
print_verbose("PrismaClient: find_unique")
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={"token": hashed_token}
|
where={"token": hashed_token}
|
||||||
)
|
)
|
||||||
|
print_verbose(f"PrismaClient: response={response}")
|
||||||
if response:
|
if response:
|
||||||
# Token exists, now check expiration.
|
# Token exists, now check expiration.
|
||||||
if response.expires is not None and expires is not None:
|
if response.expires is not None and expires is not None:
|
||||||
|
@ -372,6 +386,10 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
)
|
)
|
||||||
|
@ -385,40 +403,71 @@ class PrismaClient:
|
||||||
max_time=10, # maximum total time to retry for
|
max_time=10, # maximum total time to retry for
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
async def insert_data(self, data: dict):
|
async def insert_data(
|
||||||
|
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Add a key to the database. If it already exists, do nothing.
|
Add a key to the database. If it already exists, do nothing.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
token = data["token"]
|
if table_name == "user+key":
|
||||||
hashed_token = self.hash_token(token=token)
|
token = data["token"]
|
||||||
db_data = self.jsonify_object(data=data)
|
hashed_token = self.hash_token(token=token)
|
||||||
db_data["token"] = hashed_token
|
db_data = self.jsonify_object(data=data)
|
||||||
max_budget = db_data.pop("max_budget", None)
|
db_data["token"] = hashed_token
|
||||||
user_email = db_data.pop("user_email", None)
|
max_budget = db_data.pop("max_budget", None)
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
user_email = db_data.pop("user_email", None)
|
||||||
where={
|
print_verbose(
|
||||||
"token": hashed_token,
|
"PrismaClient: Before upsert into litellm_verificationtoken"
|
||||||
},
|
)
|
||||||
data={
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
"create": {**db_data}, # type: ignore
|
where={
|
||||||
"update": {}, # don't do anything if it already exists
|
"token": hashed_token,
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
new_user_row = await self.db.litellm_usertable.upsert(
|
|
||||||
where={"user_id": data["user_id"]},
|
|
||||||
data={
|
|
||||||
"create": {
|
|
||||||
"user_id": data["user_id"],
|
|
||||||
"max_budget": max_budget,
|
|
||||||
"user_email": user_email,
|
|
||||||
},
|
},
|
||||||
"update": {}, # don't do anything if it already exists
|
data={
|
||||||
},
|
"create": {**db_data}, # type: ignore
|
||||||
)
|
"update": {}, # don't do anything if it already exists
|
||||||
return new_verification_token
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_user_row = await self.db.litellm_usertable.upsert(
|
||||||
|
where={"user_id": data["user_id"]},
|
||||||
|
data={
|
||||||
|
"create": {
|
||||||
|
"user_id": data["user_id"],
|
||||||
|
"max_budget": max_budget,
|
||||||
|
"user_email": user_email,
|
||||||
|
},
|
||||||
|
"update": {}, # don't do anything if it already exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return new_verification_token
|
||||||
|
elif table_name == "config":
|
||||||
|
"""
|
||||||
|
For each param,
|
||||||
|
get the existing table values
|
||||||
|
|
||||||
|
Add the new values
|
||||||
|
|
||||||
|
Update DB
|
||||||
|
"""
|
||||||
|
tasks = []
|
||||||
|
for k, v in data.items():
|
||||||
|
updated_data = v
|
||||||
|
updated_data = json.dumps(updated_data)
|
||||||
|
updated_table_row = self.db.litellm_config.upsert(
|
||||||
|
where={"param_name": k},
|
||||||
|
data={
|
||||||
|
"create": {"param_name": k, "param_value": updated_data},
|
||||||
|
"update": {"param_value": updated_data},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks.append(updated_table_row)
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
)
|
)
|
||||||
|
@ -505,11 +554,7 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
try:
|
try:
|
||||||
if self.connected == False:
|
await self.db.connect()
|
||||||
await self.db.connect()
|
|
||||||
self.connected = True
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
|
|
|
@ -773,6 +773,10 @@ class Router:
|
||||||
)
|
)
|
||||||
original_exception = e
|
original_exception = e
|
||||||
try:
|
try:
|
||||||
|
if (
|
||||||
|
hasattr(e, "status_code") and e.status_code == 400
|
||||||
|
): # don't retry a malformed request
|
||||||
|
raise e
|
||||||
self.print_verbose(f"Trying to fallback b/w models")
|
self.print_verbose(f"Trying to fallback b/w models")
|
||||||
if (
|
if (
|
||||||
isinstance(e, litellm.ContextWindowExceededError)
|
isinstance(e, litellm.ContextWindowExceededError)
|
||||||
|
@ -846,7 +850,7 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||||
if (
|
if (
|
||||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||||
and context_window_fallbacks is None
|
and context_window_fallbacks is None
|
||||||
|
@ -864,12 +868,12 @@ class Router:
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(timeout)
|
await asyncio.sleep(timeout)
|
||||||
elif (
|
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||||
hasattr(original_exception, "status_code")
|
status_code=original_exception.status_code
|
||||||
and hasattr(original_exception, "response")
|
|
||||||
and litellm._should_retry(status_code=original_exception.status_code)
|
|
||||||
):
|
):
|
||||||
if hasattr(original_exception.response, "headers"):
|
if hasattr(original_exception, "response") and hasattr(
|
||||||
|
original_exception.response, "headers"
|
||||||
|
):
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=num_retries,
|
remaining_retries=num_retries,
|
||||||
max_retries=num_retries,
|
max_retries=num_retries,
|
||||||
|
@ -1326,6 +1330,7 @@ class Router:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
cache_key = f"{model_id}_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
|
|
|
@ -138,14 +138,15 @@ def test_async_completion_cloudflare():
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||||
messages=[{"content": "what llm are you", "role": "user"}],
|
messages=[{"content": "what llm are you", "role": "user"}],
|
||||||
max_tokens=50,
|
max_tokens=5,
|
||||||
|
num_retries=3,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
response = asyncio.run(test())
|
response = asyncio.run(test())
|
||||||
text_response = response["choices"][0]["message"]["content"]
|
text_response = response["choices"][0]["message"]["content"]
|
||||||
assert len(text_response) > 5 # more than 5 chars in response
|
assert len(text_response) > 1 # more than 1 chars in response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
@ -166,7 +167,7 @@ def test_get_cloudflare_response_streaming():
|
||||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=5,
|
num_retries=3, # cloudflare ai workers is EXTREMELY UNSTABLE
|
||||||
)
|
)
|
||||||
print(type(response))
|
print(type(response))
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ def test_caching_with_cache_controls():
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
|
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
|
||||||
)
|
)
|
||||||
response2 = completion(
|
response2 = completion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 10}
|
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
|
||||||
)
|
)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
|
@ -105,7 +105,7 @@ def test_caching_with_cache_controls():
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
|
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
|
||||||
)
|
)
|
||||||
response2 = completion(
|
response2 = completion(
|
||||||
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 5}
|
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
|
||||||
)
|
)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
|
@ -167,6 +167,8 @@ small text
|
||||||
def test_embedding_caching():
|
def test_embedding_caching():
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
|
||||||
litellm.cache = Cache()
|
litellm.cache = Cache()
|
||||||
text_to_embed = [embedding_large_text]
|
text_to_embed = [embedding_large_text]
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -182,7 +184,7 @@ def test_embedding_caching():
|
||||||
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"embedding2: {embedding2}")
|
# print(f"embedding2: {embedding2}")
|
||||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -274,7 +276,7 @@ def test_redis_cache_completion():
|
||||||
port=os.environ["REDIS_PORT"],
|
port=os.environ["REDIS_PORT"],
|
||||||
password=os.environ["REDIS_PASSWORD"],
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
)
|
)
|
||||||
print("test2 for caching")
|
print("test2 for Redis Caching - non streaming")
|
||||||
response1 = completion(
|
response1 = completion(
|
||||||
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
||||||
)
|
)
|
||||||
|
@ -326,6 +328,10 @@ def test_redis_cache_completion():
|
||||||
print(f"response4: {response4}")
|
print(f"response4: {response4}")
|
||||||
pytest.fail(f"Error occurred:")
|
pytest.fail(f"Error occurred:")
|
||||||
|
|
||||||
|
assert response1.id == response2.id
|
||||||
|
assert response1.created == response2.created
|
||||||
|
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
# test_redis_cache_completion()
|
# test_redis_cache_completion()
|
||||||
|
|
||||||
|
@ -395,7 +401,7 @@ def test_redis_cache_completion_stream():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# test_redis_cache_completion_stream()
|
test_redis_cache_completion_stream()
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_acompletion_stream():
|
def test_redis_cache_acompletion_stream():
|
||||||
|
@ -529,6 +535,7 @@ def test_redis_cache_acompletion_stream_bedrock():
|
||||||
assert (
|
assert (
|
||||||
response_1_content == response_2_content
|
response_1_content == response_2_content
|
||||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm._async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
|
@ -537,7 +544,7 @@ def test_redis_cache_acompletion_stream_bedrock():
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def test_s3_cache_acompletion_stream_bedrock():
|
def test_s3_cache_acompletion_stream_azure():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -556,10 +563,13 @@ def test_s3_cache_acompletion_stream_bedrock():
|
||||||
response_1_content = ""
|
response_1_content = ""
|
||||||
response_2_content = ""
|
response_2_content = ""
|
||||||
|
|
||||||
|
response_1_created = ""
|
||||||
|
response_2_created = ""
|
||||||
|
|
||||||
async def call1():
|
async def call1():
|
||||||
nonlocal response_1_content
|
nonlocal response_1_content, response_1_created
|
||||||
response1 = await litellm.acompletion(
|
response1 = await litellm.acompletion(
|
||||||
model="bedrock/anthropic.claude-v1",
|
model="azure/chatgpt-v-2",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=40,
|
max_tokens=40,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
|
@ -567,6 +577,7 @@ def test_s3_cache_acompletion_stream_bedrock():
|
||||||
)
|
)
|
||||||
async for chunk in response1:
|
async for chunk in response1:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
response_1_created = chunk.created
|
||||||
response_1_content += chunk.choices[0].delta.content or ""
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_1_content)
|
print(response_1_content)
|
||||||
|
|
||||||
|
@ -575,9 +586,9 @@ def test_s3_cache_acompletion_stream_bedrock():
|
||||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||||
|
|
||||||
async def call2():
|
async def call2():
|
||||||
nonlocal response_2_content
|
nonlocal response_2_content, response_2_created
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
model="bedrock/anthropic.claude-v1",
|
model="azure/chatgpt-v-2",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=40,
|
max_tokens=40,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
|
@ -586,14 +597,23 @@ def test_s3_cache_acompletion_stream_bedrock():
|
||||||
async for chunk in response2:
|
async for chunk in response2:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
response_2_content += chunk.choices[0].delta.content or ""
|
response_2_content += chunk.choices[0].delta.content or ""
|
||||||
|
response_2_created = chunk.created
|
||||||
print(response_2_content)
|
print(response_2_content)
|
||||||
|
|
||||||
asyncio.run(call2())
|
asyncio.run(call2())
|
||||||
print("\nresponse 1", response_1_content)
|
print("\nresponse 1", response_1_content)
|
||||||
print("\nresponse 2", response_2_content)
|
print("\nresponse 2", response_2_content)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
response_1_content == response_2_content
|
response_1_content == response_2_content
|
||||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
|
|
||||||
|
# prioritizing getting a new deploy out - will look at this in the next deploy
|
||||||
|
# print("response 1 created", response_1_created)
|
||||||
|
# print("response 2 created", response_2_created)
|
||||||
|
|
||||||
|
# assert response_1_created == response_2_created
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm._async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
|
@ -602,7 +622,7 @@ def test_s3_cache_acompletion_stream_bedrock():
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
test_s3_cache_acompletion_stream_bedrock()
|
# test_s3_cache_acompletion_stream_azure()
|
||||||
|
|
||||||
|
|
||||||
# test_redis_cache_acompletion_stream_bedrock()
|
# test_redis_cache_acompletion_stream_bedrock()
|
||||||
|
|
|
@ -749,10 +749,14 @@ def test_completion_ollama_hosted():
|
||||||
model="ollama/phi",
|
model="ollama/phi",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
num_retries=3,
|
||||||
|
timeout=90,
|
||||||
api_base="https://test-ollama-endpoint.onrender.com",
|
api_base="https://test-ollama-endpoint.onrender.com",
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
except Timeout as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -1626,6 +1630,7 @@ def test_completion_anyscale_api():
|
||||||
|
|
||||||
|
|
||||||
def test_azure_cloudflare_api():
|
def test_azure_cloudflare_api():
|
||||||
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
@ -1641,11 +1646,12 @@ def test_azure_cloudflare_api():
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# test_azure_cloudflare_api()
|
test_azure_cloudflare_api()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_anyscale_2():
|
def test_completion_anyscale_2():
|
||||||
|
@ -1931,6 +1937,7 @@ def test_completion_cloudflare():
|
||||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||||
messages=[{"content": "what llm are you", "role": "user"}],
|
messages=[{"content": "what llm are you", "role": "user"}],
|
||||||
max_tokens=15,
|
max_tokens=15,
|
||||||
|
num_retries=3,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
@ -1938,7 +1945,7 @@ def test_completion_cloudflare():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_completion_cloudflare()
|
test_completion_cloudflare()
|
||||||
|
|
||||||
|
|
||||||
def test_moderation():
|
def test_moderation():
|
||||||
|
|
|
@ -103,7 +103,7 @@ def test_cost_azure_gpt_35():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
model="azure/gpt-35-turbo", # azure always has model written like this
|
model="gpt-35-turbo", # azure always has model written like this
|
||||||
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
|
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -125,3 +125,36 @@ def test_cost_azure_gpt_35():
|
||||||
|
|
||||||
|
|
||||||
test_cost_azure_gpt_35()
|
test_cost_azure_gpt_35()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cost_azure_embedding():
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
async def _test():
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm", "gm"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
response = asyncio.run(_test())
|
||||||
|
|
||||||
|
cost = litellm.completion_cost(completion_response=response)
|
||||||
|
|
||||||
|
print("Cost", cost)
|
||||||
|
expected_cost = float("7e-07")
|
||||||
|
assert cost == expected_cost
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(
|
||||||
|
f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# test_cost_azure_embedding()
|
|
@ -0,0 +1,15 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: azure-cloudflare
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: 2023-07-01-preview
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
set_verbose: True
|
||||||
|
cache: True # set cache responses to True
|
||||||
|
cache_params: # set cache params for s3
|
||||||
|
type: s3
|
||||||
|
s3_bucket_name: cache-bucket-litellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
|
@ -9,6 +9,11 @@ model_list:
|
||||||
api_key: os.environ/AZURE_CANADA_API_KEY
|
api_key: os.environ/AZURE_CANADA_API_KEY
|
||||||
model: azure/gpt-35-turbo
|
model: azure/gpt-35-turbo
|
||||||
model_name: azure-model
|
model_name: azure-model
|
||||||
|
- litellm_params:
|
||||||
|
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
model_name: azure-cloudflare-model
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
api_base: https://openai-france-1234.openai.azure.com
|
api_base: https://openai-france-1234.openai.azure.com
|
||||||
api_key: os.environ/AZURE_FRANCE_API_KEY
|
api_key: os.environ/AZURE_FRANCE_API_KEY
|
||||||
|
|
|
@ -59,6 +59,7 @@ def test_openai_embedding():
|
||||||
|
|
||||||
def test_openai_azure_embedding_simple():
|
def test_openai_azure_embedding_simple():
|
||||||
try:
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="azure/azure-embedding-model",
|
model="azure/azure-embedding-model",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
|
@ -70,6 +71,10 @@ def test_openai_azure_embedding_simple():
|
||||||
response_keys
|
response_keys
|
||||||
) # assert litellm response has expected keys from OpenAI embedding response
|
) # assert litellm response has expected keys from OpenAI embedding response
|
||||||
|
|
||||||
|
request_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
|
||||||
|
print("Calculated request cost=", request_cost)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -260,15 +265,22 @@ def test_aembedding():
|
||||||
input=["good morning from litellm", "this is another item"],
|
input=["good morning from litellm", "this is another item"],
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
asyncio.run(embedding_call())
|
response = asyncio.run(embedding_call())
|
||||||
|
print("Before caclulating cost, response", response)
|
||||||
|
|
||||||
|
cost = litellm.completion_cost(completion_response=response)
|
||||||
|
|
||||||
|
print("COST=", cost)
|
||||||
|
assert cost == float("1e-06")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_aembedding()
|
test_aembedding()
|
||||||
|
|
||||||
|
|
||||||
def test_aembedding_azure():
|
def test_aembedding_azure():
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -22,6 +22,7 @@ from litellm.proxy.proxy_server import (
|
||||||
router,
|
router,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
initialize,
|
initialize,
|
||||||
|
ProxyConfig,
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ def client():
|
||||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
initialize(config=config_fp)
|
asyncio.run(initialize(config=config_fp))
|
||||||
|
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from litellm.proxy.proxy_server import (
|
||||||
router,
|
router,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
initialize,
|
initialize,
|
||||||
|
startup_event,
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
@ -39,8 +40,8 @@ python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
||||||
def client():
|
def client():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||||
initialize(config=config_fp)
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
asyncio.run(initialize(config=config_fp))
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ from litellm.proxy.proxy_server import (
|
||||||
def client():
|
def client():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
||||||
initialize(config=config_fp)
|
asyncio.run(initialize(config=config_fp))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
@ -149,7 +149,7 @@ def test_chat_completion_exception_any_model(client):
|
||||||
response=response
|
response=response
|
||||||
)
|
)
|
||||||
print("Exception raised=", openai_exception)
|
print("Exception raised=", openai_exception)
|
||||||
assert isinstance(openai_exception, openai.NotFoundError)
|
assert isinstance(openai_exception, openai.BadRequestError)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
@ -170,7 +170,7 @@ def test_embedding_exception_any_model(client):
|
||||||
response=response
|
response=response
|
||||||
)
|
)
|
||||||
print("Exception raised=", openai_exception)
|
print("Exception raised=", openai_exception)
|
||||||
assert isinstance(openai_exception, openai.NotFoundError)
|
assert isinstance(openai_exception, openai.BadRequestError)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, logging
|
import pytest, logging, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -46,7 +46,7 @@ def client_no_auth():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
initialize(config=config_fp, debug=True)
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, logging
|
import pytest, logging, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -45,7 +45,7 @@ def client_no_auth():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
initialize(config=config_fp)
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
@ -280,33 +280,42 @@ def test_chat_completion_optional_params(client_no_auth):
|
||||||
# test_chat_completion_optional_params()
|
# test_chat_completion_optional_params()
|
||||||
|
|
||||||
# Test Reading config.yaml file
|
# Test Reading config.yaml file
|
||||||
from litellm.proxy.proxy_server import load_router_config
|
from litellm.proxy.proxy_server import ProxyConfig
|
||||||
|
|
||||||
|
|
||||||
def test_load_router_config():
|
def test_load_router_config():
|
||||||
try:
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
print("testing reading config")
|
print("testing reading config")
|
||||||
# this is a basic config.yaml with only a model
|
# this is a basic config.yaml with only a model
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
result = load_router_config(
|
proxy_config = ProxyConfig()
|
||||||
router=None,
|
result = asyncio.run(
|
||||||
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
|
proxy_config.load_config(
|
||||||
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result[1]) == 1
|
assert len(result[1]) == 1
|
||||||
|
|
||||||
# this is a load balancing config yaml
|
# this is a load balancing config yaml
|
||||||
result = load_router_config(
|
result = asyncio.run(
|
||||||
router=None,
|
proxy_config.load_config(
|
||||||
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result[1]) == 2
|
assert len(result[1]) == 2
|
||||||
|
|
||||||
# config with general settings - custom callbacks
|
# config with general settings - custom callbacks
|
||||||
result = load_router_config(
|
result = asyncio.run(
|
||||||
router=None,
|
proxy_config.load_config(
|
||||||
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result[1]) == 2
|
assert len(result[1]) == 2
|
||||||
|
@ -314,9 +323,11 @@ def test_load_router_config():
|
||||||
# tests for litellm.cache set from config
|
# tests for litellm.cache set from config
|
||||||
print("testing reading proxy config for cache")
|
print("testing reading proxy config for cache")
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
load_router_config(
|
asyncio.run(
|
||||||
router=None,
|
proxy_config.load_config(
|
||||||
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert litellm.cache is not None
|
assert litellm.cache is not None
|
||||||
assert "redis_client" in vars(
|
assert "redis_client" in vars(
|
||||||
|
@ -329,10 +340,14 @@ def test_load_router_config():
|
||||||
"aembedding",
|
"aembedding",
|
||||||
] # init with all call types
|
] # init with all call types
|
||||||
|
|
||||||
|
litellm.disable_cache()
|
||||||
|
|
||||||
print("testing reading proxy config for cache with params")
|
print("testing reading proxy config for cache with params")
|
||||||
load_router_config(
|
asyncio.run(
|
||||||
router=None,
|
proxy_config.load_config(
|
||||||
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
assert litellm.cache is not None
|
assert litellm.cache is not None
|
||||||
print(litellm.cache)
|
print(litellm.cache)
|
||||||
|
|
|
@ -1,38 +1,103 @@
|
||||||
# #### What this tests ####
|
#### What this tests ####
|
||||||
# # This tests using caching w/ litellm which requires SSL=True
|
# This tests using caching w/ litellm which requires SSL=True
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# import sys, os
|
load_dotenv()
|
||||||
# import time
|
import os, io
|
||||||
# import traceback
|
|
||||||
# from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# load_dotenv()
|
# this file is to test litellm/proxy
|
||||||
# import os
|
|
||||||
|
|
||||||
# sys.path.insert(
|
sys.path.insert(
|
||||||
# 0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
# import pytest
|
import pytest, logging, asyncio
|
||||||
# import litellm
|
import litellm
|
||||||
# from litellm import embedding, completion
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
# from litellm.caching import Cache
|
from litellm import RateLimitError
|
||||||
|
|
||||||
# messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG, # Set the desired logging level
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
# @pytest.mark.skip(reason="local proxy test")
|
# test /chat/completion request to the proxy
|
||||||
# def test_caching_v2(): # test in memory cache
|
from fastapi.testclient import TestClient
|
||||||
# try:
|
from fastapi import FastAPI
|
||||||
# response1 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000")
|
from litellm.proxy.proxy_server import (
|
||||||
# response2 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000")
|
router,
|
||||||
# print(f"response1: {response1}")
|
save_worker_config,
|
||||||
# print(f"response2: {response2}")
|
initialize,
|
||||||
# litellm.cache = None # disable cache
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
# if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
|
||||||
# print(f"response1: {response1}")
|
|
||||||
# print(f"response2: {response2}")
|
|
||||||
# raise Exception()
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"error occurred: {traceback.format_exc()}")
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
# test_caching_v2()
|
# Your bearer token
|
||||||
|
token = ""
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def client_no_auth():
|
||||||
|
# Assuming litellm.proxy.proxy_server is an object
|
||||||
|
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||||
|
|
||||||
|
cleanup_router_config_variables()
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml"
|
||||||
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_word(length=4):
|
||||||
|
import string, random
|
||||||
|
|
||||||
|
letters = string.ascii_lowercase
|
||||||
|
return "".join(random.choice(letters) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion(client_no_auth):
|
||||||
|
global headers
|
||||||
|
try:
|
||||||
|
user_message = f"Write a poem about {generate_random_word()}"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "azure-cloudflare",
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("testing proxy server with chat completions")
|
||||||
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
|
print(f"response - {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
response1_id = response["id"]
|
||||||
|
|
||||||
|
print("\n content", content)
|
||||||
|
|
||||||
|
assert len(content) > 1
|
||||||
|
|
||||||
|
print("\nmaking 2nd request to proxy. Testing caching + non streaming")
|
||||||
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
|
print(f"response - {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
print(response)
|
||||||
|
response2_id = response["id"]
|
||||||
|
assert response1_id == response2_id
|
||||||
|
litellm.disable_cache()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
|
@ -29,6 +29,7 @@ from litellm.proxy.proxy_server import (
|
||||||
router,
|
router,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
startup_event,
|
startup_event,
|
||||||
|
asyncio,
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
@ -39,7 +40,7 @@ save_worker_config(
|
||||||
alias=None,
|
alias=None,
|
||||||
api_base=None,
|
api_base=None,
|
||||||
api_version=None,
|
api_version=None,
|
||||||
debug=False,
|
debug=True,
|
||||||
temperature=None,
|
temperature=None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
request_timeout=600,
|
request_timeout=600,
|
||||||
|
@ -51,24 +52,38 @@ save_worker_config(
|
||||||
save=False,
|
save=False,
|
||||||
use_queue=False,
|
use_queue=False,
|
||||||
)
|
)
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(router) # Include your router in the test app
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
import asyncio
|
||||||
async def wrapper_startup_event():
|
|
||||||
await startup_event()
|
|
||||||
|
@pytest.fixture
|
||||||
|
def event_loop():
|
||||||
|
"""Create an instance of the default event loop for each test case."""
|
||||||
|
policy = asyncio.WindowsSelectorEventLoopPolicy()
|
||||||
|
res = policy.new_event_loop()
|
||||||
|
asyncio.set_event_loop(res)
|
||||||
|
res._close = res.close
|
||||||
|
res.close = lambda: None
|
||||||
|
|
||||||
|
yield res
|
||||||
|
|
||||||
|
res._close()
|
||||||
|
|
||||||
|
|
||||||
# Here you create a fixture that will be used by your tests
|
# Here you create a fixture that will be used by your tests
|
||||||
# Make sure the fixture returns TestClient(app)
|
# Make sure the fixture returns TestClient(app)
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(scope="function")
|
||||||
def client():
|
def client():
|
||||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
from litellm.proxy.proxy_server import cleanup_router_config_variables, initialize
|
||||||
|
|
||||||
cleanup_router_config_variables()
|
cleanup_router_config_variables() # rest proxy before test
|
||||||
with TestClient(app) as client:
|
|
||||||
yield client
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_add_new_key(client):
|
def test_add_new_key(client):
|
||||||
|
@ -79,7 +94,7 @@ def test_add_new_key(client):
|
||||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
"duration": "20m",
|
"duration": "20m",
|
||||||
}
|
}
|
||||||
print("testing proxy server")
|
print("testing proxy server - test_add_new_key")
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = os.getenv("PROXY_MASTER_KEY")
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
@ -121,7 +136,7 @@ def test_update_new_key(client):
|
||||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
"duration": "20m",
|
"duration": "20m",
|
||||||
}
|
}
|
||||||
print("testing proxy server")
|
print("testing proxy server-test_update_new_key")
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = os.getenv("PROXY_MASTER_KEY")
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,73 @@ def test_init_clients_basic():
|
||||||
# test_init_clients_basic()
|
# test_init_clients_basic()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_clients_basic_azure_cloudflare():
|
||||||
|
# init azure + cloudflare
|
||||||
|
# init OpenAI gpt-3.5
|
||||||
|
# init OpenAI text-embedding
|
||||||
|
# init OpenAI comptaible - Mistral/mistral-medium
|
||||||
|
# init OpenAI compatible - xinference/bge
|
||||||
|
litellm.set_verbose = True
|
||||||
|
try:
|
||||||
|
print("Test basic client init")
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-cloudflare",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": "https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-openai",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "text-embedding-ada-002",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "mistral",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "mistral/mistral-tiny",
|
||||||
|
"api_key": os.getenv("MISTRAL_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "bge-base-en",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "xinference/bge-base-en",
|
||||||
|
"api_base": "http://127.0.0.1:9997/v1",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
for elem in router.model_list:
|
||||||
|
model_id = elem["model_info"]["id"]
|
||||||
|
assert router.cache.get_cache(f"{model_id}_client") is not None
|
||||||
|
assert router.cache.get_cache(f"{model_id}_async_client") is not None
|
||||||
|
assert router.cache.get_cache(f"{model_id}_stream_client") is not None
|
||||||
|
assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
|
||||||
|
print("PASSED !")
|
||||||
|
|
||||||
|
# see if we can init clients without timeout or max retries set
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# test_init_clients_basic_azure_cloudflare()
|
||||||
|
|
||||||
|
|
||||||
def test_timeouts_router():
|
def test_timeouts_router():
|
||||||
"""
|
"""
|
||||||
Test the timeouts of the router with multiple clients. This HASas to raise a timeout error
|
Test the timeouts of the router with multiple clients. This HASas to raise a timeout error
|
||||||
|
|
137
litellm/tests/test_router_policy_violation.py
Normal file
137
litellm/tests/test_router_policy_violation.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests if the router sends back a policy violation, without retries
|
||||||
|
|
||||||
|
import sys, os, time
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
success: bool = False
|
||||||
|
failure: bool = False
|
||||||
|
previous_models: int = 0
|
||||||
|
|
||||||
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
print(f"Pre-API Call")
|
||||||
|
print(
|
||||||
|
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
|
||||||
|
)
|
||||||
|
self.previous_models += len(
|
||||||
|
kwargs["litellm_params"]["metadata"]["previous_models"]
|
||||||
|
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||||
|
print(f"self.previous_models: {self.previous_models}")
|
||||||
|
|
||||||
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(
|
||||||
|
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Stream")
|
||||||
|
|
||||||
|
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Stream")
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Failure")
|
||||||
|
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "azure/gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "vorrei vedere la cosa più bella ad Ercolano. Qual’è?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_fallbacks():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
num_retries=3,
|
||||||
|
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||||
|
# context_window_fallbacks=[
|
||||||
|
# {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
|
||||||
|
# {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
|
||||||
|
# ],
|
||||||
|
set_verbose=False,
|
||||||
|
)
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
try:
|
||||||
|
response = await router.acompletion(**kwargs)
|
||||||
|
pytest.fail(
|
||||||
|
f"An exception occurred: {e}"
|
||||||
|
) # should've raised azure policy error
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
await asyncio.sleep(
|
||||||
|
0.05
|
||||||
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
|
assert customHandler.previous_models == 0 # 0 retries, 0 fallback
|
||||||
|
router.reset()
|
||||||
|
finally:
|
||||||
|
router.reset()
|
|
@ -306,6 +306,8 @@ def test_completion_ollama_hosted_stream():
|
||||||
model="ollama/phi",
|
model="ollama/phi",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
num_retries=3,
|
||||||
|
timeout=90,
|
||||||
api_base="https://test-ollama-endpoint.onrender.com",
|
api_base="https://test-ollama-endpoint.onrender.com",
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
import sys, re, binascii, struct
|
import sys, re, binascii, struct
|
||||||
import litellm
|
import litellm
|
||||||
import dotenv, json, traceback, threading, base64
|
import dotenv, json, traceback, threading, base64, ast
|
||||||
import subprocess, os
|
import subprocess, os
|
||||||
import litellm, openai
|
import litellm, openai
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -1975,7 +1975,10 @@ def client(original_function):
|
||||||
if (
|
if (
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||||
or kwargs.get("caching", False) == True
|
or kwargs.get("caching", False) == True
|
||||||
or kwargs.get("cache", {}).get("no-cache", False) != True
|
or (
|
||||||
|
kwargs.get("cache", None) is not None
|
||||||
|
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||||
|
)
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose(f"INSIDE CHECKING CACHE")
|
print_verbose(f"INSIDE CHECKING CACHE")
|
||||||
|
@ -2737,6 +2740,8 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
||||||
completion_tokens_cost_usd_dollar = 0
|
completion_tokens_cost_usd_dollar = 0
|
||||||
model_cost_ref = litellm.model_cost
|
model_cost_ref = litellm.model_cost
|
||||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||||
|
print_verbose(f"Looking up model={model} in model_cost_map")
|
||||||
|
|
||||||
if model in model_cost_ref:
|
if model in model_cost_ref:
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
@ -2746,6 +2751,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif "ft:gpt-3.5-turbo" in model:
|
elif "ft:gpt-3.5-turbo" in model:
|
||||||
|
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||||
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
||||||
|
@ -2756,6 +2762,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model in litellm.azure_llms:
|
elif model in litellm.azure_llms:
|
||||||
|
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
|
||||||
model = litellm.azure_llms[model]
|
model = litellm.azure_llms[model]
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
@ -2764,19 +2771,29 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
||||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
else:
|
elif model in litellm.azure_embedding_models:
|
||||||
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows
|
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
|
||||||
input_cost_sum = 0
|
model = litellm.azure_embedding_models[model]
|
||||||
output_cost_sum = 0
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref = litellm.model_cost
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
for model in model_cost_ref:
|
)
|
||||||
input_cost_sum += model_cost_ref[model]["input_cost_per_token"]
|
completion_tokens_cost_usd_dollar = (
|
||||||
output_cost_sum += model_cost_ref[model]["output_cost_per_token"]
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
avg_input_cost = input_cost_sum / len(model_cost_ref.keys())
|
)
|
||||||
avg_output_cost = output_cost_sum / len(model_cost_ref.keys())
|
|
||||||
prompt_tokens_cost_usd_dollar = avg_input_cost * prompt_tokens
|
|
||||||
completion_tokens_cost_usd_dollar = avg_output_cost * completion_tokens
|
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
else:
|
||||||
|
# if model is not in model_prices_and_context_window.json. Raise an exception-let users know
|
||||||
|
error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}\n"
|
||||||
|
raise litellm.exceptions.NotFoundError( # type: ignore
|
||||||
|
message=error_str,
|
||||||
|
model=model,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=404,
|
||||||
|
content=error_str,
|
||||||
|
request=httpx.request(method="cost_per_token", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
llm_provider="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def completion_cost(
|
def completion_cost(
|
||||||
|
@ -2818,8 +2835,10 @@ def completion_cost(
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
if completion_response is not None:
|
if completion_response is not None:
|
||||||
# get input/output tokens from completion_response
|
# get input/output tokens from completion_response
|
||||||
prompt_tokens = completion_response["usage"]["prompt_tokens"]
|
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
|
||||||
completion_tokens = completion_response["usage"]["completion_tokens"]
|
completion_tokens = completion_response.get("usage", {}).get(
|
||||||
|
"completion_tokens", 0
|
||||||
|
)
|
||||||
model = (
|
model = (
|
||||||
model or completion_response["model"]
|
model or completion_response["model"]
|
||||||
) # check if user passed an override for model, if it's none check completion_response['model']
|
) # check if user passed an override for model, if it's none check completion_response['model']
|
||||||
|
@ -2829,6 +2848,10 @@ def completion_cost(
|
||||||
elif len(prompt) > 0:
|
elif len(prompt) > 0:
|
||||||
prompt_tokens = token_counter(model=model, text=prompt)
|
prompt_tokens = token_counter(model=model, text=prompt)
|
||||||
completion_tokens = token_counter(model=model, text=completion)
|
completion_tokens = token_counter(model=model, text=completion)
|
||||||
|
if model == None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate cost based on prompt_tokens, completion_tokens
|
# Calculate cost based on prompt_tokens, completion_tokens
|
||||||
if "togethercomputer" in model or "together_ai" in model:
|
if "togethercomputer" in model or "together_ai" in model:
|
||||||
|
@ -2849,8 +2872,7 @@ def completion_cost(
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM: Excepton when cost calculating {str(e)}")
|
raise e
|
||||||
return 0.0 # this should not block a users execution path
|
|
||||||
|
|
||||||
|
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
|
@ -4081,11 +4103,11 @@ def get_llm_provider(
|
||||||
print() # noqa
|
print() # noqa
|
||||||
error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers"
|
error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers"
|
||||||
# maps to openai.NotFoundError, this is raised when openai does not recognize the llm
|
# maps to openai.NotFoundError, this is raised when openai does not recognize the llm
|
||||||
raise litellm.exceptions.NotFoundError( # type: ignore
|
raise litellm.exceptions.BadRequestError( # type: ignore
|
||||||
message=error_str,
|
message=error_str,
|
||||||
model=model,
|
model=model,
|
||||||
response=httpx.Response(
|
response=httpx.Response(
|
||||||
status_code=404,
|
status_code=400,
|
||||||
content=error_str,
|
content=error_str,
|
||||||
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
|
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
),
|
),
|
||||||
|
@ -4915,6 +4937,9 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] =
|
||||||
if "id" in response_object:
|
if "id" in response_object:
|
||||||
model_response_object.id = response_object["id"]
|
model_response_object.id = response_object["id"]
|
||||||
|
|
||||||
|
if "created" in response_object:
|
||||||
|
model_response_object.created = response_object["created"]
|
||||||
|
|
||||||
if "system_fingerprint" in response_object:
|
if "system_fingerprint" in response_object:
|
||||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||||
|
|
||||||
|
@ -4959,6 +4984,9 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
|
||||||
if "id" in response_object:
|
if "id" in response_object:
|
||||||
model_response_object.id = response_object["id"]
|
model_response_object.id = response_object["id"]
|
||||||
|
|
||||||
|
if "created" in response_object:
|
||||||
|
model_response_object.created = response_object["created"]
|
||||||
|
|
||||||
if "system_fingerprint" in response_object:
|
if "system_fingerprint" in response_object:
|
||||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||||
|
|
||||||
|
@ -5014,6 +5042,9 @@ def convert_to_model_response_object(
|
||||||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||||
|
|
||||||
|
if "created" in response_object:
|
||||||
|
model_response_object.created = response_object["created"]
|
||||||
|
|
||||||
if "id" in response_object:
|
if "id" in response_object:
|
||||||
model_response_object.id = response_object["id"]
|
model_response_object.id = response_object["id"]
|
||||||
|
|
||||||
|
@ -6621,7 +6652,7 @@ def _is_base64(s):
|
||||||
|
|
||||||
def get_secret(
|
def get_secret(
|
||||||
secret_name: str,
|
secret_name: str,
|
||||||
default_value: Optional[str] = None,
|
default_value: Optional[Union[str, bool]] = None,
|
||||||
):
|
):
|
||||||
key_management_system = litellm._key_management_system
|
key_management_system = litellm._key_management_system
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
|
@ -6672,9 +6703,24 @@ def get_secret(
|
||||||
secret = client.get_secret(secret_name).secret_value
|
secret = client.get_secret(secret_name).secret_value
|
||||||
except Exception as e: # check if it's in os.environ
|
except Exception as e: # check if it's in os.environ
|
||||||
secret = os.getenv(secret_name)
|
secret = os.getenv(secret_name)
|
||||||
return secret
|
try:
|
||||||
|
secret_value_as_bool = ast.literal_eval(secret)
|
||||||
|
if isinstance(secret_value_as_bool, bool):
|
||||||
|
return secret_value_as_bool
|
||||||
|
else:
|
||||||
|
return secret
|
||||||
|
except:
|
||||||
|
return secret
|
||||||
else:
|
else:
|
||||||
return os.environ.get(secret_name)
|
secret = os.environ.get(secret_name)
|
||||||
|
try:
|
||||||
|
secret_value_as_bool = ast.literal_eval(secret)
|
||||||
|
if isinstance(secret_value_as_bool, bool):
|
||||||
|
return secret_value_as_bool
|
||||||
|
else:
|
||||||
|
return secret
|
||||||
|
except:
|
||||||
|
return secret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if default_value is not None:
|
if default_value is not None:
|
||||||
return default_value
|
return default_value
|
||||||
|
|
|
@ -111,6 +111,13 @@
|
||||||
"litellm_provider": "openai",
|
"litellm_provider": "openai",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"text-embedding-ada-002-v2": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
"256-x-256/dall-e-2": {
|
"256-x-256/dall-e-2": {
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"input_cost_per_pixel": 0.00000024414,
|
"input_cost_per_pixel": 0.00000024414,
|
||||||
|
@ -242,6 +249,13 @@
|
||||||
"litellm_provider": "azure",
|
"litellm_provider": "azure",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"azure/ada": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
"azure/text-embedding-ada-002": {
|
"azure/text-embedding-ada-002": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
"input_cost_per_token": 0.0000001,
|
"input_cost_per_token": 0.0000001,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.16.13"
|
version = "1.16.14"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
@ -59,7 +59,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.16.13"
|
version = "1.16.14"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
28
retry_push.sh
Normal file
28
retry_push.sh
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
retry_count=0
|
||||||
|
max_retries=3
|
||||||
|
exit_code=1
|
||||||
|
|
||||||
|
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
||||||
|
do
|
||||||
|
retry_count=$((retry_count+1))
|
||||||
|
echo "Attempt $retry_count..."
|
||||||
|
|
||||||
|
# Run the Prisma db push command
|
||||||
|
prisma db push --accept-data-loss
|
||||||
|
|
||||||
|
exit_code=$?
|
||||||
|
|
||||||
|
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
||||||
|
echo "Retrying in 10 seconds..."
|
||||||
|
sleep 10
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $exit_code -ne 0 ]; then
|
||||||
|
echo "Unable to push database changes after $max_retries retries."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Database push successful!"
|
33
schema.prisma
Normal file
33
schema.prisma
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
datasource client {
|
||||||
|
provider = "postgresql"
|
||||||
|
url = env("DATABASE_URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
generator client {
|
||||||
|
provider = "prisma-client-py"
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_UserTable {
|
||||||
|
user_id String @unique
|
||||||
|
max_budget Float?
|
||||||
|
spend Float @default(0.0)
|
||||||
|
user_email String?
|
||||||
|
}
|
||||||
|
|
||||||
|
// required for token gen
|
||||||
|
model LiteLLM_VerificationToken {
|
||||||
|
token String @unique
|
||||||
|
spend Float @default(0.0)
|
||||||
|
expires DateTime?
|
||||||
|
models String[]
|
||||||
|
aliases Json @default("{}")
|
||||||
|
config Json @default("{}")
|
||||||
|
user_id String?
|
||||||
|
max_parallel_requests Int?
|
||||||
|
metadata Json @default("{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
model LiteLLM_Config {
|
||||||
|
param_name String @id
|
||||||
|
param_value Json?
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue