Merge branch 'main' into improve-langchain-integration

This commit is contained in:
Max Deichmann 2023-12-21 23:50:01 +01:00 committed by GitHub
commit 1c68f5557d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 1169 additions and 322 deletions

View file

@ -1,66 +0,0 @@
name: Build Docker Images
on:
workflow_dispatch:
inputs:
tag:
description: "The tag version you want to build"
jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
env:
REPO_NAME: ${{ github.repository }}
steps:
- name: Convert repo name to lowercase
run: echo "REPO_NAME=$(echo "$REPO_NAME" | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GHCR_TOKEN }}
logout: false
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: ghcr.io/berriai/litellm
- name: Get tag to build
id: tag
run: |
echo "latest=ghcr.io/${{ env.REPO_NAME }}:latest" >> $GITHUB_OUTPUT
if [[ -z "${{ github.event.inputs.tag }}" ]]; then
echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.ref_name }}" >> $GITHUB_OUTPUT
else
echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT
fi
- name: Debug Info
run: |
echo "GHCR_TOKEN=${{ secrets.GHCR_TOKEN }}"
echo "REPO_NAME=${{ env.REPO_NAME }}"
echo "ACTOR=${{ github.actor }}"
- name: Build and push container image to registry
uses: docker/build-push-action@v2
with:
push: true
tags: ghcr.io/${{ env.REPO_NAME }}:${{ github.sha }}
file: ./Dockerfile
- name: Build and release Docker images
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64
tags: |
${{ steps.tag.outputs.latest }}
${{ steps.tag.outputs.versioned }}
labels: ${{ steps.meta.outputs.labels }}
push: true

View file

@ -1,10 +1,12 @@
#
name: Build & Publich to GHCR
name: Build & Publish to GHCR
on:
workflow_dispatch:
inputs:
tag:
description: "The tag version you want to build"
release:
types: [published]
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
env:
@ -19,7 +21,7 @@ jobs:
permissions:
contents: read
packages: write
#
#
steps:
- name: Checkout repository
uses: actions/checkout@v4
@ -44,5 +46,5 @@ jobs:
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag }} # Add the input tag to the image tags
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name }} # if a tag is provided, use that, otherwise use the release tag
labels: ${{ steps.meta.outputs.labels }}

View file

@ -1,13 +1,19 @@
# Exception Mapping
LiteLLM maps exceptions across all providers to their OpenAI counterparts.
- Rate Limit Errors
- Invalid Request Errors
- Authentication Errors
- Timeout Errors `openai.APITimeoutError`
- ServiceUnavailableError
- APIError
- APIConnectionError
| Status Code | Error Type |
|-------------|--------------------------|
| 400 | BadRequestError |
| 401 | AuthenticationError |
| 403 | PermissionDeniedError |
| 404 | NotFoundError |
| 422 | UnprocessableEntityError |
| 429 | RateLimitError |
| >=500 | InternalServerError |
| N/A | ContextWindowExceededError|
| N/A | APIConnectionError |
Base case we return APIConnectionError
@ -83,6 +89,7 @@ Base case - we return the original exception.
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------|
| Anthropic | ✅ | ✅ | ✅ | ✅ | |
| OpenAI | ✅ | ✅ |✅ |✅ |✅|
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅|
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ |
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ |
| Huggingface | ✅ | ✅ | ✅ | ✅ | |

View file

@ -15,7 +15,7 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
## Pre-Requisites
Ensure you have run `pip install langfuse` for this integration
```shell
pip install langfuse litellm
pip install langfuse==1.14.0 litellm
```
## Quick Start

View file

