Merge branch 'main' into improve-langchain-integration

This commit is contained in:
Max Deichmann 2023-12-21 23:50:01 +01:00 committed by GitHub
commit 1c68f5557d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 1169 additions and 322 deletions

View file

@ -1,66 +0,0 @@
name: Build Docker Images
on:
workflow_dispatch:
inputs:
tag:
description: "The tag version you want to build"
jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
env:
REPO_NAME: ${{ github.repository }}
steps:
- name: Convert repo name to lowercase
run: echo "REPO_NAME=$(echo "$REPO_NAME" | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GHCR_TOKEN }}
logout: false
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: ghcr.io/berriai/litellm
- name: Get tag to build
id: tag
run: |
echo "latest=ghcr.io/${{ env.REPO_NAME }}:latest" >> $GITHUB_OUTPUT
if [[ -z "${{ github.event.inputs.tag }}" ]]; then
echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.ref_name }}" >> $GITHUB_OUTPUT
else
echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT
fi
- name: Debug Info
run: |
echo "GHCR_TOKEN=${{ secrets.GHCR_TOKEN }}"
echo "REPO_NAME=${{ env.REPO_NAME }}"
echo "ACTOR=${{ github.actor }}"
- name: Build and push container image to registry
uses: docker/build-push-action@v2
with:
push: true
tags: ghcr.io/${{ env.REPO_NAME }}:${{ github.sha }}
file: ./Dockerfile
- name: Build and release Docker images
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64
tags: |
${{ steps.tag.outputs.latest }}
${{ steps.tag.outputs.versioned }}
labels: ${{ steps.meta.outputs.labels }}
push: true

View file

@ -1,10 +1,12 @@
# #
name: Build & Publich to GHCR name: Build & Publish to GHCR
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
tag: tag:
description: "The tag version you want to build" description: "The tag version you want to build"
release:
types: [published]
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds. # Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
env: env:
@ -44,5 +46,5 @@ jobs:
with: with:
context: . context: .
push: true push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag }} # Add the input tag to the image tags tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name }} # if a tag is provided, use that, otherwise use the release tag
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}

View file

@ -1,13 +1,19 @@
# Exception Mapping # Exception Mapping
LiteLLM maps exceptions across all providers to their OpenAI counterparts. LiteLLM maps exceptions across all providers to their OpenAI counterparts.
- Rate Limit Errors
- Invalid Request Errors | Status Code | Error Type |
- Authentication Errors |-------------|--------------------------|
- Timeout Errors `openai.APITimeoutError` | 400 | BadRequestError |
- ServiceUnavailableError | 401 | AuthenticationError |
- APIError | 403 | PermissionDeniedError |
- APIConnectionError | 404 | NotFoundError |
| 422 | UnprocessableEntityError |
| 429 | RateLimitError |
| >=500 | InternalServerError |
| N/A | ContextWindowExceededError|
| N/A | APIConnectionError |
Base case we return APIConnectionError Base case we return APIConnectionError
@ -83,6 +89,7 @@ Base case - we return the original exception.
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------| |---------------|----------------------------|---------------------|---------------------|---------------|-------------------------|
| Anthropic | ✅ | ✅ | ✅ | ✅ | | | Anthropic | ✅ | ✅ | ✅ | ✅ | |
| OpenAI | ✅ | ✅ |✅ |✅ |✅| | OpenAI | ✅ | ✅ |✅ |✅ |✅|
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅|
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | Replicate | ✅ | ✅ | ✅ | ✅ | ✅ |
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ | | Cohere | ✅ | ✅ | ✅ | ✅ | ✅ |
| Huggingface | ✅ | ✅ | ✅ | ✅ | | | Huggingface | ✅ | ✅ | ✅ | ✅ | |

View file

@ -15,7 +15,7 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
## Pre-Requisites ## Pre-Requisites
Ensure you have run `pip install langfuse` for this integration Ensure you have run `pip install langfuse` for this integration
```shell ```shell
pip install langfuse litellm pip install langfuse==1.14.0 litellm
``` ```
## Quick Start ## Quick Start

View file

@ -41,6 +41,18 @@ def send_slack_alert(
# get it from https://api.slack.com/messaging/webhooks # get it from https://api.slack.com/messaging/webhooks
slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>" slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
# Remove api_key from kwargs under litellm_params
if kwargs.get('litellm_params'):
kwargs['litellm_params'].pop('api_key', None)
if kwargs['litellm_params'].get('metadata'):
kwargs['litellm_params']['metadata'].pop('deployment', None)
# Remove deployment under metadata
if kwargs.get('metadata'):
kwargs['metadata'].pop('deployment', None)
# Prevent api_key from being logged
if kwargs.get('api_key'):
kwargs.pop('api_key', None)
# Define the text payload, send data available in litellm custom_callbacks # Define the text payload, send data available in litellm custom_callbacks
text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)} text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
""" """

View file

