forked from phoenix/litellm-mirror
Merge branch 'main' into improve-langchain-integration
This commit is contained in:
commit
1c68f5557d
41 changed files with 1169 additions and 322 deletions
66
.github/workflows/docker.yml
vendored
66
.github/workflows/docker.yml
vendored
|
@ -1,66 +0,0 @@
|
|||
name: Build Docker Images
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag version you want to build"
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
env:
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
steps:
|
||||
- name: Convert repo name to lowercase
|
||||
run: echo "REPO_NAME=$(echo "$REPO_NAME" | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GHCR_TOKEN }}
|
||||
logout: false
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/berriai/litellm
|
||||
- name: Get tag to build
|
||||
id: tag
|
||||
run: |
|
||||
echo "latest=ghcr.io/${{ env.REPO_NAME }}:latest" >> $GITHUB_OUTPUT
|
||||
if [[ -z "${{ github.event.inputs.tag }}" ]]; then
|
||||
echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.ref_name }}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
- name: Debug Info
|
||||
run: |
|
||||
echo "GHCR_TOKEN=${{ secrets.GHCR_TOKEN }}"
|
||||
echo "REPO_NAME=${{ env.REPO_NAME }}"
|
||||
echo "ACTOR=${{ github.actor }}"
|
||||
- name: Build and push container image to registry
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
push: true
|
||||
tags: ghcr.io/${{ env.REPO_NAME }}:${{ github.sha }}
|
||||
file: ./Dockerfile
|
||||
- name: Build and release Docker images
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
tags: |
|
||||
${{ steps.tag.outputs.latest }}
|
||||
${{ steps.tag.outputs.versioned }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
push: true
|
||||
|
8
.github/workflows/ghcr_deploy.yml
vendored
8
.github/workflows/ghcr_deploy.yml
vendored
|
@ -1,10 +1,12 @@
|
|||
#
|
||||
name: Build & Publich to GHCR
|
||||
name: Build & Publish to GHCR
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag version you want to build"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||
env:
|
||||
|
@ -19,7 +21,7 @@ jobs:
|
|||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
#
|
||||
#
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
@ -44,5 +46,5 @@ jobs:
|
|||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag }} # Add the input tag to the image tags
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name }} # if a tag is provided, use that, otherwise use the release tag
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
|
|
@ -1,13 +1,19 @@
|
|||
# Exception Mapping
|
||||
|
||||
LiteLLM maps exceptions across all providers to their OpenAI counterparts.
|
||||
- Rate Limit Errors
|
||||
- Invalid Request Errors
|
||||
- Authentication Errors
|
||||
- Timeout Errors `openai.APITimeoutError`
|
||||
- ServiceUnavailableError
|
||||
- APIError
|
||||
- APIConnectionError
|
||||
|
||||
| Status Code | Error Type |
|
||||
|-------------|--------------------------|
|
||||
| 400 | BadRequestError |
|
||||
| 401 | AuthenticationError |
|
||||
| 403 | PermissionDeniedError |
|
||||
| 404 | NotFoundError |
|
||||
| 422 | UnprocessableEntityError |
|
||||
| 429 | RateLimitError |
|
||||
| >=500 | InternalServerError |
|
||||
| N/A | ContextWindowExceededError|
|
||||
| N/A | APIConnectionError |
|
||||
|
||||
|
||||
Base case we return APIConnectionError
|
||||
|
||||
|
@ -83,6 +89,7 @@ Base case - we return the original exception.
|
|||
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------|
|
||||
| Anthropic | ✅ | ✅ | ✅ | ✅ | |
|
||||
| OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
||||
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
||||
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Huggingface | ✅ | ✅ | ✅ | ✅ | |
|
||||
|
|
|
@ -15,7 +15,7 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
|
|||
## Pre-Requisites
|
||||
Ensure you have run `pip install langfuse` for this integration
|
||||
```shell
|
||||
pip install langfuse litellm
|
||||
pip install langfuse==1.14.0 litellm
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
|
|
@ -41,6 +41,18 @@ def send_slack_alert(
|
|||
# get it from https://api.slack.com/messaging/webhooks
|
||||
slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
|
||||
|
||||
# Remove api_key from kwargs under litellm_params
|
||||
if kwargs.get('litellm_params'):
|
||||
kwargs['litellm_params'].pop('api_key', None)
|
||||
if kwargs['litellm_params'].get('metadata'):
|
||||
kwargs['litellm_params']['metadata'].pop('deployment', None)
|
||||
# Remove deployment under metadata
|
||||
if kwargs.get('metadata'):
|
||||
kwargs['metadata'].pop('deployment', None)
|
||||
# Prevent api_key from being logged
|
||||
if kwargs.get('api_key'):
|
||||
kwargs.pop('api_key', None)
|
||||
|
||||
# Define the text payload, send data available in litellm custom_callbacks
|
||||
text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
|
||||
"""
|
||||
|
@ -90,4 +102,4 @@ response = litellm.completion(
|
|||
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
|
||||
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
|
||||
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
|
||||
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
|
||||
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
|
||||
|
|
|
@ -77,6 +77,45 @@ Ollama supported models: https://github.com/jmorganca/ollama
|
|||
| Nous-Hermes 13B | `completion(model='ollama/nous-hermes:13b', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Wizard Vicuna Uncensored | `completion(model='ollama/wizard-vicuna', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
|
||||
## Ollama Vision Models
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| llava | `completion('ollama/llava', messages)` |
|
||||
|
||||
#### Using Ollama Vision Models
|
||||
|
||||
Call `ollama/llava` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models)
|
||||
|
||||
LiteLLM Supports the following image types passed in `url`
|
||||
- Base64 encoded svgs
|
||||
|
||||
**Example Request**
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.completion(
|
||||
model = "ollama/llava",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Whats in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## LiteLLM/Ollama Docker Image
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# OpenRouter
|
||||
LiteLLM supports all the text models from [OpenRouter](https://openrouter.ai/docs)
|
||||
LiteLLM supports all the text / chat / vision models from [OpenRouter](https://openrouter.ai/docs)
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/LiteLLM_OpenRouter.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
|
@ -20,10 +20,10 @@ response = completion(
|
|||
)
|
||||
```
|
||||
|
||||
### OpenRouter Completion Models
|
||||
## OpenRouter Completion Models
|
||||
|
||||
| Model Name | Function Call | Required OS Variables |
|
||||
|---------------------------|-----------------------------------------------------|--------------------------------------------------------------|
|
||||
| Model Name | Function Call |
|
||||
|---------------------------|-----------------------------------------------------|
|
||||
| openrouter/openai/gpt-3.5-turbo | `completion('openrouter/openai/gpt-3.5-turbo', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
|
||||
| openrouter/openai/gpt-3.5-turbo-16k | `completion('openrouter/openai/gpt-3.5-turbo-16k', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
|
||||
| openrouter/openai/gpt-4 | `completion('openrouter/openai/gpt-4', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
|
||||
|
@ -35,3 +35,19 @@ response = completion(
|
|||
| openrouter/meta-llama/llama-2-13b-chat | `completion('openrouter/meta-llama/llama-2-13b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
|
||||
| openrouter/meta-llama/llama-2-70b-chat | `completion('openrouter/meta-llama/llama-2-70b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` |
|
||||
|
||||
## Passing OpenRouter Params - transforms, models, route
|
||||
|
||||
Pass `transforms`, `models`, `route`as arguments to `litellm.completion()`
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
os.environ["OPENROUTER_API_KEY"] = ""
|
||||
|
||||
response = completion(
|
||||
model="openrouter/google/palm-2-chat-bison",
|
||||
messages=messages,
|
||||
transforms = [""],
|
||||
route= ""
|
||||
)
|
||||
```
|
|
@ -59,4 +59,21 @@ $ litellm /path/to/config.yaml
|
|||
3. Query health endpoint:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/health'
|
||||
```
|
||||
```
|
||||
|
||||
## Embedding Models
|
||||
|
||||
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: azure-embedding-model
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
mode: embedding # 👈 ADD THIS
|
||||
```
|
||||
|
||||
|
|
|
@ -461,7 +461,7 @@ We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this
|
|||
**Step 1** Install langfuse
|
||||
|
||||
```shell
|
||||
pip install langfuse
|
||||
pip install langfuse==1.14.0
|
||||
```
|
||||
|
||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
|
|
|
@ -1,6 +1,17 @@
|
|||
# Model Management
|
||||
Add new models + Get model info without restarting proxy.
|
||||
|
||||
## In Config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: text-davinci-003
|
||||
litellm_params:
|
||||
model: "text-completion-openai/text-davinci-003"
|
||||
model_info:
|
||||
metadata: "here's additional metadata on the model" # returned via GET /model/info
|
||||
```
|
||||
|
||||
## Get Model Information
|
||||
|
||||
Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.
|
||||
|
|
|
@ -349,6 +349,12 @@ litellm --config your_config.yaml
|
|||
[**More Info**](./configs.md)
|
||||
|
||||
## Server Endpoints
|
||||
|
||||
:::note
|
||||
|
||||
You can see Swagger Docs for the server on root http://0.0.0.0:8000
|
||||
|
||||
:::
|
||||
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
|
||||
- POST `/completions` - completions endpoint
|
||||
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints
|
||||
|
|
|
@ -39,14 +39,17 @@ litellm --config /path/to/config.yaml
|
|||
|
||||
```shell
|
||||
curl 'http://0.0.0.0:8000/key/generate' \
|
||||
--h 'Authorization: Bearer sk-1234' \
|
||||
--d '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}'
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m","metadata": {"user": "ishaan@berri.ai", "team": "core-infra"}}'
|
||||
```
|
||||
|
||||
- `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server.
|
||||
|
||||
- `duration`: *str or null (optional)* Specify the length of time the token is valid for. If null, default is set to 1 hour. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
|
||||
- `metadata`: *dict or null (optional)* Pass metadata for the created token. If null defaults to {}
|
||||
|
||||
Expected response:
|
||||
|
||||
```python
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Callable, Optional
|
|||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -261,7 +262,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
|
||||
def streaming(self,
|
||||
logging_obj,
|
||||
|
@ -463,13 +467,52 @@ class AzureChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
client=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
if client is None:
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
|
||||
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
|
||||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.images.generate(**data)
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def image_generation(self,
|
||||
prompt: list,
|
||||
prompt: str,
|
||||
timeout: float,
|
||||
model: Optional[str]=None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
|
@ -477,9 +520,12 @@ class AzureChatCompletion(BaseLLM):
|
|||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
model = model
|
||||
if model and len(model) > 0:
|
||||
model = model
|
||||
else:
|
||||
model = None
|
||||
data = {
|
||||
# "model": model,
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
@ -487,12 +533,26 @@ class AzureChatCompletion(BaseLLM):
|
|||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
# if aembedding == True:
|
||||
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
# return response
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
|
||||
return response
|
||||
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
|
||||
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
||||
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
|
|
118
litellm/llms/custom_httpx/azure_dall_e_2.py
Normal file
118
litellm/llms/custom_httpx/azure_dall_e_2.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
import time, json, httpx, asyncio
|
||||
|
||||
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
||||
"""
|
||||
Async implementation of custom http transport
|
||||
"""
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if "images/generations" in request.url.path and request.url.params[
|
||||
"api-version"
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]:
|
||||
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||
response = await super().handle_async_request(request)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
request.url = httpx.URL(operation_location_url)
|
||||
request.method = "GET"
|
||||
response = await super().handle_async_request(request)
|
||||
await response.aread()
|
||||
|
||||
timeout_secs: int = 120
|
||||
start_time = time.time()
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(timeout).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
time.sleep(int(response.headers.get("retry-after")) or 10)
|
||||
response = await super().handle_async_request(request)
|
||||
await response.aread()
|
||||
|
||||
if response.json()["status"] == "failed":
|
||||
error_data = response.json()
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(error_data).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
result = response.json()["result"]
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers=response.headers,
|
||||
content=json.dumps(result).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
return await super().handle_async_request(request)
|
||||
|
||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||
"""
|
||||
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
||||
|
||||
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
|
||||
"""
|
||||
def handle_request(
|
||||
self,
|
||||
request: httpx.Request,
|
||||
) -> httpx.Response:
|
||||
if "images/generations" in request.url.path and request.url.params[
|
||||
"api-version"
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]:
|
||||
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||
response = super().handle_request(request)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
request.url = httpx.URL(operation_location_url)
|
||||
request.method = "GET"
|
||||
response = super().handle_request(request)
|
||||
response.read()
|
||||
|
||||
timeout_secs: int = 120
|
||||
start_time = time.time()
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(timeout).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
time.sleep(int(response.headers.get("retry-after")) or 10)
|
||||
response = super().handle_request(request)
|
||||
response.read()
|
||||
|
||||
if response.json()["status"] == "failed":
|
||||
error_data = response.json()
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
content=json.dumps(error_data).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
|
||||
result = response.json()["result"]
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers=response.headers,
|
||||
content=json.dumps(result).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
return super().handle_request(request)
|
|
@ -1,14 +1,10 @@
|
|||
import requests, types, time
|
||||
import json
|
||||
import json, uuid
|
||||
import traceback
|
||||
from typing import Optional
|
||||
import litellm
|
||||
import httpx, aiohttp, asyncio
|
||||
try:
|
||||
from async_generator import async_generator, yield_ # optional dependency
|
||||
async_generator_imported = True
|
||||
except ImportError:
|
||||
async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class OllamaError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -106,9 +102,8 @@ class OllamaConfig():
|
|||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
|
||||
# ollama implementation
|
||||
def get_ollama_response_stream(
|
||||
def get_ollama_response(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
|
@ -129,6 +124,7 @@ def get_ollama_response_stream(
|
|||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
optional_params["stream"] = optional_params.get("stream", False)
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
|
@ -146,9 +142,41 @@ def get_ollama_response_stream(
|
|||
else:
|
||||
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
return response
|
||||
|
||||
else:
|
||||
elif optional_params.get("stream", False):
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={
|
||||
"headers": None,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
if optional_params.get("format", "") == "json":
|
||||
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
|
||||
model_response["choices"][0]["message"] = message
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
return model_response
|
||||
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
|
@ -157,13 +185,15 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
try:
|
||||
|
@ -194,39 +224,31 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
text = await resp.text()
|
||||
raise OllamaError(status_code=resp.status, message=text)
|
||||
|
||||
completion_string = ""
|
||||
async for line in resp.content.iter_any():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
chunks = json_chunk.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip() != "":
|
||||
j = json.loads(chunk)
|
||||
if "error" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"error": j
|
||||
}
|
||||
raise Exception(f"OllamError - {chunk}")
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": j["response"],
|
||||
}
|
||||
completion_string = completion_string + completion_obj["content"]
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data['prompt'],
|
||||
api_key="",
|
||||
original_response=resp.text,
|
||||
additional_args={
|
||||
"headers": None,
|
||||
"api_base": url,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = await resp.json()
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
model_response["choices"][0]["message"]["content"] = completion_string
|
||||
if data.get("format", "") == "json":
|
||||
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
|
||||
model_response["choices"][0]["message"] = message
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data['model']
|
||||
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
|
||||
completion_tokens = len(encoding.encode(completion_string))
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
|
|
@ -284,7 +284,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def streaming(self,
|
||||
|
@ -445,6 +445,43 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.images.generate(**data) # type: ignore
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def image_generation(self,
|
||||
model: Optional[str],
|
||||
prompt: str,
|
||||
|
@ -631,24 +668,27 @@ class OpenAITextCompletion(BaseLLM):
|
|||
api_key: str,
|
||||
model: str):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def streaming(self,
|
||||
logging_obj,
|
||||
|
@ -687,9 +727,12 @@ class OpenAITextCompletion(BaseLLM):
|
|||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
raise e
|
|
@ -348,7 +348,7 @@ def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/r
|
|||
|
||||
# Function call template
|
||||
def function_call_prompt(messages: list, functions: list):
|
||||
function_prompt = "The following functions are available to you:"
|
||||
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:"
|
||||
for function in functions:
|
||||
function_prompt += f"""\n{function}\n"""
|
||||
|
||||
|
@ -425,6 +425,6 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return alpaca_pt(messages=messages)
|
||||
else:
|
||||
return hf_chat_template(original_model_name, messages)
|
||||
except:
|
||||
except Exception as e:
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
||||
|
|
|
@ -181,6 +181,7 @@ def completion(
|
|||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
||||
|
||||
|
||||
vertexai.init(
|
||||
|
@ -193,6 +194,16 @@ def completion(
|
|||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
## Process safety settings into format expected by vertex AI
|
||||
safety_settings = None
|
||||
if "safety_settings" in optional_params:
|
||||
safety_settings = optional_params.pop("safety_settings")
|
||||
if not isinstance(safety_settings, list):
|
||||
raise ValueError("safety_settings must be a list")
|
||||
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
|
||||
raise ValueError("safety_settings must be a list of dicts")
|
||||
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
|
||||
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
|
||||
|
@ -238,16 +249,16 @@ def completion(
|
|||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
|
||||
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
|
@ -258,12 +269,13 @@ def completion(
|
|||
content = [prompt] + images
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
stream=True
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
|
@ -276,7 +288,8 @@ def completion(
|
|||
## LLM Call
|
||||
response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params)
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
|
|
|
@ -1329,23 +1329,11 @@ def completion(
|
|||
optional_params["images"] = images
|
||||
|
||||
## LOGGING
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||
generator = ollama.get_ollama_response(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||
if acompletion is True or optional_params.get("stream", False) == True:
|
||||
return generator
|
||||
else:
|
||||
response_string = ""
|
||||
for chunk in generator:
|
||||
response_string+=chunk['content']
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
model_response["choices"][0]["message"]["content"] = response_string
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = len(encoding.encode(prompt)) # type: ignore
|
||||
completion_tokens = len(encoding.encode(response_string))
|
||||
model_response["usage"] = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
response = model_response
|
||||
|
||||
response = generator
|
||||
elif (
|
||||
custom_llm_provider == "baseten"
|
||||
or litellm.api_base == "https://app.baseten.co"
|
||||
|
@ -2026,11 +2014,9 @@ async def atext_completion(*args, **kwargs):
|
|||
response = text_completion(*args, **kwargs)
|
||||
else:
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -2205,7 +2191,8 @@ def text_completion(
|
|||
if stream == True or kwargs.get("stream", False) == True:
|
||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||
return response
|
||||
|
||||
if kwargs.get("acompletion", False) == True:
|
||||
return response
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
try:
|
||||
|
@ -2243,6 +2230,49 @@ def moderation(input: str, api_key: Optional[str]=None):
|
|||
return response
|
||||
|
||||
##### Image Generation #######################
|
||||
@client
|
||||
async def aimage_generation(*args, **kwargs):
|
||||
"""
|
||||
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
|
||||
|
||||
Parameters:
|
||||
- `args` (tuple): Positional arguments to be passed to the `embedding` function.
|
||||
- `kwargs` (dict): Keyword arguments to be passed to the `embedding` function.
|
||||
|
||||
Returns:
|
||||
- `response` (Any): The response returned by the `embedding` function.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
### PASS ARGS TO Image Generation ###
|
||||
kwargs["aimg_generation"] = True
|
||||
custom_llm_provider = None
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(image_generation, *args, **kwargs)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
@client
|
||||
def image_generation(prompt: str,
|
||||
model: Optional[str]=None,
|
||||
|
@ -2264,6 +2294,7 @@ def image_generation(prompt: str,
|
|||
|
||||
Currently supports just Azure + OpenAI.
|
||||
"""
|
||||
aimg_generation = kwargs.get("aimg_generation", False)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
proxy_server_request = kwargs.get('proxy_server_request', None)
|
||||
|
@ -2277,7 +2308,7 @@ def image_generation(prompt: str,
|
|||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
|
||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||
litellm_params = ["metadata", "aimg_generation", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = get_optional_params_image_gen(n=n,
|
||||
|
@ -2320,10 +2351,9 @@ def image_generation(prompt: str,
|
|||
get_secret("AZURE_AD_TOKEN")
|
||||
)
|
||||
|
||||
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||
pass
|
||||
model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation)
|
||||
elif custom_llm_provider == "openai":
|
||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation)
|
||||
|
||||
return model_response
|
||||
|
||||
|
|
|
@ -120,6 +120,8 @@ class GenerateKeyRequest(LiteLLMBase):
|
|||
spend: Optional[float] = 0
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Optional[dict] = {}
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
class UpdateKeyRequest(LiteLLMBase):
|
||||
key: str
|
||||
|
@ -130,21 +132,8 @@ class UpdateKeyRequest(LiteLLMBase):
|
|||
spend: Optional[float] = None
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
|
||||
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
metadata: Optional[dict] = {}
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
|
@ -158,6 +147,20 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
|
|||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
metadata: dict = {}
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
|
|
|
@ -96,7 +96,7 @@ async def _perform_health_check(model_list: list):
|
|||
|
||||
|
||||
|
||||
async def perform_health_check(model_list: list, model: Optional[str] = None):
|
||||
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None):
|
||||
"""
|
||||
Perform a health check on the system.
|
||||
|
||||
|
@ -104,7 +104,10 @@ async def perform_health_check(model_list: list, model: Optional[str] = None):
|
|||
(bool): True if the health check passes, False otherwise.
|
||||
"""
|
||||
if not model_list:
|
||||
return [], []
|
||||
if cli_model:
|
||||
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}]
|
||||
else:
|
||||
return [], []
|
||||
|
||||
if model is not None:
|
||||
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]
|
||||
|
|
35
litellm/proxy/hooks/max_budget_limiter.py
Normal file
35
litellm/proxy/hooks/max_budget_limiter.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
|
||||
class MaxBudgetLimiter(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_budget = user_api_key_dict.max_budget
|
||||
curr_spend = user_api_key_dict.spend
|
||||
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if max_budget is None:
|
||||
return
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
if curr_spend >= max_budget:
|
||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
|
@ -47,7 +47,7 @@ litellm_settings:
|
|||
# setting callback class
|
||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
|
||||
general_settings:
|
||||
# general_settings:
|
||||
|
||||
environment_variables:
|
||||
# otel: True # OpenTelemetry Logger
|
||||
|
|
|
@ -110,7 +110,7 @@ import json
|
|||
import logging
|
||||
from typing import Union
|
||||
|
||||
app = FastAPI(docs_url="/", title="LiteLLM API")
|
||||
app = FastAPI(docs_url="/", title="LiteLLM API", description="Proxy Server to call 100+ LLMs in the OpenAI format")
|
||||
router = APIRouter()
|
||||
origins = ["*"]
|
||||
|
||||
|
@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
|
||||
async def generate_key_helper_fn(duration: Optional[str],
|
||||
models: list,
|
||||
aliases: dict,
|
||||
config: dict,
|
||||
spend: float,
|
||||
max_budget: Optional[float]=None,
|
||||
token: Optional[str]=None,
|
||||
user_id: Optional[str]=None,
|
||||
max_parallel_requests: Optional[int]=None,
|
||||
metadata: Optional[dict] = {},):
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
|
@ -653,6 +662,7 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
|||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
metadata_json = json.dumps(metadata)
|
||||
user_id = user_id or str(uuid.uuid4())
|
||||
try:
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
|
@ -664,7 +674,9 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
|||
"config": config_json,
|
||||
"spend": spend,
|
||||
"user_id": user_id,
|
||||
"max_parallel_requests": max_parallel_requests
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
"metadata": metadata_json,
|
||||
"max_budget": max_budget
|
||||
}
|
||||
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
||||
except Exception as e:
|
||||
|
@ -774,6 +786,7 @@ def initialize(
|
|||
print(f"\033[1;34mLiteLLM: Test your local proxy with: \"litellm --test\" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n")
|
||||
print(f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n")
|
||||
print("\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n")
|
||||
print(f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:8000 \033[0m\n")
|
||||
# for streaming
|
||||
def data_generator(response):
|
||||
print_verbose("inside generator")
|
||||
|
@ -871,9 +884,9 @@ def model_list():
|
|||
object="list",
|
||||
)
|
||||
|
||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"])
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
try:
|
||||
|
@ -1044,8 +1057,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
detail=error_msg
|
||||
)
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"])
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"])
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
|
@ -1124,6 +1137,72 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
detail=error_msg
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"])
|
||||
@router.post("/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"])
|
||||
async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data) # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("image_generation_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.aimage_generation(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aimage_generation(**data, specific_deployment = True)
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.aimage_generation(**data) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
else:
|
||||
response = await litellm.aimage_generation(**data)
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = 500
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
)
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
|
||||
|
@ -1140,6 +1219,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
|||
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
||||
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||
|
||||
Returns:
|
||||
- key: (str) The generated api key
|
||||
|
@ -1165,7 +1245,6 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
|
|||
raise Exception("Not connected to DB!")
|
||||
|
||||
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
||||
print(f"non_default_values: {non_default_values}")
|
||||
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
|
||||
return {"key": key, **non_default_values}
|
||||
# update based on remaining passed in values
|
||||
|
@ -1514,9 +1593,18 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query
|
|||
```
|
||||
else, the health checks will be run on models when /health is called.
|
||||
"""
|
||||
global health_check_results, use_background_health_checks
|
||||
global health_check_results, use_background_health_checks, user_model
|
||||
|
||||
if llm_model_list is None:
|
||||
# if no router set, check if user set a model using litellm --model ollama/llama2
|
||||
if user_model is not None:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=[], cli_model=user_model)
|
||||
return {
|
||||
"healthy_endpoints": healthy_endpoints,
|
||||
"unhealthy_endpoints": unhealthy_endpoints,
|
||||
"healthy_count": len(healthy_endpoints),
|
||||
"unhealthy_count": len(unhealthy_endpoints),
|
||||
}
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Model list not initialized"},
|
||||
|
|
|
@ -17,4 +17,6 @@ model LiteLLM_VerificationToken {
|
|||
config Json @default("{}")
|
||||
user_id String?
|
||||
max_parallel_requests Int?
|
||||
metadata Json @default("{}")
|
||||
max_budget Float?
|
||||
}
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy, json
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
|
@ -23,11 +24,13 @@ class ProxyLogging:
|
|||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
|
@ -147,6 +150,14 @@ class PrismaClient:
|
|||
|
||||
return hashed_token
|
||||
|
||||
def jsonify_object(self, data: dict) -> dict:
|
||||
db_data = copy.deepcopy(data)
|
||||
|
||||
for k, v in db_data.items():
|
||||
if isinstance(v, dict):
|
||||
db_data[k] = json.dumps(v)
|
||||
return db_data
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
|
@ -193,9 +204,8 @@ class PrismaClient:
|
|||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
|
@ -228,7 +238,7 @@ class PrismaClient:
|
|||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import copy
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any
|
||||
import random, threading, time, traceback, uuid
|
||||
|
@ -18,6 +18,7 @@ import inspect, concurrent
|
|||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
import copy
|
||||
class Router:
|
||||
"""
|
||||
|
@ -84,11 +85,11 @@ class Router:
|
|||
|
||||
self.set_verbose = set_verbose
|
||||
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
self.deployment_latency_map = {}
|
||||
if model_list:
|
||||
model_list = copy.deepcopy(model_list)
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
self.deployment_latency_map = {}
|
||||
for m in model_list:
|
||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
|
||||
|
@ -166,7 +167,7 @@ class Router:
|
|||
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
|
||||
|
||||
|
||||
### COMPLETION + EMBEDDING FUNCTIONS
|
||||
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
|
||||
|
||||
def completion(self,
|
||||
model: str,
|
||||
|
@ -260,6 +261,94 @@ class Router:
|
|||
self.fail_calls[model_name] +=1
|
||||
raise e
|
||||
|
||||
def image_generation(self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
**kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._image_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = self.function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _image_generation(self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
**kwargs):
|
||||
try:
|
||||
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
|
||||
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[model_name] +=1
|
||||
response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[model_name] +=1
|
||||
return response
|
||||
except Exception as e:
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] +=1
|
||||
raise e
|
||||
|
||||
async def aimage_generation(self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
**kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._aimage_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _aimage_generation(self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
**kwargs):
|
||||
try:
|
||||
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
|
||||
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[model_name] +=1
|
||||
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[model_name] +=1
|
||||
return response
|
||||
except Exception as e:
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] +=1
|
||||
raise e
|
||||
|
||||
def text_completion(self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
|
@ -436,7 +525,6 @@ class Router:
|
|||
|
||||
async def async_function_with_retries(self, *args, **kwargs):
|
||||
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||
|
@ -1009,14 +1097,16 @@ class Router:
|
|||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
|
||||
)
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
|
@ -1024,7 +1114,7 @@ class Router:
|
|||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
|
|
|
@ -20,7 +20,7 @@ import tempfile
|
|||
litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
|
@ -91,7 +91,7 @@ def test_vertex_ai():
|
|||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
litellm.set_verbose=False
|
||||
litellm.vertex_project = "hardy-device-386718"
|
||||
litellm.vertex_project = "reliablekeys"
|
||||
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
|
@ -113,7 +113,7 @@ def test_vertex_ai():
|
|||
def test_vertex_ai_stream():
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose=False
|
||||
litellm.vertex_project = "hardy-device-386718"
|
||||
litellm.vertex_project = "reliablekeys"
|
||||
import random
|
||||
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
|
|
|
@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider():
|
|||
|
||||
# test_completion_hf_model_no_provider()
|
||||
|
||||
# def test_completion_openai_azure_with_functions():
|
||||
# function1 = [
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
# response = completion(
|
||||
# model="azure/chatgpt-functioncalling", messages=messages, functions=function1
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_openai_azure_with_functions()
|
||||
def test_completion_anyscale_with_functions():
|
||||
function1 = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
response = completion(
|
||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_anyscale_with_functions()
|
||||
|
||||
def test_completion_azure_key_completion_arg():
|
||||
# this tests if we can pass api_key to completion, when it's not in the env
|
||||
|
@ -727,7 +727,7 @@ def test_completion_azure():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_azure()
|
||||
test_completion_azure()
|
||||
|
||||
def test_azure_openai_ad_token():
|
||||
# this tests if the azure ad token is set in the request header
|
||||
|
|
|
@ -78,4 +78,19 @@ model_list:
|
|||
model: "bedrock/amazon.titan-embed-text-v1"
|
||||
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
|
||||
litellm_params:
|
||||
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
||||
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
||||
- model_name: dall-e-3
|
||||
litellm_params:
|
||||
model: dall-e-3
|
||||
- model_name: dall-e-3
|
||||
litellm_params:
|
||||
model: "azure/dall-e-3-test"
|
||||
api_version: "2023-12-01-preview"
|
||||
api_base: "os.environ/AZURE_SWEDEN_API_BASE"
|
||||
api_key: "os.environ/AZURE_SWEDEN_API_KEY"
|
||||
- model_name: dall-e-2
|
||||
litellm_params:
|
||||
model: "azure/"
|
||||
api_version: "2023-06-01-preview"
|
||||
api_base: "os.environ/AZURE_API_BASE"
|
||||
api_key: "os.environ/AZURE_API_KEY"
|
|
@ -17,6 +17,14 @@ def test_prompt_formatting():
|
|||
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
def test_prompt_formatting_custom_model():
|
||||
try:
|
||||
prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface")
|
||||
print(f"prompt: {prompt}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
# test_prompt_formatting_custom_model()
|
||||
# def logger_fn(user_model_dict):
|
||||
# return
|
||||
# print(f"user_model_dict: {user_model_dict}")
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
import asyncio
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -18,20 +19,31 @@ def test_image_generation_openai():
|
|||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
|
||||
# test_image_generation_openai()
|
||||
|
||||
# def test_image_generation_azure():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
|
||||
# print(f"response: {response}")
|
||||
def test_image_generation_azure():
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
# test_image_generation_azure()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_async_image_generation_openai():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
# print(f"response: {response}")
|
||||
def test_image_generation_azure_dall_e_3():
|
||||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
# test_image_generation_azure_dall_e_3()
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_image_generation_openai():
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_async_image_generation_azure():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3")
|
||||
# print(f"response: {response}")
|
||||
# asyncio.run(test_async_image_generation_openai())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_image_generation_azure():
|
||||
response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test")
|
||||
print(f"response: {response}")
|
|
@ -16,23 +16,61 @@
|
|||
# user_message = "respond in 20 words. who are you?"
|
||||
# messages = [{ "content": user_message,"role": "user"}]
|
||||
|
||||
# def test_ollama_streaming():
|
||||
# try:
|
||||
# litellm.set_verbose = False
|
||||
# messages = [
|
||||
# {"role": "user", "content": "What is the weather like in Boston?"}
|
||||
# ]
|
||||
# functions = [
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA"
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"]
|
||||
# }
|
||||
# },
|
||||
# "required": ["location"]
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# response = litellm.completion(model="ollama/mistral",
|
||||
# messages=messages,
|
||||
# functions=functions,
|
||||
# stream=True)
|
||||
# for chunk in response:
|
||||
# print(f"CHUNK: {chunk}")
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
|
||||
# test_ollama_streaming()
|
||||
|
||||
# async def test_async_ollama_streaming():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# litellm.set_verbose = False
|
||||
# response = await litellm.acompletion(model="ollama/mistral-openorca",
|
||||
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
# stream=True)
|
||||
# async for chunk in response:
|
||||
# print(chunk)
|
||||
# print(f"CHUNK: {chunk}")
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
|
||||
# asyncio.run(test_async_ollama_streaming())
|
||||
# # asyncio.run(test_async_ollama_streaming())
|
||||
|
||||
# def test_completion_ollama():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# response = completion(
|
||||
# model="ollama/llama2",
|
||||
# model="ollama/mistral",
|
||||
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
# max_tokens=200,
|
||||
# request_timeout = 10,
|
||||
|
@ -44,7 +82,87 @@
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_ollama()
|
||||
# # test_completion_ollama()
|
||||
|
||||
# def test_completion_ollama_function_calling():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# messages = [
|
||||
# {"role": "user", "content": "What is the weather like in Boston?"}
|
||||
# ]
|
||||
# functions = [
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA"
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"]
|
||||
# }
|
||||
# },
|
||||
# "required": ["location"]
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# response = completion(
|
||||
# model="ollama/mistral",
|
||||
# messages=messages,
|
||||
# functions=functions,
|
||||
# max_tokens=200,
|
||||
# request_timeout = 10,
|
||||
# )
|
||||
# for chunk in response:
|
||||
# print(chunk)
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# # test_completion_ollama_function_calling()
|
||||
|
||||
# async def async_test_completion_ollama_function_calling():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# messages = [
|
||||
# {"role": "user", "content": "What is the weather like in Boston?"}
|
||||
# ]
|
||||
# functions = [
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA"
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"]
|
||||
# }
|
||||
# },
|
||||
# "required": ["location"]
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# response = await litellm.acompletion(
|
||||
# model="ollama/mistral",
|
||||
# messages=messages,
|
||||
# functions=functions,
|
||||
# max_tokens=200,
|
||||
# request_timeout = 10,
|
||||
# )
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # asyncio.run(async_test_completion_ollama_function_calling())
|
||||
|
||||
|
||||
# def test_completion_ollama_with_api_base():
|
||||
# try:
|
||||
|
@ -197,7 +315,7 @@
|
|||
# )
|
||||
# print("Response from ollama/llava")
|
||||
# print(response)
|
||||
# test_ollama_llava()
|
||||
# # test_ollama_llava()
|
||||
|
||||
|
||||
# # PROCESSED CHUNK PRE CHUNK CREATOR
|
||||
|
|
|
@ -99,8 +99,6 @@ def test_embedding(client):
|
|||
def test_chat_completion(client):
|
||||
try:
|
||||
# Your test data
|
||||
|
||||
print("initialized proxy")
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
|
|
|
@ -68,6 +68,7 @@ def test_chat_completion_exception_azure(client):
|
|||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
print(openai_exception)
|
||||
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -101,7 +101,7 @@ def test_chat_completion_azure(client_no_auth):
|
|||
# Run the test
|
||||
# test_chat_completion_azure()
|
||||
|
||||
|
||||
### EMBEDDING
|
||||
def test_embedding(client_no_auth):
|
||||
global headers
|
||||
from litellm.proxy.proxy_server import user_custom_auth
|
||||
|
@ -161,7 +161,30 @@ def test_sagemaker_embedding(client_no_auth):
|
|||
|
||||
# Run the test
|
||||
# test_embedding()
|
||||
#### IMAGE GENERATION
|
||||
|
||||
def test_img_gen(client_no_auth):
|
||||
global headers
|
||||
from litellm.proxy.proxy_server import user_custom_auth
|
||||
|
||||
try:
|
||||
test_data = {
|
||||
"model": "dall-e-3",
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "1024x1024"
|
||||
}
|
||||
|
||||
response = client_no_auth.post("/v1/images/generations", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(len(result["data"][0]["url"]))
|
||||
assert len(result["data"][0]["url"]) > 10
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
#### ADDITIONAL
|
||||
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
|
||||
def test_add_new_model(client_no_auth):
|
||||
global headers
|
||||
|
|
|
@ -423,6 +423,94 @@ def test_function_calling_on_router():
|
|||
|
||||
# test_function_calling_on_router()
|
||||
|
||||
### IMAGE GENERATION
|
||||
@pytest.mark.asyncio
|
||||
async def test_aimg_gen_on_router():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "dall-e-3",
|
||||
"litellm_params": {
|
||||
"model": "dall-e-3",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "dall-e-3",
|
||||
"litellm_params": {
|
||||
"model": "azure/dall-e-3-test",
|
||||
"api_version": "2023-12-01-preview",
|
||||
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "dall-e-2",
|
||||
"litellm_params": {
|
||||
"model": "azure/",
|
||||
"api_version": "2023-06-01-preview",
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_API_KEY")
|
||||
}
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
response = await router.aimage_generation(
|
||||
model="dall-e-3",
|
||||
prompt="A cute baby sea otter"
|
||||
)
|
||||
print(response)
|
||||
assert len(response.data) > 0
|
||||
|
||||
response = await router.aimage_generation(
|
||||
model="dall-e-2",
|
||||
prompt="A cute baby sea otter"
|
||||
)
|
||||
print(response)
|
||||
assert len(response.data) > 0
|
||||
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# asyncio.run(test_aimg_gen_on_router())
|
||||
|
||||
def test_img_gen_on_router():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "dall-e-3",
|
||||
"litellm_params": {
|
||||
"model": "dall-e-3",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "dall-e-3",
|
||||
"litellm_params": {
|
||||
"model": "azure/dall-e-3-test",
|
||||
"api_version": "2023-12-01-preview",
|
||||
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
|
||||
}
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
response = router.image_generation(
|
||||
model="dall-e-3",
|
||||
prompt="A cute baby sea otter"
|
||||
)
|
||||
print(response)
|
||||
assert len(response.data) > 0
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_img_gen_on_router()
|
||||
###
|
||||
|
||||
def test_aembedding_on_router():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
|
@ -556,7 +644,7 @@ async def test_mistral_on_router():
|
|||
]
|
||||
)
|
||||
print(response)
|
||||
asyncio.run(test_mistral_on_router())
|
||||
# asyncio.run(test_mistral_on_router())
|
||||
|
||||
def test_openai_completion_on_router():
|
||||
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
||||
|
|
|
@ -169,17 +169,37 @@ def test_text_completion_stream():
|
|||
|
||||
# test_text_completion_stream()
|
||||
|
||||
async def test_text_completion_async_stream():
|
||||
try:
|
||||
response = await atext_completion(
|
||||
model="text-completion-openai/text-davinci-003",
|
||||
prompt="good morning",
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||
# async def test_text_completion_async_stream():
|
||||
# try:
|
||||
# response = await atext_completion(
|
||||
# model="text-completion-openai/text-davinci-003",
|
||||
# prompt="good morning",
|
||||
# stream=True,
|
||||
# max_tokens=10,
|
||||
# )
|
||||
# async for chunk in response:
|
||||
# print(f"chunk: {chunk}")
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||
|
||||
asyncio.run(test_text_completion_async_stream())
|
||||
# asyncio.run(test_text_completion_async_stream())
|
||||
|
||||
def test_async_text_completion():
|
||||
litellm.set_verbose = True
|
||||
print('test_async_text_completion')
|
||||
async def test_get_response():
|
||||
try:
|
||||
response = await litellm.atext_completion(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="good morning",
|
||||
stream=False,
|
||||
max_tokens=10
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
print(e)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
test_async_text_completion()
|
|
@ -1,13 +1,13 @@
|
|||
{
|
||||
"type": "service_account",
|
||||
"project_id": "hardy-device-386718",
|
||||
"project_id": "reliablekeys",
|
||||
"private_key_id": "",
|
||||
"private_key": "",
|
||||
"client_email": "litellm-vertexai-ci-cd@hardy-device-386718.iam.gserviceaccount.com",
|
||||
"client_id": "110281020501213430254",
|
||||
"client_email": "73470430121-compute@developer.gserviceaccount.com",
|
||||
"client_id": "108560959659377334173",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/litellm-vertexai-ci-cd%40hardy-device-386718.iam.gserviceaccount.com",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com",
|
||||
"universe_domain": "googleapis.com"
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.15.1"
|
||||
version = "1.15.4"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.15.1"
|
||||
version = "1.15.4"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
|
@ -3,24 +3,24 @@ anyio==4.2.0 # openai + http req.
|
|||
openai>=1.0.0 # openai req.
|
||||
fastapi # server dep
|
||||
pydantic>=2.5 # openai req.
|
||||
appdirs # server dep
|
||||
backoff # server dep
|
||||
pyyaml # server dep
|
||||
uvicorn # server dep
|
||||
boto3 # aws bedrock/sagemaker calls
|
||||
redis # caching
|
||||
prisma # for db
|
||||
mangum # for aws lambda functions
|
||||
google-generativeai # for vertex ai calls
|
||||
appdirs==1.4.4 # server dep
|
||||
backoff==2.2.1 # server dep
|
||||
pyyaml==6.0 # server dep
|
||||
uvicorn==0.22.0 # server dep
|
||||
boto3==1.28.58 # aws bedrock/sagemaker calls
|
||||
redis==4.6.0 # caching
|
||||
prisma==0.11.0 # for db
|
||||
mangum==0.17.0 # for aws lambda functions
|
||||
google-generativeai==0.1.0 # for vertex ai calls
|
||||
traceloop-sdk==0.5.3 # for open telemetry logging
|
||||
langfuse==1.14.0 # for langfuse self-hosted logging
|
||||
### LITELLM PACKAGE DEPENDENCIES
|
||||
python-dotenv>=0.2.0 # for env
|
||||
tiktoken>=0.4.0 # for calculating usage
|
||||
importlib-metadata>=6.8.0 # for random utils
|
||||
tokenizers # for calculating usage
|
||||
click # for proxy cli
|
||||
tokenizers==0.14.0 # for calculating usage
|
||||
click==8.1.7 # for proxy cli
|
||||
jinja2==3.1.2 # for prompt templates
|
||||
certifi>=2023.7.22 # [TODO] clean up
|
||||
aiohttp # for network calls
|
||||
aiohttp==3.8.4 # for network calls
|
||||
####
|
Loading…
Add table
Add a link
Reference in a new issue