@ -41,6 +41,18 @@ def send_slack_alert(
# get it from https://api.slack.com/messaging/webhooks
slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
# Remove api_key from kwargs under litellm_params
if kwargs.get('litellm_params'):
kwargs['litellm_params'].pop('api_key', None)
if kwargs['litellm_params'].get('metadata'):
kwargs['litellm_params']['metadata'].pop('deployment', None)
# Remove deployment under metadata
if kwargs.get('metadata'):
kwargs['metadata'].pop('deployment', None)
# Prevent api_key from being logged
if kwargs.get('api_key'):
kwargs.pop('api_key', None)
# Define the text payload, send data available in litellm custom_callbacks
text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
"""
@ -90,4 +102,4 @@ response = litellm.completion(
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai

View file

@ -77,6 +77,45 @@ Ollama supported models: https://github.com/jmorganca/ollama
| Nous-Hermes 13B | `completion(model='ollama/nous-hermes:13b', messages, api_base="http://localhost:11434", stream=True)` |
| Wizard Vicuna Uncensored | `completion(model='ollama/wizard-vicuna', messages, api_base="http://localhost:11434", stream=True)` |
## Ollama Vision Models
| Model Name | Function Call |
|------------------|--------------------------------------|
| llava | `completion('ollama/llava', messages)` |
#### Using Ollama Vision Models
Call `ollama/llava` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models)
LiteLLM Supports the following image types passed in `url`
- Base64 encoded svgs
**Example Request**
```python
import litellm
response = litellm.completion(
model = "ollama/llava",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"
}
}
]
}
],
)
print(response)
```
## LiteLLM/Ollama Docker Image

View file

@ -1,5 +1,5 @@
# OpenRouter
LiteLLM supports all the text models from [OpenRouter](https://openrouter.ai/docs)
LiteLLM supports all the text / chat / vision models from [OpenRouter](https://openrouter.ai/docs)
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_OpenRouter.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@ -20,10 +20,10 @@ response = completion(
)
```
### OpenRouter Completion Models
## OpenRouter Completion Models
| Model Name | Function Call | Required OS Variables |
|---------------------------|-----------------------------------------------------|--------------------------------------------------------------|
| Model Name | Function Call |
|---------------------------|-----------------------------------------------------|
| openrouter/openai/gpt-3.5-turbo | `completion('openrouter/openai/gpt-3.5-turbo', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
| openrouter/openai/gpt-3.5-turbo-16k | `completion('openrouter/openai/gpt-3.5-turbo-16k', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
| openrouter/openai/gpt-4 | `completion('openrouter/openai/gpt-4', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
@ -35,3 +35,19 @@ response = completion(
| openrouter/meta-llama/llama-2-13b-chat | `completion('openrouter/meta-llama/llama-2-13b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
| openrouter/meta-llama/llama-2-70b-chat | `completion('openrouter/meta-llama/llama-2-70b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
## Passing OpenRouter Params - transforms, models, route
Pass `transforms`, `models`, `route`as arguments to `litellm.completion()`
```python
import os
from litellm import completion
os.environ["OPENROUTER_API_KEY"] = ""
response = completion(
model="openrouter/google/palm-2-chat-bison",
messages=messages,
transforms = [""],
route= ""
)
```

View file

@ -59,4 +59,21 @@ $ litellm /path/to/config.yaml
3. Query health endpoint:
```
curl --location 'http://0.0.0.0:8000/health'
```
```
## Embedding Models
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
```yaml
model_list:
- model_name: azure-embedding-model
litellm_params:
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
mode: embedding # 👈 ADD THIS
```

View file

@ -461,7 +461,7 @@ We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this
**Step 1** Install langfuse
```shell
pip install langfuse
pip install langfuse==1.14.0
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`

View file

@ -1,6 +1,17 @@
# Model Management
Add new models + Get model info without restarting proxy.
## In Config.yaml
```yaml
model_list:
- model_name: text-davinci-003
litellm_params:
model: "text-completion-openai/text-davinci-003"
model_info:
metadata: "here's additional metadata on the model" # returned via GET /model/info
```
## Get Model Information
Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.

View file

@ -349,6 +349,12 @@ litellm --config your_config.yaml
[**More Info**](./configs.md)
## Server Endpoints
:::note
You can see Swagger Docs for the server on root http://0.0.0.0:8000
:::
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
- POST `/completions` - completions endpoint
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints

View file

@ -39,14 +39,17 @@ litellm --config /path/to/config.yaml
```shell
curl 'http://0.0.0.0:8000/key/generate' \
--h 'Authorization: Bearer sk-1234' \
--d '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}'
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m","metadata": {"user": "ishaan@berri.ai", "team": "core-infra"}}'
```
- `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server.
- `duration`: *str or null (optional)* Specify the length of time the token is valid for. If null, default is set to 1 hour. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- `metadata`: *dict or null (optional)* Pass metadata for the created token. If null defaults to {}
Expected response:
```python

View file

@ -6,6 +6,7 @@ from typing import Callable, Optional
from litellm import OpenAIConfig
import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception):
@ -261,7 +262,10 @@ class AzureChatCompletion(BaseLLM):
exception_mapping_worked = True
raise e
except Exception as e:
raise AzureOpenAIError(status_code=500, message=str(e))
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(self,
logging_obj,
@ -463,13 +467,52 @@ class AzureChatCompletion(BaseLLM):
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
self,
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
client=None,
logging_obj=None
):
response = None
try:
if client is None:
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data)
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def image_generation(self,
prompt: list,
prompt: str,
timeout: float,
model: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None,
logging_obj=None,
optional_params=None,
client=None,
@ -477,9 +520,12 @@ class AzureChatCompletion(BaseLLM):
):
exception_mapping_worked = False
try:
model = model
if model and len(model) > 0:
model = model
else:
model = None
data = {
# "model": model,
"model": model,
"prompt": prompt,
**optional_params
}
@ -487,12 +533,26 @@ class AzureChatCompletion(BaseLLM):
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
# if aembedding == True:
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
# return response
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
return response
if client is None:
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else:
azure_client = client

View file

@ -0,0 +1,118 @@
import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async implementation of custom http transport
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = await super().handle_async_request(request)
await response.aread()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after")) or 10)
response = await super().handle_async_request(request)
await response.aread()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
"""
def handle_request(
self,
request: httpx.Request,
) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
response = super().handle_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = super().handle_request(request)
response.read()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after")) or 10)
response = super().handle_request(request)
response.read()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return super().handle_request(request)

View file

@ -1,14 +1,10 @@
import requests, types, time
import json
import json, uuid
import traceback
from typing import Optional
import litellm
import httpx, aiohttp, asyncio
try:
from async_generator import async_generator, yield_ # optional dependency
async_generator_imported = True
except ImportError:
async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement
from .prompt_templates.factory import prompt_factory, custom_prompt
class OllamaError(Exception):
def __init__(self, status_code, message):
@ -106,9 +102,8 @@ class OllamaConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
# ollama implementation
def get_ollama_response_stream(
def get_ollama_response(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
@ -129,6 +124,7 @@ def get_ollama_response_stream(
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False)
data = {
"model": model,
"prompt": prompt,
@ -146,9 +142,41 @@ def get_ollama_response_stream(
else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
return response
else:
elif optional_params.get("stream", False):
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
)
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = response.json()
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
@ -157,13 +185,15 @@ def ollama_completion_stream(url, data, logging_obj):
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
try:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
@ -194,39 +224,31 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
completion_string = ""
async for line in resp.content.iter_any():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "error" in j:
completion_obj = {
"role": "assistant",
"content": "",
"error": j
}
raise Exception(f"OllamError - {chunk}")
if "response" in j:
completion_obj = {
"role": "assistant",
"content": j["response"],
}
completion_string = completion_string + completion_obj["content"]
except Exception as e:
traceback.print_exc()
## LOGGING
logging_obj.post_call(
input=data['prompt'],
api_key="",
original_response=resp.text,
additional_args={
"headers": None,
"api_base": url,
},
)
response_json = await resp.json()
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = completion_string
if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
completion_tokens = len(encoding.encode(completion_string))
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response
except Exception as e:
traceback.print_exc()
raise e

View file

@ -284,7 +284,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e:
except Exception as e:
raise e
def streaming(self,
@ -445,6 +445,43 @@ class OpenAIChatCompletion(BaseLLM):
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
self,
prompt: str,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
client=None,
max_retries=None,
logging_obj=None
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data) # type: ignore
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def image_generation(self,
model: Optional[str],
prompt: str,
@ -631,24 +668,27 @@ class OpenAITextCompletion(BaseLLM):
api_key: str,
model: str):
async with httpx.AsyncClient() as client:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
try:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e:
raise e
def streaming(self,
logging_obj,
@ -687,9 +727,12 @@ class OpenAITextCompletion(BaseLLM):
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
try:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e

View file

@ -348,7 +348,7 @@ def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/r
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = "The following functions are available to you:"
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:"
for function in functions:
function_prompt += f"""\n{function}\n"""
@ -425,6 +425,6 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return alpaca_pt(messages=messages)
else:
return hf_chat_template(original_model_name, messages)
except:
except Exception as e:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -181,6 +181,7 @@ def completion(
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
vertexai.init(
@ -193,6 +194,16 @@ def completion(
if k not in optional_params:
optional_params[k] = v
## Process safety settings into format expected by vertex AI
safety_settings = None
if "safety_settings" in optional_params:
safety_settings = optional_params.pop("safety_settings")
if not isinstance(safety_settings, list):
raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts")
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
@ -238,16 +249,16 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
@ -258,12 +269,13 @@ def completion(
content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=True
)
optional_params["stream"] = True
@ -276,7 +288,8 @@ def completion(
## LLM Call
response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params)
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response.text
response_obj = response._raw_response

View file

@ -1329,23 +1329,11 @@ def completion(
optional_params["images"] = images
## LOGGING
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
generator = ollama.get_ollama_response(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True or optional_params.get("stream", False) == True:
return generator
else:
response_string = ""
for chunk in generator:
response_string+=chunk['content']
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = response_string
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = len(encoding.encode(prompt)) # type: ignore
completion_tokens = len(encoding.encode(response_string))
model_response["usage"] = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
response = model_response
response = generator
elif (
custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co"
@ -2026,11 +2014,9 @@ async def atext_completion(*args, **kwargs):
response = text_completion(*args, **kwargs)
else:
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(response):
response = await response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -2205,7 +2191,8 @@ def text_completion(
if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response
if kwargs.get("acompletion", False) == True:
return response
transformed_logprobs = None
# only supported for TGI models
try:
@ -2243,6 +2230,49 @@ def moderation(input: str, api_key: Optional[str]=None):
return response
##### Image Generation #######################
@client
async def aimage_generation(*args, **kwargs):
"""
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `embedding` function.
- `kwargs` (dict): Keyword arguments to be passed to the `embedding` function.
Returns:
- `response` (Any): The response returned by the `embedding` function.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["aimg_generation"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(image_generation, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
@client
def image_generation(prompt: str,
model: Optional[str]=None,
@ -2264,6 +2294,7 @@ def image_generation(prompt: str,
Currently supports just Azure + OpenAI.
"""
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get('proxy_server_request', None)
@ -2277,7 +2308,7 @@ def image_generation(prompt: str,
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
litellm_params = ["metadata", "aimg_generation", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(n=n,
@ -2320,10 +2351,9 @@ def image_generation(prompt: str,
get_secret("AZURE_AD_TOKEN")
)
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
pass
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation)
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation)
return model_response

View file

@ -120,6 +120,8 @@ class GenerateKeyRequest(LiteLLMBase):
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
max_budget: Optional[float] = None
class UpdateKeyRequest(LiteLLMBase):
key: str
@ -130,21 +132,8 @@ class UpdateKeyRequest(LiteLLMBase):
spend: Optional[float] = None
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
metadata: Optional[dict] = {}
max_budget: Optional[float] = None
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
@ -158,6 +147,20 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
duration: str = "1h"
metadata: dict = {}
max_budget: Optional[float] = None
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class ConfigGeneralSettings(LiteLLMBase):
"""

View file

@ -96,7 +96,7 @@ async def _perform_health_check(model_list: list):
async def perform_health_check(model_list: list, model: Optional[str] = None):
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None):
"""
Perform a health check on the system.
@ -104,7 +104,10 @@ async def perform_health_check(model_list: list, model: Optional[str] = None):
(bool): True if the health check passes, False otherwise.
"""
if not model_list:
return [], []
if cli_model:
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}]
else:
return [], []
if model is not None:
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]

View file

@ -0,0 +1,35 @@
from typing import Optional
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_budget = user_api_key_dict.max_budget
curr_spend = user_api_key_dict.spend
if api_key is None:
return
if max_budget is None:
return
if curr_spend is None:
return
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")

View file

@ -47,7 +47,7 @@ litellm_settings:
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
general_settings:
# general_settings:
environment_variables:
# otel: True # OpenTelemetry Logger

View file

@ -110,7 +110,7 @@ import json
import logging
from typing import Union
app = FastAPI(docs_url="/", title="LiteLLM API")
app = FastAPI(docs_url="/", title="LiteLLM API", description="Proxy Server to call 100+ LLMs in the OpenAI format")
router = APIRouter()
origins = ["*"]
@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
async def generate_key_helper_fn(duration: Optional[str],
models: list,
aliases: dict,
config: dict,
spend: float,
max_budget: Optional[float]=None,
token: Optional[str]=None,
user_id: Optional[str]=None,
max_parallel_requests: Optional[int]=None,
metadata: Optional[dict] = {},):
global prisma_client
if prisma_client is None:
@ -653,6 +662,7 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
metadata_json = json.dumps(metadata)
user_id = user_id or str(uuid.uuid4())
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
@ -664,7 +674,9 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
"config": config_json,
"spend": spend,
"user_id": user_id,
"max_parallel_requests": max_parallel_requests
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
"max_budget": max_budget
}
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e:
@ -774,6 +786,7 @@ def initialize(
print(f"\033[1;34mLiteLLM: Test your local proxy with: \"litellm --test\" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n")
print(f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n")
print("\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n")
print(f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:8000 \033[0m\n")
# for streaming
def data_generator(response):
print_verbose("inside generator")
@ -871,9 +884,9 @@ def model_list():
object="list",
)
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
try:
@ -1044,8 +1057,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
detail=error_msg
)
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"])
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global proxy_logging_obj
try:
@ -1124,6 +1137,72 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
detail=error_msg
)
@router.post("/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"])
@router.post("/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"])
async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data) # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("image_generation_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aimage_generation(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(**data, specific_deployment = True)
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.aimage_generation(**data) # ensure this goes the llm_router, router will do the correct alias mapping
else:
response = await litellm.aimage_generation(**data)
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
#### KEY MANAGEMENT ####
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
@ -1140,6 +1219,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns:
- key: (str) The generated api key
@ -1165,7 +1245,6 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
raise Exception("Not connected to DB!")
non_default_values = {k: v for k, v in data_json.items() if v is not None}
print(f"non_default_values: {non_default_values}")
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
return {"key": key, **non_default_values}
# update based on remaining passed in values
@ -1514,9 +1593,18 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query
```
else, the health checks will be run on models when /health is called.
"""
global health_check_results, use_background_health_checks
global health_check_results, use_background_health_checks, user_model
if llm_model_list is None:
# if no router set, check if user set a model using litellm --model ollama/llama2
if user_model is not None:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=[], cli_model=user_model)
return {
"healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
}
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Model list not initialized"},

View file

@ -17,4 +17,6 @@ model LiteLLM_VerificationToken {
config Json @default("{}")
user_id String?
max_parallel_requests Int?
metadata Json @default("{}")
max_budget Float?
}

View file

@ -1,9 +1,10 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio, copy
import os, subprocess, hashlib, importlib, asyncio, copy, json
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
def print_verbose(print_statement):
if litellm.set_verbose:
@ -23,11 +24,13 @@ class ProxyLogging:
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
pass
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
@ -147,6 +150,14 @@ class PrismaClient:
return hashed_token
def jsonify_object(self, data: dict) -> dict:
db_data = copy.deepcopy(data)
for k, v in db_data.items():
if isinstance(v, dict):
db_data[k] = json.dumps(v)
return db_data
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
@ -193,9 +204,8 @@ class PrismaClient:
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = copy.deepcopy(data)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
@ -228,7 +238,7 @@ class PrismaClient:
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data = copy.deepcopy(data)
db_data = self.jsonify_object(data=data)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={

View file

@ -7,7 +7,7 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid
@ -18,6 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
import copy
class Router:
"""
@ -84,11 +85,11 @@ class Router:
self.set_verbose = set_verbose
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
if model_list:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
self.deployment_latency_map = {}
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
@ -166,7 +167,7 @@ class Router:
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
### COMPLETION + EMBEDDING FUNCTIONS
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
def completion(self,
model: str,
@ -260,6 +261,94 @@ class Router:
self.fail_calls[model_name] +=1
raise e
def image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
async def aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
def text_completion(self,
model: str,
prompt: str,
@ -436,7 +525,6 @@ class Router:
async def async_function_with_retries(self, *args, **kwargs):
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
backoff_factor = 1
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
@ -1009,14 +1097,16 @@ class Router:
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries
max_retries=max_retries,
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries
max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
@ -1024,7 +1114,7 @@ class Router:
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
max_retries=max_retries,
)
model["stream_client"] = openai.AzureOpenAI(

View file

@ -20,7 +20,7 @@ import tempfile
litellm.num_retries = 3
litellm.cache = None
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
messages = [{"content": user_message, "role": "user"}]
def load_vertex_ai_credentials():
@ -91,7 +91,7 @@ def test_vertex_ai():
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718"
litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
@ -113,7 +113,7 @@ def test_vertex_ai():
def test_vertex_ai_stream():
load_vertex_ai_credentials()
litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718"
litellm.vertex_project = "reliablekeys"
import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models

View file

@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider():
# test_completion_hf_model_no_provider()
# def test_completion_openai_azure_with_functions():
# function1 = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
# },
# "required": ["location"],
# },
# }
# ]
# try:
# messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
# response = completion(
# model="azure/chatgpt-functioncalling", messages=messages, functions=function1
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_openai_azure_with_functions()
def test_completion_anyscale_with_functions():
function1 = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_anyscale_with_functions()
def test_completion_azure_key_completion_arg():
# this tests if we can pass api_key to completion, when it's not in the env
@ -727,7 +727,7 @@ def test_completion_azure():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure()
test_completion_azure()
def test_azure_openai_ad_token():
# this tests if the azure ad token is set in the request header

View file

@ -78,4 +78,19 @@ model_list:
model: "bedrock/amazon.titan-embed-text-v1"
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
- model_name: dall-e-3
litellm_params:
model: dall-e-3
- model_name: dall-e-3
litellm_params:
model: "azure/dall-e-3-test"
api_version: "2023-12-01-preview"
api_base: "os.environ/AZURE_SWEDEN_API_BASE"
api_key: "os.environ/AZURE_SWEDEN_API_KEY"
- model_name: dall-e-2
litellm_params:
model: "azure/"
api_version: "2023-06-01-preview"
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"

View file

@ -17,6 +17,14 @@ def test_prompt_formatting():
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
def test_prompt_formatting_custom_model():
try:
prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface")
print(f"prompt: {prompt}")
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_prompt_formatting_custom_model()
# def logger_fn(user_model_dict):
# return
# print(f"user_model_dict: {user_model_dict}")

View file

@ -4,10 +4,11 @@
import sys, os
import traceback
from dotenv import load_dotenv
import logging
logging.basicConfig(level=logging.DEBUG)
load_dotenv()
import os
import asyncio
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -18,20 +19,31 @@ def test_image_generation_openai():
litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_openai()
# def test_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
# print(f"response: {response}")
def test_image_generation_azure():
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure()
# @pytest.mark.asyncio
# async def test_async_image_generation_openai():
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
# print(f"response: {response}")
def test_image_generation_azure_dall_e_3():
litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure_dall_e_3()
@pytest.mark.asyncio
async def test_async_image_generation_openai():
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
print(f"response: {response}")
assert len(response.data) > 0
# @pytest.mark.asyncio
# async def test_async_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3")
# print(f"response: {response}")
# asyncio.run(test_async_image_generation_openai())
@pytest.mark.asyncio
async def test_async_image_generation_azure():
response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test")
print(f"response: {response}")

View file

@ -16,23 +16,61 @@
# user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}]
# def test_ollama_streaming():
# try:
# litellm.set_verbose = False
# messages = [
# {"role": "user", "content": "What is the weather like in Boston?"}
# ]
# functions = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ]
# response = litellm.completion(model="ollama/mistral",
# messages=messages,
# functions=functions,
# stream=True)
# for chunk in response:
# print(f"CHUNK: {chunk}")
# except Exception as e:
# print(e)
# test_ollama_streaming()
# async def test_async_ollama_streaming():
# try:
# litellm.set_verbose = True
# litellm.set_verbose = False
# response = await litellm.acompletion(model="ollama/mistral-openorca",
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
# stream=True)
# async for chunk in response:
# print(chunk)
# print(f"CHUNK: {chunk}")
# except Exception as e:
# print(e)
# asyncio.run(test_async_ollama_streaming())
# # asyncio.run(test_async_ollama_streaming())
# def test_completion_ollama():
# try:
# litellm.set_verbose = True
# response = completion(
# model="ollama/llama2",
# model="ollama/mistral",
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
# max_tokens=200,
# request_timeout = 10,
@ -44,7 +82,87 @@
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_ollama()
# # test_completion_ollama()
# def test_completion_ollama_function_calling():
# try:
# litellm.set_verbose = True
# messages = [
# {"role": "user", "content": "What is the weather like in Boston?"}
# ]
# functions = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ]
# response = completion(
# model="ollama/mistral",
# messages=messages,
# functions=functions,
# max_tokens=200,
# request_timeout = 10,
# )
# for chunk in response:
# print(chunk)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama_function_calling()
# async def async_test_completion_ollama_function_calling():
# try:
# litellm.set_verbose = True
# messages = [
# {"role": "user", "content": "What is the weather like in Boston?"}
# ]
# functions = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ]
# response = await litellm.acompletion(
# model="ollama/mistral",
# messages=messages,
# functions=functions,
# max_tokens=200,
# request_timeout = 10,
# )
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # asyncio.run(async_test_completion_ollama_function_calling())
# def test_completion_ollama_with_api_base():
# try:
@ -197,7 +315,7 @@
# )
# print("Response from ollama/llava")
# print(response)
# test_ollama_llava()
# # test_ollama_llava()
# # PROCESSED CHUNK PRE CHUNK CREATOR

View file

@ -99,8 +99,6 @@ def test_embedding(client):
def test_chat_completion(client):
try:
# Your test data
print("initialized proxy")
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(

View file

@ -68,6 +68,7 @@ def test_chat_completion_exception_azure(client):
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
print(openai_exception)
assert isinstance(openai_exception, openai.AuthenticationError)
except Exception as e:

View file

@ -101,7 +101,7 @@ def test_chat_completion_azure(client_no_auth):
# Run the test
# test_chat_completion_azure()
### EMBEDDING
def test_embedding(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth
@ -161,7 +161,30 @@ def test_sagemaker_embedding(client_no_auth):
# Run the test
# test_embedding()
#### IMAGE GENERATION
def test_img_gen(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth
try:
test_data = {
"model": "dall-e-3",
"prompt": "A cute baby sea otter",
"n": 1,
"size": "1024x1024"
}
response = client_no_auth.post("/v1/images/generations", json=test_data)
assert response.status_code == 200
result = response.json()
print(len(result["data"][0]["url"]))
assert len(result["data"][0]["url"]) > 10
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
#### ADDITIONAL
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
def test_add_new_model(client_no_auth):
global headers

View file

@ -423,6 +423,94 @@ def test_function_calling_on_router():
# test_function_calling_on_router()
### IMAGE GENERATION
@pytest.mark.asyncio
async def test_aimg_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
},
{
"model_name": "dall-e-2",
"litellm_params": {
"model": "azure/",
"api_version": "2023-06-01-preview",
"api_base": os.getenv("AZURE_API_BASE"),
"api_key": os.getenv("AZURE_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = await router.aimage_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
response = await router.aimage_generation(
model="dall-e-2",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_aimg_gen_on_router())
def test_img_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = router.image_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
# test_img_gen_on_router()
###
def test_aembedding_on_router():
litellm.set_verbose = True
try:
@ -556,7 +644,7 @@ async def test_mistral_on_router():
]
)
print(response)
asyncio.run(test_mistral_on_router())
# asyncio.run(test_mistral_on_router())
def test_openai_completion_on_router():
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream

View file

@ -169,17 +169,37 @@ def test_text_completion_stream():
# test_text_completion_stream()
async def test_text_completion_async_stream():
try:
response = await atext_completion(
model="text-completion-openai/text-davinci-003",
prompt="good morning",
stream=True,
max_tokens=10,
)
async for chunk in response:
print(f"chunk: {chunk}")
except Exception as e:
pytest.fail(f"GOT exception for HF In streaming{e}")
# async def test_text_completion_async_stream():
# try:
# response = await atext_completion(
# model="text-completion-openai/text-davinci-003",
# prompt="good morning",
# stream=True,
# max_tokens=10,
# )
# async for chunk in response:
# print(f"chunk: {chunk}")
# except Exception as e:
# pytest.fail(f"GOT exception for HF In streaming{e}")
asyncio.run(test_text_completion_async_stream())
# asyncio.run(test_text_completion_async_stream())
def test_async_text_completion():
litellm.set_verbose = True
print('test_async_text_completion')
async def test_get_response():
try:
response = await litellm.atext_completion(
model="gpt-3.5-turbo-instruct",
prompt="good morning",
stream=False,
max_tokens=10
)
print(f"response: {response}")
except litellm.Timeout as e:
print(e)
except Exception as e:
print(e)
asyncio.run(test_get_response())
test_async_text_completion()

View file

@ -1,13 +1,13 @@
{
"type": "service_account",
"project_id": "hardy-device-386718",
"project_id": "reliablekeys",
"private_key_id": "",
"private_key": "",
"client_email": "litellm-vertexai-ci-cd@hardy-device-386718.iam.gserviceaccount.com",
"client_id": "110281020501213430254",
"client_email": "73470430121-compute@developer.gserviceaccount.com",
"client_id": "108560959659377334173",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/litellm-vertexai-ci-cd%40hardy-device-386718.iam.gserviceaccount.com",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com",
"universe_domain": "googleapis.com"
}

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.15.1"
version = "1.15.4"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"
@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.15.1"
version = "1.15.4"
version_files = [
"pyproject.toml:^version"
]

View file

@ -3,24 +3,24 @@ anyio==4.2.0 # openai + http req.
openai>=1.0.0 # openai req.
fastapi # server dep
pydantic>=2.5 # openai req.
appdirs # server dep
backoff # server dep
pyyaml # server dep
uvicorn # server dep
boto3 # aws bedrock/sagemaker calls
redis # caching
prisma # for db
mangum # for aws lambda functions
google-generativeai # for vertex ai calls
appdirs==1.4.4 # server dep
backoff==2.2.1 # server dep
pyyaml==6.0 # server dep
uvicorn==0.22.0 # server dep
boto3==1.28.58 # aws bedrock/sagemaker calls
redis==4.6.0 # caching
prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions
google-generativeai==0.1.0 # for vertex ai calls
traceloop-sdk==0.5.3 # for open telemetry logging
langfuse==1.14.0 # for langfuse self-hosted logging
### LITELLM PACKAGE DEPENDENCIES
python-dotenv>=0.2.0 # for env
tiktoken>=0.4.0 # for calculating usage
importlib-metadata>=6.8.0 # for random utils
tokenizers # for calculating usage
click # for proxy cli
tokenizers==0.14.0 # for calculating usage
click==8.1.7 # for proxy cli
jinja2==3.1.2 # for prompt templates
certifi>=2023.7.22 # [TODO] clean up
aiohttp # for network calls
aiohttp==3.8.4 # for network calls
####