@ -77,6 +77,45 @@ Ollama supported models: https://github.com/jmorganca/ollama
| Nous-Hermes 13B | `completion(model='ollama/nous-hermes:13b', messages, api_base="http://localhost:11434", stream=True)` | | Nous-Hermes 13B | `completion(model='ollama/nous-hermes:13b', messages, api_base="http://localhost:11434", stream=True)` |
| Wizard Vicuna Uncensored | `completion(model='ollama/wizard-vicuna', messages, api_base="http://localhost:11434", stream=True)` | | Wizard Vicuna Uncensored | `completion(model='ollama/wizard-vicuna', messages, api_base="http://localhost:11434", stream=True)` |
## Ollama Vision Models
| Model Name | Function Call |
|------------------|--------------------------------------|
| llava | `completion('ollama/llava', messages)` |
#### Using Ollama Vision Models
Call `ollama/llava` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models)
LiteLLM Supports the following image types passed in `url`
- Base64 encoded svgs
**Example Request**
```python
import litellm
response = litellm.completion(
model = "ollama/llava",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"
}
}
]
}
],
)
print(response)
```
## LiteLLM/Ollama Docker Image ## LiteLLM/Ollama Docker Image

View file

@ -1,5 +1,5 @@
# OpenRouter # OpenRouter
LiteLLM supports all the text models from [OpenRouter](https://openrouter.ai/docs) LiteLLM supports all the text / chat / vision models from [OpenRouter](https://openrouter.ai/docs)
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_OpenRouter.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_OpenRouter.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@ -20,10 +20,10 @@ response = completion(
) )
``` ```
### OpenRouter Completion Models ## OpenRouter Completion Models
| Model Name | Function Call | Required OS Variables | | Model Name | Function Call |
|---------------------------|-----------------------------------------------------|--------------------------------------------------------------| |---------------------------|-----------------------------------------------------|
| openrouter/openai/gpt-3.5-turbo | `completion('openrouter/openai/gpt-3.5-turbo', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/openai/gpt-3.5-turbo | `completion('openrouter/openai/gpt-3.5-turbo', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
| openrouter/openai/gpt-3.5-turbo-16k | `completion('openrouter/openai/gpt-3.5-turbo-16k', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/openai/gpt-3.5-turbo-16k | `completion('openrouter/openai/gpt-3.5-turbo-16k', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
| openrouter/openai/gpt-4 | `completion('openrouter/openai/gpt-4', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/openai/gpt-4 | `completion('openrouter/openai/gpt-4', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
@ -35,3 +35,19 @@ response = completion(
| openrouter/meta-llama/llama-2-13b-chat | `completion('openrouter/meta-llama/llama-2-13b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/meta-llama/llama-2-13b-chat | `completion('openrouter/meta-llama/llama-2-13b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
| openrouter/meta-llama/llama-2-70b-chat | `completion('openrouter/meta-llama/llama-2-70b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/meta-llama/llama-2-70b-chat | `completion('openrouter/meta-llama/llama-2-70b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
## Passing OpenRouter Params - transforms, models, route
Pass `transforms`, `models`, `route`as arguments to `litellm.completion()`
```python
import os
from litellm import completion
os.environ["OPENROUTER_API_KEY"] = ""
response = completion(
model="openrouter/google/palm-2-chat-bison",
messages=messages,
transforms = [""],
route= ""
)
```

View file

@ -60,3 +60,20 @@ $ litellm /path/to/config.yaml
``` ```
curl --location 'http://0.0.0.0:8000/health' curl --location 'http://0.0.0.0:8000/health'
``` ```
## Embedding Models
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
```yaml
model_list:
- model_name: azure-embedding-model
litellm_params:
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
mode: embedding # 👈 ADD THIS
```

View file

@ -461,7 +461,7 @@ We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this
**Step 1** Install langfuse **Step 1** Install langfuse
```shell ```shell
pip install langfuse pip install langfuse==1.14.0
``` ```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback` **Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`

View file

@ -1,6 +1,17 @@
# Model Management # Model Management
Add new models + Get model info without restarting proxy. Add new models + Get model info without restarting proxy.
## In Config.yaml
```yaml
model_list:
- model_name: text-davinci-003
litellm_params:
model: "text-completion-openai/text-davinci-003"
model_info:
metadata: "here's additional metadata on the model" # returned via GET /model/info
```
## Get Model Information ## Get Model Information
Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes. Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.

View file

@ -349,6 +349,12 @@ litellm --config your_config.yaml
[**More Info**](./configs.md) [**More Info**](./configs.md)
## Server Endpoints ## Server Endpoints
:::note
You can see Swagger Docs for the server on root http://0.0.0.0:8000
:::
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs - POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
- POST `/completions` - completions endpoint - POST `/completions` - completions endpoint
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints - POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints

View file

@ -39,14 +39,17 @@ litellm --config /path/to/config.yaml
```shell ```shell
curl 'http://0.0.0.0:8000/key/generate' \ curl 'http://0.0.0.0:8000/key/generate' \
--h 'Authorization: Bearer sk-1234' \ --header 'Authorization: Bearer sk-1234' \
--d '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' --header 'Content-Type: application/json' \
--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m","metadata": {"user": "ishaan@berri.ai", "team": "core-infra"}}'
``` ```
- `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server. - `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server.
- `duration`: *str or null (optional)* Specify the length of time the token is valid for. If null, default is set to 1 hour. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). - `duration`: *str or null (optional)* Specify the length of time the token is valid for. If null, default is set to 1 hour. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- `metadata`: *dict or null (optional)* Pass metadata for the created token. If null defaults to {}
Expected response: Expected response:
```python ```python

View file

@ -6,6 +6,7 @@ from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -261,6 +262,9 @@ class AzureChatCompletion(BaseLLM):
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(self, def streaming(self,
@ -463,13 +467,52 @@ class AzureChatCompletion(BaseLLM):
import traceback import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
self,
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
client=None,
logging_obj=None
):
response = None
try:
if client is None:
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data)
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def image_generation(self, def image_generation(self,
prompt: list, prompt: str,
timeout: float, timeout: float,
model: Optional[str]=None, model: Optional[str]=None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None,
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
client=None, client=None,
@ -477,9 +520,12 @@ class AzureChatCompletion(BaseLLM):
): ):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if model and len(model) > 0:
model = model model = model
else:
model = None
data = { data = {
# "model": model, "model": model,
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params
} }
@ -487,12 +533,26 @@ class AzureChatCompletion(BaseLLM):
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(status_code=422, message="max retries must be an int")
# if aembedding == True: # init AzureOpenAI Client
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore azure_client_params = {
# return response "api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
return response
if client is None: if client is None:
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else: else:
azure_client = client azure_client = client

View file

@ -0,0 +1,118 @@
import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async implementation of custom http transport
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = await super().handle_async_request(request)
await response.aread()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after")) or 10)
response = await super().handle_async_request(request)
await response.aread()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
"""
def handle_request(
self,
request: httpx.Request,
) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
response = super().handle_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = super().handle_request(request)
response.read()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after")) or 10)
response = super().handle_request(request)
response.read()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return super().handle_request(request)

View file

@ -1,14 +1,10 @@
import requests, types, time import requests, types, time
import json import json, uuid
import traceback import traceback
from typing import Optional from typing import Optional
import litellm import litellm
import httpx, aiohttp, asyncio import httpx, aiohttp, asyncio
try: from .prompt_templates.factory import prompt_factory, custom_prompt
from async_generator import async_generator, yield_ # optional dependency
async_generator_imported = True
except ImportError:
async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement
class OllamaError(Exception): class OllamaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -106,9 +102,8 @@ class OllamaConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None} and v is not None}
# ollama implementation # ollama implementation
def get_ollama_response_stream( def get_ollama_response(
api_base="http://localhost:11434", api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
@ -129,6 +124,7 @@ def get_ollama_response_stream(
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False)
data = { data = {
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,
@ -146,9 +142,41 @@ def get_ollama_response_stream(
else: else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
return response return response
elif optional_params.get("stream", False):
else:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
)
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = response.json()
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response
def ollama_completion_stream(url, data, logging_obj): def ollama_completion_stream(url, data, logging_obj):
with httpx.stream( with httpx.stream(
@ -157,13 +185,15 @@ def ollama_completion_stream(url, data, logging_obj):
method="POST", method="POST",
timeout=litellm.request_timeout timeout=litellm.request_timeout
) as response: ) as response:
try:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try: try:
@ -194,39 +224,31 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
text = await resp.text() text = await resp.text()
raise OllamaError(status_code=resp.status, message=text) raise OllamaError(status_code=resp.status, message=text)
completion_string = "" ## LOGGING
async for line in resp.content.iter_any(): logging_obj.post_call(
if line: input=data['prompt'],
try: api_key="",
json_chunk = line.decode("utf-8") original_response=resp.text,
chunks = json_chunk.split("\n") additional_args={
for chunk in chunks: "headers": None,
if chunk.strip() != "": "api_base": url,
j = json.loads(chunk) },
if "error" in j: )
completion_obj = {
"role": "assistant",
"content": "",
"error": j
}
raise Exception(f"OllamError - {chunk}")
if "response" in j:
completion_obj = {
"role": "assistant",
"content": j["response"],
}
completion_string = completion_string + completion_obj["content"]
except Exception as e:
traceback.print_exc()
response_json = await resp.json()
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = completion_string if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model'] model_response["model"] = "ollama/" + data['model']
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = len(encoding.encode(completion_string)) completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e

View file

@ -445,6 +445,43 @@ class OpenAIChatCompletion(BaseLLM):
import traceback import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
self,
prompt: str,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
client=None,
max_retries=None,
logging_obj=None
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data) # type: ignore
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def image_generation(self, def image_generation(self,
model: Optional[str], model: Optional[str],
prompt: str, prompt: str,
@ -631,6 +668,7 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str): model: str):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -649,6 +687,8 @@ class OpenAITextCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e:
raise e
def streaming(self, def streaming(self,
logging_obj, logging_obj,
@ -687,9 +727,12 @@ class OpenAITextCompletion(BaseLLM):
method="POST", method="POST",
timeout=litellm.request_timeout timeout=litellm.request_timeout
) as response: ) as response:
try:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e:
raise e

View file

@ -348,7 +348,7 @@ def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/r
# Function call template # Function call template
def function_call_prompt(messages: list, functions: list): def function_call_prompt(messages: list, functions: list):
function_prompt = "The following functions are available to you:" function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:"
for function in functions: for function in functions:
function_prompt += f"""\n{function}\n""" function_prompt += f"""\n{function}\n"""
@ -425,6 +425,6 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return alpaca_pt(messages=messages) return alpaca_pt(messages=messages)
else: else:
return hf_chat_template(original_model_name, messages) return hf_chat_template(original_model_name, messages)
except: except Exception as e:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -181,6 +181,7 @@ def completion(
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
vertexai.init( vertexai.init(
@ -193,6 +194,16 @@ def completion(
if k not in optional_params: if k not in optional_params:
optional_params[k] = v optional_params[k] = v
## Process safety settings into format expected by vertex AI
safety_settings = None
if "safety_settings" in optional_params:
safety_settings = optional_params.pop("safety_settings")
if not isinstance(safety_settings, list):
raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts")
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
# vertexai does not use an API key, it looks for credentials.json in the environment # vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)]) prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
@ -238,16 +249,16 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n" request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params)) response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "vision": elif mode == "vision":
@ -258,12 +269,13 @@ def completion(
content = [prompt] + images content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = llm_model.generate_content( model_response = llm_model.generate_content(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=True stream=True
) )
optional_params["stream"] = True optional_params["stream"] = True
@ -276,7 +288,8 @@ def completion(
## LLM Call ## LLM Call
response = llm_model.generate_content( response = llm_model.generate_content(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params) generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
) )
completion_response = response.text completion_response = response.text
response_obj = response._raw_response response_obj = response._raw_response

View file

@ -1329,23 +1329,11 @@ def completion(
optional_params["images"] = images optional_params["images"] = images
## LOGGING ## LOGGING
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) generator = ollama.get_ollama_response(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True or optional_params.get("stream", False) == True: if acompletion is True or optional_params.get("stream", False) == True:
return generator return generator
else:
response_string = ""
for chunk in generator:
response_string+=chunk['content']
## RESPONSE OBJECT response = generator
model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = response_string
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = len(encoding.encode(prompt)) # type: ignore
completion_tokens = len(encoding.encode(response_string))
model_response["usage"] = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
response = model_response
elif ( elif (
custom_llm_provider == "baseten" custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co" or litellm.api_base == "https://app.baseten.co"
@ -2026,11 +2014,9 @@ async def atext_completion(*args, **kwargs):
response = text_completion(*args, **kwargs) response = text_completion(*args, **kwargs)
else: else:
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO if asyncio.iscoroutine(response):
response = init_response response = await response
elif asyncio.iscoroutine(init_response):
response = await init_response
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
@ -2205,7 +2191,8 @@ def text_completion(
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response return response
if kwargs.get("acompletion", False) == True:
return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models
try: try:
@ -2243,6 +2230,49 @@ def moderation(input: str, api_key: Optional[str]=None):
return response return response
##### Image Generation ####################### ##### Image Generation #######################
@client
async def aimage_generation(*args, **kwargs):
"""
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `embedding` function.
- `kwargs` (dict): Keyword arguments to be passed to the `embedding` function.
Returns:
- `response` (Any): The response returned by the `embedding` function.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["aimg_generation"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(image_generation, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
@client @client
def image_generation(prompt: str, def image_generation(prompt: str,
model: Optional[str]=None, model: Optional[str]=None,
@ -2264,6 +2294,7 @@ def image_generation(prompt: str,
Currently supports just Azure + OpenAI. Currently supports just Azure + OpenAI.
""" """
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None) litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None) logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get('proxy_server_request', None) proxy_server_request = kwargs.get('proxy_server_request', None)
@ -2277,7 +2308,7 @@ def image_generation(prompt: str,
model = "dall-e-2" model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai custom_llm_provider = "openai" # default to dall-e-2 on openai
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"] openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] litellm_params = ["metadata", "aimg_generation", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(n=n, optional_params = get_optional_params_image_gen(n=n,
@ -2320,10 +2351,9 @@ def image_generation(prompt: str,
get_secret("AZURE_AD_TOKEN") get_secret("AZURE_AD_TOKEN")
) )
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation)
pass
elif custom_llm_provider == "openai": elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation)
return model_response return model_response

View file

@ -120,6 +120,8 @@ class GenerateKeyRequest(LiteLLMBase):
spend: Optional[float] = 0 spend: Optional[float] = 0
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
max_budget: Optional[float] = None
class UpdateKeyRequest(LiteLLMBase): class UpdateKeyRequest(LiteLLMBase):
key: str key: str
@ -130,21 +132,8 @@ class UpdateKeyRequest(LiteLLMBase):
spend: Optional[float] = None spend: Optional[float] = None
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class GenerateKeyResponse(LiteLLMBase): max_budget: Optional[float] = None
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
""" """
@ -158,6 +147,20 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
duration: str = "1h" duration: str = "1h"
metadata: dict = {}
max_budget: Optional[float] = None
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """

View file

@ -96,7 +96,7 @@ async def _perform_health_check(model_list: list):
async def perform_health_check(model_list: list, model: Optional[str] = None): async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None):
""" """
Perform a health check on the system. Perform a health check on the system.
@ -104,6 +104,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None):
(bool): True if the health check passes, False otherwise. (bool): True if the health check passes, False otherwise.
""" """
if not model_list: if not model_list:
if cli_model:
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}]
else:
return [], [] return [], []
if model is not None: if model is not None:

View file

@ -0,0 +1,35 @@
from typing import Optional
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_budget = user_api_key_dict.max_budget
curr_spend = user_api_key_dict.spend
if api_key is None:
return
if max_budget is None:
return
if curr_spend is None:
return
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")

View file

@ -47,7 +47,7 @@ litellm_settings:
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
general_settings: # general_settings:
environment_variables: environment_variables:
# otel: True # OpenTelemetry Logger # otel: True # OpenTelemetry Logger

View file

@ -110,7 +110,7 @@ import json
import logging import logging
from typing import Union from typing import Union
app = FastAPI(docs_url="/", title="LiteLLM API") app = FastAPI(docs_url="/", title="LiteLLM API", description="Proxy Server to call 100+ LLMs in the OpenAI format")
router = APIRouter() router = APIRouter()
origins = ["*"] origins = ["*"]
@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
router = litellm.Router(**router_params) # type:ignore router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings return router, model_list, general_settings
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None): async def generate_key_helper_fn(duration: Optional[str],
models: list,
aliases: dict,
config: dict,
spend: float,
max_budget: Optional[float]=None,
token: Optional[str]=None,
user_id: Optional[str]=None,
max_parallel_requests: Optional[int]=None,
metadata: Optional[dict] = {},):
global prisma_client global prisma_client
if prisma_client is None: if prisma_client is None:
@ -653,6 +662,7 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
aliases_json = json.dumps(aliases) aliases_json = json.dumps(aliases)
config_json = json.dumps(config) config_json = json.dumps(config)
metadata_json = json.dumps(metadata)
user_id = user_id or str(uuid.uuid4()) user_id = user_id or str(uuid.uuid4())
try: try:
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
@ -664,7 +674,9 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
"config": config_json, "config": config_json,
"spend": spend, "spend": spend,
"user_id": user_id, "user_id": user_id,
"max_parallel_requests": max_parallel_requests "max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
"max_budget": max_budget
} }
new_verification_token = await prisma_client.insert_data(data=verification_token_data) new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e: except Exception as e:
@ -774,6 +786,7 @@ def initialize(
print(f"\033[1;34mLiteLLM: Test your local proxy with: \"litellm --test\" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n") print(f"\033[1;34mLiteLLM: Test your local proxy with: \"litellm --test\" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n")
print(f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n") print(f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n")
print("\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n") print("\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n")
print(f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:8000 \033[0m\n")
# for streaming # for streaming
def data_generator(response): def data_generator(response):
print_verbose("inside generator") print_verbose("inside generator")
@ -871,9 +884,9 @@ def model_list():
object="list", object="list",
) )
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
try: try:
@ -1044,8 +1057,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
detail=error_msg detail=error_msg
) )
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"])
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global proxy_logging_obj global proxy_logging_obj
try: try:
@ -1124,6 +1137,72 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
detail=error_msg detail=error_msg
) )
@router.post("/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"])
@router.post("/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"])
async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data) # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("image_generation_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aimage_generation(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(**data, specific_deployment = True)
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.aimage_generation(**data) # ensure this goes the llm_router, router will do the correct alias mapping
else:
response = await litellm.aimage_generation(**data)
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
#### KEY MANAGEMENT #### #### KEY MANAGEMENT ####
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) @router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
@ -1140,6 +1219,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml - config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns: Returns:
- key: (str) The generated api key - key: (str) The generated api key
@ -1165,7 +1245,6 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
non_default_values = {k: v for k, v in data_json.items() if v is not None} non_default_values = {k: v for k, v in data_json.items() if v is not None}
print(f"non_default_values: {non_default_values}")
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key}) response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
return {"key": key, **non_default_values} return {"key": key, **non_default_values}
# update based on remaining passed in values # update based on remaining passed in values
@ -1514,9 +1593,18 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query
``` ```
else, the health checks will be run on models when /health is called. else, the health checks will be run on models when /health is called.
""" """
global health_check_results, use_background_health_checks global health_check_results, use_background_health_checks, user_model
if llm_model_list is None: if llm_model_list is None:
# if no router set, check if user set a model using litellm --model ollama/llama2
if user_model is not None:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=[], cli_model=user_model)
return {
"healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
}
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Model list not initialized"}, detail={"error": "Model list not initialized"},

View file

@ -17,4 +17,6 @@ model LiteLLM_VerificationToken {
config Json @default("{}") config Json @default("{}")
user_id String? user_id String?
max_parallel_requests Int? max_parallel_requests Int?
metadata Json @default("{}")
max_budget Float?
} }

View file

@ -1,9 +1,10 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio, copy import os, subprocess, hashlib, importlib, asyncio, copy, json
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
@ -23,11 +24,13 @@ class ProxyLogging:
self.call_details: dict = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler() self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
pass pass
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback)
@ -147,6 +150,14 @@ class PrismaClient:
return hashed_token return hashed_token
def jsonify_object(self, data: dict) -> dict:
db_data = copy.deepcopy(data)
for k, v in db_data.items():
if isinstance(v, dict):
db_data[k] = json.dumps(v)
return db_data
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
@ -193,9 +204,8 @@ class PrismaClient:
try: try:
token = data["token"] token = data["token"]
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
db_data = copy.deepcopy(data) db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, 'token': hashed_token,
@ -228,7 +238,7 @@ class PrismaClient:
if token.startswith("sk-"): if token.startswith("sk-"):
token = self.hash_token(token=token) token = self.hash_token(token=token)
db_data = copy.deepcopy(data) db_data = self.jsonify_object(data=data)
db_data["token"] = token db_data["token"] = token
response = await self.db.litellm_verificationtoken.update( response = await self.db.litellm_verificationtoken.update(
where={ where={

View file

@ -7,7 +7,7 @@
# #
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import copy import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
@ -18,6 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI from openai import AsyncOpenAI
from collections import defaultdict from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
import copy import copy
class Router: class Router:
""" """
@ -84,11 +85,11 @@ class Router:
self.set_verbose = set_verbose self.set_verbose = set_verbose
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
if model_list: if model_list:
model_list = copy.deepcopy(model_list) model_list = copy.deepcopy(model_list)
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list self.healthy_deployments: List = self.model_list
self.deployment_latency_map = {}
for m in model_list: for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0 self.deployment_latency_map[m["litellm_params"]["model"]] = 0
@ -166,7 +167,7 @@ class Router:
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n") self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
### COMPLETION + EMBEDDING FUNCTIONS ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
def completion(self, def completion(self,
model: str, model: str,
@ -260,6 +261,94 @@ class Router:
self.fail_calls[model_name] +=1 self.fail_calls[model_name] +=1
raise e raise e
def image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
async def aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
def text_completion(self, def text_completion(self,
model: str, model: str,
prompt: str, prompt: str,
@ -436,7 +525,6 @@ class Router:
async def async_function_with_retries(self, *args, **kwargs): async def async_function_with_retries(self, *args, **kwargs):
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
backoff_factor = 1
original_function = kwargs.pop("original_function") original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks) fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
@ -1009,14 +1097,16 @@ class Router:
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
max_retries=max_retries max_retries=max_retries,
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore
) )
model["client"] = openai.AzureOpenAI( model["client"] = openai.AzureOpenAI(
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
max_retries=max_retries max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
) )
# streaming clients should have diff timeouts # streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI( model["stream_async_client"] = openai.AsyncAzureOpenAI(
@ -1024,7 +1114,7 @@ class Router:
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries max_retries=max_retries,
) )
model["stream_client"] = openai.AzureOpenAI( model["stream_client"] = openai.AzureOpenAI(

View file

@ -91,7 +91,7 @@ def test_vertex_ai():
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
litellm.set_verbose=False litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718" litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
@ -113,7 +113,7 @@ def test_vertex_ai():
def test_vertex_ai_stream(): def test_vertex_ai_stream():
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose=False litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718" litellm.vertex_project = "reliablekeys"
import random import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models

View file

@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider():
# test_completion_hf_model_no_provider() # test_completion_hf_model_no_provider()
# def test_completion_openai_azure_with_functions(): def test_completion_anyscale_with_functions():
# function1 = [ function1 = [
# { {
# "name": "get_current_weather", "name": "get_current_weather",
# "description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
# "parameters": { "parameters": {
# "type": "object", "type": "object",
# "properties": { "properties": {
# "location": { "location": {
# "type": "string", "type": "string",
# "description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
# }, },
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
# }, },
# "required": ["location"], "required": ["location"],
# }, },
# } }
# ] ]
# try: try:
# messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
# response = completion( response = completion(
# model="azure/chatgpt-functioncalling", messages=messages, functions=function1 model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1
# ) )
# # Add any assertions here to check the response # Add any assertions here to check the response
# print(response) print(response)
# except Exception as e: except Exception as e:
# pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_azure_with_functions() test_completion_anyscale_with_functions()
def test_completion_azure_key_completion_arg(): def test_completion_azure_key_completion_arg():
# this tests if we can pass api_key to completion, when it's not in the env # this tests if we can pass api_key to completion, when it's not in the env
@ -727,7 +727,7 @@ def test_completion_azure():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure() test_completion_azure()
def test_azure_openai_ad_token(): def test_azure_openai_ad_token():
# this tests if the azure ad token is set in the request header # this tests if the azure ad token is set in the request header

View file

@ -79,3 +79,18 @@ model_list:
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)" - model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
litellm_params: litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
- model_name: dall-e-3
litellm_params:
model: dall-e-3
- model_name: dall-e-3
litellm_params:
model: "azure/dall-e-3-test"
api_version: "2023-12-01-preview"
api_base: "os.environ/AZURE_SWEDEN_API_BASE"
api_key: "os.environ/AZURE_SWEDEN_API_KEY"
- model_name: dall-e-2
litellm_params:
model: "azure/"
api_version: "2023-06-01-preview"
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"

View file

@ -17,6 +17,14 @@ def test_prompt_formatting():
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]" assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
def test_prompt_formatting_custom_model():
try:
prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface")
print(f"prompt: {prompt}")
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_prompt_formatting_custom_model()
# def logger_fn(user_model_dict): # def logger_fn(user_model_dict):
# return # return
# print(f"user_model_dict: {user_model_dict}") # print(f"user_model_dict: {user_model_dict}")

View file

@ -4,10 +4,11 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import logging
logging.basicConfig(level=logging.DEBUG)
load_dotenv() load_dotenv()
import os import os
import asyncio
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -18,20 +19,31 @@ def test_image_generation_openai():
litellm.set_verbose = True litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_openai() # test_image_generation_openai()
# def test_image_generation_azure(): def test_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure") response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
# print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure() # test_image_generation_azure()
# @pytest.mark.asyncio def test_image_generation_azure_dall_e_3():
# async def test_async_image_generation_openai(): litellm.set_verbose = True
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
# print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure_dall_e_3()
@pytest.mark.asyncio
async def test_async_image_generation_openai():
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
print(f"response: {response}")
assert len(response.data) > 0
# @pytest.mark.asyncio # asyncio.run(test_async_image_generation_openai())
# async def test_async_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3") @pytest.mark.asyncio
# print(f"response: {response}") async def test_async_image_generation_azure():
response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test")
print(f"response: {response}")

View file

@ -16,23 +16,61 @@
# user_message = "respond in 20 words. who are you?" # user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# def test_ollama_streaming():
# try:
# litellm.set_verbose = False
# messages = [
# {"role": "user", "content": "What is the weather like in Boston?"}
# ]
# functions = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ]
# response = litellm.completion(model="ollama/mistral",
# messages=messages,
# functions=functions,
# stream=True)
# for chunk in response:
# print(f"CHUNK: {chunk}")
# except Exception as e:
# print(e)
# test_ollama_streaming()
# async def test_async_ollama_streaming(): # async def test_async_ollama_streaming():
# try: # try:
# litellm.set_verbose = True # litellm.set_verbose = False
# response = await litellm.acompletion(model="ollama/mistral-openorca", # response = await litellm.acompletion(model="ollama/mistral-openorca",
# messages=[{"role": "user", "content": "Hey, how's it going?"}], # messages=[{"role": "user", "content": "Hey, how's it going?"}],
# stream=True) # stream=True)
# async for chunk in response: # async for chunk in response:
# print(chunk) # print(f"CHUNK: {chunk}")
# except Exception as e: # except Exception as e:
# print(e) # print(e)
# asyncio.run(test_async_ollama_streaming()) # # asyncio.run(test_async_ollama_streaming())
# def test_completion_ollama(): # def test_completion_ollama():
# try: # try:
# litellm.set_verbose = True
# response = completion( # response = completion(
# model="ollama/llama2", # model="ollama/mistral",
# messages=[{"role": "user", "content": "Hey, how's it going?"}], # messages=[{"role": "user", "content": "Hey, how's it going?"}],
# max_tokens=200, # max_tokens=200,
# request_timeout = 10, # request_timeout = 10,
@ -44,7 +82,87 @@
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama() # # test_completion_ollama()
# def test_completion_ollama_function_calling():
# try:
# litellm.set_verbose = True
# messages = [
# {"role": "user", "content": "What is the weather like in Boston?"}
# ]
# functions = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ]
# response = completion(
# model="ollama/mistral",
# messages=messages,
# functions=functions,
# max_tokens=200,
# request_timeout = 10,
# )
# for chunk in response:
# print(chunk)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama_function_calling()
# async def async_test_completion_ollama_function_calling():
# try:
# litellm.set_verbose = True
# messages = [
# {"role": "user", "content": "What is the weather like in Boston?"}
# ]
# functions = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ]
# response = await litellm.acompletion(
# model="ollama/mistral",
# messages=messages,
# functions=functions,
# max_tokens=200,
# request_timeout = 10,
# )
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # asyncio.run(async_test_completion_ollama_function_calling())
# def test_completion_ollama_with_api_base(): # def test_completion_ollama_with_api_base():
# try: # try:
@ -197,7 +315,7 @@
# ) # )
# print("Response from ollama/llava") # print("Response from ollama/llava")
# print(response) # print(response)
# test_ollama_llava() # # test_ollama_llava()
# # PROCESSED CHUNK PRE CHUNK CREATOR # # PROCESSED CHUNK PRE CHUNK CREATOR

View file

@ -99,8 +99,6 @@ def test_embedding(client):
def test_chat_completion(client): def test_chat_completion(client):
try: try:
# Your test data # Your test data
print("initialized proxy")
litellm.set_verbose=False litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn( my_custom_logger = get_instance_fn(

View file

@ -68,6 +68,7 @@ def test_chat_completion_exception_azure(client):
# make an openai client to call _make_status_error_from_response # make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything") openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response) openai_exception = openai_client._make_status_error_from_response(response=response)
print(openai_exception)
assert isinstance(openai_exception, openai.AuthenticationError) assert isinstance(openai_exception, openai.AuthenticationError)
except Exception as e: except Exception as e:

View file

@ -101,7 +101,7 @@ def test_chat_completion_azure(client_no_auth):
# Run the test # Run the test
# test_chat_completion_azure() # test_chat_completion_azure()
### EMBEDDING
def test_embedding(client_no_auth): def test_embedding(client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -161,7 +161,30 @@ def test_sagemaker_embedding(client_no_auth):
# Run the test # Run the test
# test_embedding() # test_embedding()
#### IMAGE GENERATION
def test_img_gen(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth
try:
test_data = {
"model": "dall-e-3",
"prompt": "A cute baby sea otter",
"n": 1,
"size": "1024x1024"
}
response = client_no_auth.post("/v1/images/generations", json=test_data)
assert response.status_code == 200
result = response.json()
print(len(result["data"][0]["url"]))
assert len(result["data"][0]["url"]) > 10
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
#### ADDITIONAL
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci") # @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
def test_add_new_model(client_no_auth): def test_add_new_model(client_no_auth):
global headers global headers

View file

@ -423,6 +423,94 @@ def test_function_calling_on_router():
# test_function_calling_on_router() # test_function_calling_on_router()
### IMAGE GENERATION
@pytest.mark.asyncio
async def test_aimg_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
},
{
"model_name": "dall-e-2",
"litellm_params": {
"model": "azure/",
"api_version": "2023-06-01-preview",
"api_base": os.getenv("AZURE_API_BASE"),
"api_key": os.getenv("AZURE_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = await router.aimage_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
response = await router.aimage_generation(
model="dall-e-2",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_aimg_gen_on_router())
def test_img_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = router.image_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
# test_img_gen_on_router()
###
def test_aembedding_on_router(): def test_aembedding_on_router():
litellm.set_verbose = True litellm.set_verbose = True
try: try:
@ -556,7 +644,7 @@ async def test_mistral_on_router():
] ]
) )
print(response) print(response)
asyncio.run(test_mistral_on_router()) # asyncio.run(test_mistral_on_router())
def test_openai_completion_on_router(): def test_openai_completion_on_router():
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream

View file

@ -169,17 +169,37 @@ def test_text_completion_stream():
# test_text_completion_stream() # test_text_completion_stream()
async def test_text_completion_async_stream(): # async def test_text_completion_async_stream():
try: # try:
response = await atext_completion( # response = await atext_completion(
model="text-completion-openai/text-davinci-003", # model="text-completion-openai/text-davinci-003",
prompt="good morning", # prompt="good morning",
stream=True, # stream=True,
max_tokens=10, # max_tokens=10,
) # )
async for chunk in response: # async for chunk in response:
print(f"chunk: {chunk}") # print(f"chunk: {chunk}")
except Exception as e: # except Exception as e:
pytest.fail(f"GOT exception for HF In streaming{e}") # pytest.fail(f"GOT exception for HF In streaming{e}")
asyncio.run(test_text_completion_async_stream()) # asyncio.run(test_text_completion_async_stream())
def test_async_text_completion():
litellm.set_verbose = True
print('test_async_text_completion')
async def test_get_response():
try:
response = await litellm.atext_completion(
model="gpt-3.5-turbo-instruct",
prompt="good morning",
stream=False,
max_tokens=10
)
print(f"response: {response}")
except litellm.Timeout as e:
print(e)
except Exception as e:
print(e)
asyncio.run(test_get_response())
test_async_text_completion()

View file

@ -1,13 +1,13 @@
{ {
"type": "service_account", "type": "service_account",
"project_id": "hardy-device-386718", "project_id": "reliablekeys",
"private_key_id": "", "private_key_id": "",
"private_key": "", "private_key": "",
"client_email": "litellm-vertexai-ci-cd@hardy-device-386718.iam.gserviceaccount.com", "client_email": "73470430121-compute@developer.gserviceaccount.com",
"client_id": "110281020501213430254", "client_id": "108560959659377334173",
"auth_uri": "https://accounts.google.com/o/oauth2/auth", "auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token", "token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/litellm-vertexai-ci-cd%40hardy-device-386718.iam.gserviceaccount.com", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com",
"universe_domain": "googleapis.com" "universe_domain": "googleapis.com"
} }

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.15.1" version = "1.15.4"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"
@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.15.1" version = "1.15.4"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -3,24 +3,24 @@ anyio==4.2.0 # openai + http req.
openai>=1.0.0 # openai req. openai>=1.0.0 # openai req.
fastapi # server dep fastapi # server dep
pydantic>=2.5 # openai req. pydantic>=2.5 # openai req.
appdirs # server dep appdirs==1.4.4 # server dep
backoff # server dep backoff==2.2.1 # server dep
pyyaml # server dep pyyaml==6.0 # server dep
uvicorn # server dep uvicorn==0.22.0 # server dep
boto3 # aws bedrock/sagemaker calls boto3==1.28.58 # aws bedrock/sagemaker calls
redis # caching redis==4.6.0 # caching
prisma # for db prisma==0.11.0 # for db
mangum # for aws lambda functions mangum==0.17.0 # for aws lambda functions
google-generativeai # for vertex ai calls google-generativeai==0.1.0 # for vertex ai calls
traceloop-sdk==0.5.3 # for open telemetry logging traceloop-sdk==0.5.3 # for open telemetry logging
langfuse==1.14.0 # for langfuse self-hosted logging langfuse==1.14.0 # for langfuse self-hosted logging
### LITELLM PACKAGE DEPENDENCIES ### LITELLM PACKAGE DEPENDENCIES
python-dotenv>=0.2.0 # for env python-dotenv>=0.2.0 # for env
tiktoken>=0.4.0 # for calculating usage tiktoken>=0.4.0 # for calculating usage
importlib-metadata>=6.8.0 # for random utils importlib-metadata>=6.8.0 # for random utils
tokenizers # for calculating usage tokenizers==0.14.0 # for calculating usage
click # for proxy cli click==8.1.7 # for proxy cli
jinja2==3.1.2 # for prompt templates jinja2==3.1.2 # for prompt templates
certifi>=2023.7.22 # [TODO] clean up certifi>=2023.7.22 # [TODO] clean up
aiohttp # for network calls aiohttp==3.8.4 # for network calls
#### ####