Merge branch 'main' into public-fix-1

This commit is contained in:
Krish Dholakia 2023-12-16 12:27:58 -08:00 committed by GitHub
commit 47ba8082df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
109 changed files with 8257 additions and 3200 deletions

View file

@ -79,6 +79,11 @@ jobs:
steps: steps:
- checkout - checkout
- run:
name: Copy model_prices_and_context_window File to model_prices_and_context_window_backup
command: |
cp model_prices_and_context_window.json litellm/model_prices_and_context_window_backup.json
- run: - run:
name: Check if litellm dir was updated or if pyproject.toml was modified name: Check if litellm dir was updated or if pyproject.toml was modified
command: | command: |

6
.gitignore vendored
View file

@ -19,3 +19,9 @@ litellm/proxy/_secret_config.yaml
litellm/tests/aiologs.log litellm/tests/aiologs.log
litellm/tests/exception_data.txt litellm/tests/exception_data.txt
litellm/tests/config_*.yaml litellm/tests/config_*.yaml
litellm/tests/langfuse.log
litellm/tests/test_custom_logger.py
litellm/tests/langfuse.log
litellm/tests/dynamo*.log
.vscode/settings.json
litellm/proxy/log.txt

View file

@ -1,8 +1,11 @@
# Base image # Base image
ARG LITELLM_BASE_IMAGE=python:3.9-slim ARG LITELLM_BUILD_IMAGE=python:3.9
# allow users to specify, else use python 3.9-slim # Runtime image
FROM $LITELLM_BASE_IMAGE ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
# allow users to specify, else use python 3.9
FROM $LITELLM_BUILD_IMAGE as builder
# Set the working directory to /app # Set the working directory to /app
WORKDIR /app WORKDIR /app
@ -16,7 +19,7 @@ RUN pip install --upgrade pip && \
pip install build pip install build
# Copy the current directory contents into the container at /app # Copy the current directory contents into the container at /app
COPY . /app COPY requirements.txt .
# Build the package # Build the package
RUN rm -rf dist/* && python -m build RUN rm -rf dist/* && python -m build
@ -25,13 +28,27 @@ RUN rm -rf dist/* && python -m build
RUN pip install dist/*.whl RUN pip install dist/*.whl
# Install any needed packages specified in requirements.txt # Install any needed packages specified in requirements.txt
RUN pip wheel --no-cache-dir --wheel-dir=wheels -r requirements.txt RUN pip install wheel && \
RUN pip install --no-cache-dir --find-links=wheels -r requirements.txt pip wheel --no-cache-dir --wheel-dir=/app/wheels -r requirements.txt
###############################################################################
FROM $LITELLM_RUNTIME_IMAGE as runtime
WORKDIR /app
# Copy the current directory contents into the container at /app
COPY . .
COPY --from=builder /app/wheels /app/wheels
RUN pip install --no-index --find-links=/app/wheels -r requirements.txt
# Trigger the Prisma CLI to be installed
RUN prisma -v
EXPOSE 4000/tcp EXPOSE 4000/tcp
# Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy # Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy
# Start the litellm proxy with default options # Start the litellm proxy with default options
CMD ["--port", "4000"] CMD ["--port", "4000"]

View file

@ -62,6 +62,22 @@ response = completion(model="command-nightly", messages=messages)
print(response) print(response)
``` ```
## Async ([Docs](https://docs.litellm.ai/docs/completion/stream#async-completion))
```python
from litellm import acompletion
import asyncio
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
return response
response = asyncio.run(test_get_response())
print(response)
```
## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream)) ## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream))
liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response. liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response.
Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.) Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.)
@ -140,6 +156,7 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
| [openrouter](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | ✅ | | [openrouter](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | ✅ |
| [google - vertex_ai](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ | | [google - vertex_ai](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ | | [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ |
| [ai21](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | ✅ | | [ai21](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | ✅ |
| [baseten](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | ✅ | | [baseten](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | ✅ |
| [vllm](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | ✅ | | [vllm](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | ✅ |

Binary file not shown.

BIN
dist/litellm-1.12.5.dev1.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.12.6.dev1.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.12.6.dev2.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.12.6.dev3.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.12.6.dev4.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.12.6.dev5.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.14.0.dev1.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.14.5.dev1.tar.gz vendored Normal file

Binary file not shown.

View file

@ -55,27 +55,76 @@ litellm.cache = cache # set litellm.cache to your cache
``` ```
### Detecting Cached Responses ## Cache Initialization Parameters
For resposes that were returned as cache hit, the response includes a param `cache` = True
:::info #### `type` (str, optional)
Only valid for OpenAI <= 0.28.1 [Let us know if you still need this](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+) The type of cache to initialize. It can be either "local" or "redis". Defaults to "local".
:::
Example response with cache hit #### `host` (str, optional)
```python
{
'cache': True,
'id': 'chatcmpl-7wggdzd6OXhgE2YhcLJHJNZsEWzZ2',
'created': 1694221467,
'model': 'gpt-3.5-turbo-0613',
'choices': [
{
'index': 0, 'message': {'role': 'assistant', 'content': 'I\'m sorry, but I couldn\'t find any information about "litellm" or how many stars it has. It is possible that you may be referring to a specific product, service, or platform that I am not familiar with. Can you please provide more context or clarify your question?'
}, 'finish_reason': 'stop'}
],
'usage': {'prompt_tokens': 17, 'completion_tokens': 59, 'total_tokens': 76},
}
``` The host address for the Redis cache. This parameter is required if the `type` is set to "redis".
#### `port` (int, optional)
The port number for the Redis cache. This parameter is required if the `type` is set to "redis".
#### `password` (str, optional)
The password for the Redis cache. This parameter is required if the `type` is set to "redis".
#### `supported_call_types` (list, optional)
A list of call types to cache for. Defaults to caching for all call types. The available call types are:
- "completion"
- "acompletion"
- "embedding"
- "aembedding"
#### `**kwargs` (additional keyword arguments)
Additional keyword arguments are accepted for the initialization of the Redis cache using the `redis.Redis()` constructor. These arguments allow you to fine-tune the Redis cache configuration based on your specific needs.
## Logging
Cache hits are logged in success events as `kwarg["cache_hit"]`.
Here's an example of accessing it:
```python
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm import completion, acompletion, Cache
# create custom callback for success_events
class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
print(f"Value of Cache hit: {kwargs['cache_hit']"})
async def test_async_completion_azure_caching():
# set custom callback
customHandler_caching = MyCustomHandler()
litellm.callbacks = [customHandler_caching]
# init cache
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
unique_time = time.time()
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
```

View file

@ -4,7 +4,9 @@
You can create a custom callback class to precisely log events as they occur in litellm. You can create a custom callback class to precisely log events as they occur in litellm.
```python ```python
import litellm
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm import completion, acompletion
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
@ -22,13 +24,37 @@ class MyCustomHandler(CustomLogger):
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure") print(f"On Failure")
#### ASYNC #### - for acompletion/aembeddings
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Streaming")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success")
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
## sync
response = completion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}], response = completion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True) stream=True)
for chunk in response: for chunk in response:
continue continue
## async
import asyncio
def async completion():
response = await acompletion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True)
async for chunk in response:
continue
asyncio.run(completion())
``` ```
## Callback Functions ## Callback Functions
@ -87,6 +113,41 @@ print(response)
## Async Callback Functions ## Async Callback Functions
We recommend using the Custom Logger class for async.
```python
from litellm.integrations.custom_logger import CustomLogger
from litellm import acompletion
class MyCustomHandler(CustomLogger):
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Streaming")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
import asyncio
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
def async completion():
response = await acompletion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True)
async for chunk in response:
continue
asyncio.run(completion())
```
**Functions**
If you just want to pass in an async function for logging.
LiteLLM currently supports just async success callback functions for async completion/embedding calls. LiteLLM currently supports just async success callback functions for async completion/embedding calls.
```python ```python
@ -117,9 +178,6 @@ asyncio.run(test_chat_openai())
:::info :::info
We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007) We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007)
::: :::
## What's in kwargs? ## What's in kwargs?
@ -170,6 +228,48 @@ Here's exactly what you can expect in the kwargs dictionary:
"end_time" = end_time # datetime object of when call was completed "end_time" = end_time # datetime object of when call was completed
``` ```
### Cache hits
Cache hits are logged in success events as `kwarg["cache_hit"]`.
Here's an example of accessing it:
```python
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm import completion, acompletion, Cache
class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
print(f"Value of Cache hit: {kwargs['cache_hit']"})
async def test_async_completion_azure_caching():
customHandler_caching = MyCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
```
### Get complete streaming response ### Get complete streaming response
LiteLLM will pass you the complete streaming response in the final streaming chunk as part of the kwargs for your custom callback function. LiteLLM will pass you the complete streaming response in the final streaming chunk as part of the kwargs for your custom callback function.

View file

@ -27,8 +27,8 @@ To get better visualizations on how your code behaves, you may want to annotate
## Exporting traces to other systems (e.g. Datadog, New Relic, and others) ## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
Since Traceloop SDK uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [Traceloop docs on exporters](https://traceloop.com/docs/python-sdk/exporters) for more information. Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
## Support ## Support
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://join.slack.com/t/traceloopcommunity/shared_invite/zt-1plpfpm6r-zOHKI028VkpcWdobX65C~g) or via [email](mailto:dev@traceloop.com). For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).

View file

@ -0,0 +1,21 @@
**A private and secure ChatGPT alternative that knows your business.**
Upload docs, ask questions --> get answers.
Leverage GenAI with your confidential documents to increase efficiency and collaboration.
OSS core, everything can run in your environment. An extensible platform you can build your GenAI strategy on. Support a variety of popular LLMs including embedded for air gap use cases.
[![Static Badge][docs-shield]][docs-url]
[![Static Badge][github-shield]][github-url]
[![X (formerly Twitter) Follow][twitter-shield]][twitter-url]
<!-- MARKDOWN LINKS & IMAGES -->
<!-- https://www.markdownguide.org/basic-syntax/#reference-style-links -->
[docs-shield]: https://img.shields.io/badge/docs-site-black?logo=materialformkdocs
[docs-url]: https://docqai.github.io/docq/
[github-shield]: https://img.shields.io/badge/Github-repo-black?logo=github
[github-url]: https://github.com/docqai/docq/
[twitter-shield]: https://img.shields.io/twitter/follow/docqai?logo=x&style=flat
[twitter-url]: https://twitter.com/docqai

View file

@ -0,0 +1,56 @@
# Mistral AI API
https://docs.mistral.ai/api/
## API Key
```python
# env variable
os.environ['MISTRAL_API_KEY']
```
## Sample Usage
```python
from litellm import completion
import os
os.environ['MISTRAL_API_KEY'] = ""
response = completion(
model="mistral/mistral-tiny"",
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
print(response)
```
## Sample Usage - Streaming
```python
from litellm import completion
import os
os.environ['MISTRAL_API_KEY'] = ""
response = completion(
model="mistral/mistral-tiny",
messages=[
{"role": "user", "content": "hello from litellm"}
],
stream=True
)
for chunk in response:
print(chunk)
```
## Supported Models
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` |
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |

View file

@ -0,0 +1,47 @@
# OpenAI-Compatible Endpoints
To call models hosted behind an openai proxy, make 2 changes:
1. Put `openai/` in front of your model name, so litellm knows you're trying to call an openai-compatible endpoint.
2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints.
## Usage
```python
import litellm
from litellm import embedding
litellm.set_verbose = True
import os
litellm_proxy_endpoint = "http://0.0.0.0:8000"
bearer_token = "sk-1234"
CHOSEN_LITE_LLM_EMBEDDING_MODEL = "openai/GPT-J 6B - Sagemaker Text Embedding (Internal)"
litellm.set_verbose = False
print(litellm_proxy_endpoint)
response = embedding(
model = CHOSEN_LITE_LLM_EMBEDDING_MODEL, # add `openai/` prefix to model so litellm knows to route to OpenAI
api_key=bearer_token,
api_base=litellm_proxy_endpoint, # set API Base of your Custom OpenAI Endpoint
input=["good morning from litellm"],
api_version='2023-07-01-preview'
)
print('================================================')
print(len(response.data[0]['embedding']))
```

View file

@ -1,4 +1,4 @@
# VertexAI - Google # VertexAI - Google [Gemini]
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@ -10,6 +10,16 @@
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc) * run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
* Alternatively you can set `application_default_credentials.json` * Alternatively you can set `application_default_credentials.json`
## Sample Usage
```python
import litellm
litellm.vertex_project = "hardy-device-38811" # Your Project ID
litellm.vertex_location = "us-central1" # proj location
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
```
## Set Vertex Project & Vertex Location ## Set Vertex Project & Vertex Location
All calls using Vertex AI require the following parameters: All calls using Vertex AI require the following parameters:
* Your Project ID * Your Project ID
@ -37,13 +47,50 @@ os.environ["VERTEXAI_LOCATION"] = "us-central1 # Your Location
litellm.vertex_location = "us-central1 # Your Location litellm.vertex_location = "us-central1 # Your Location
``` ```
## Sample Usage ## Gemini Pro
| Model Name | Function Call |
|------------------|--------------------------------------|
| gemini-pro | `completion('gemini-pro', messages)` |
## Gemini Pro Vision
| Model Name | Function Call |
|------------------|--------------------------------------|
| gemini-pro-vision | `completion('gemini-pro-vision', messages)` |
#### Using Gemini Pro Vision
Call `gemini-pro-vision` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models)
LiteLLM Supports the following image types passed in `url`
- Images with Cloud Storage URIs - gs://cloud-samples-data/generative-ai/image/boats.jpeg
- Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg
- Videos with Cloud Storage URIs - https://storage.googleapis.com/github-repo/img/gemini/multimodality_usecases_overview/pixel8.mp4
**Example Request**
```python ```python
import litellm import litellm
litellm.vertex_project = "hardy-device-38811" # Your Project ID
litellm.vertex_location = "us-central1" # proj location
response = completion(model="chat-bison", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]) response = litellm.completion(
model = "vertex_ai/gemini-pro-vision",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
)
print(response)
``` ```
## Chat Models ## Chat Models

View file

@ -1,20 +1,24 @@
# Caching # Caching
Cache LLM Responses Cache LLM Responses
## Quick Start
Caching can be enabled by adding the `cache` key in the `config.yaml` Caching can be enabled by adding the `cache` key in the `config.yaml`
#### Step 1: Add `cache` to the config.yaml ### Step 1: Add `cache` to the config.yaml
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
- model_name: text-embedding-ada-002
litellm_params:
model: text-embedding-ada-002
litellm_settings: litellm_settings:
set_verbose: True set_verbose: True
cache: True # set cache responses to True, litellm defaults to using a redis cache cache: True # set cache responses to True, litellm defaults to using a redis cache
``` ```
#### Step 2: Add Redis Credentials to .env ### Step 2: Add Redis Credentials to .env
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching. Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
```shell ```shell
@ -32,12 +36,12 @@ REDIS_<redis-kwarg-name> = ""
``` ```
[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40) [**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40)
#### Step 3: Run proxy with config ### Step 3: Run proxy with config
```shell ```shell
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
``` ```
#### Using Caching ## Using Caching - /chat/completions
Send the same request twice: Send the same request twice:
```shell ```shell
curl http://0.0.0.0:8000/v1/chat/completions \ curl http://0.0.0.0:8000/v1/chat/completions \
@ -57,9 +61,51 @@ curl http://0.0.0.0:8000/v1/chat/completions \
}' }'
``` ```
#### Control caching per completion request ## Using Caching - /embeddings
Send the same request twice:
```shell
curl --location 'http://0.0.0.0:8000/embeddings' \
--header 'Content-Type: application/json' \
--data ' {
"model": "text-embedding-ada-002",
"input": ["write a litellm poem"]
}'
curl --location 'http://0.0.0.0:8000/embeddings' \
--header 'Content-Type: application/json' \
--data ' {
"model": "text-embedding-ada-002",
"input": ["write a litellm poem"]
}'
```
## Advanced
### Set Cache Params on config.yaml
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- model_name: text-embedding-ada-002
litellm_params:
model: text-embedding-ada-002
litellm_settings:
set_verbose: True
cache: True # set cache responses to True, litellm defaults to using a redis cache
cache_params: # cache_params are optional
type: "redis" # The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
host: "localhost" # The host address for the Redis cache. Required if type is "redis".
port: 6379 # The port number for the Redis cache. Required if type is "redis".
password: "your_password" # The password for the Redis cache. Required if type is "redis".
# Optional configurations
supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types
```
### Override caching per `chat/completions` request
Caching can be switched on/off per `/chat/completions` request Caching can be switched on/off per `/chat/completions` request
- Caching **on** for completion - pass `caching=True`: - Caching **on** for individual completion - pass `caching=True`:
```shell ```shell
curl http://0.0.0.0:8000/v1/chat/completions \ curl http://0.0.0.0:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -70,7 +116,7 @@ Caching can be switched on/off per `/chat/completions` request
"caching": true "caching": true
}' }'
``` ```
- Caching **off** for completion - pass `caching=False`: - Caching **off** for individual completion - pass `caching=False`:
```shell ```shell
curl http://0.0.0.0:8000/v1/chat/completions \ curl http://0.0.0.0:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -81,3 +127,28 @@ Caching can be switched on/off per `/chat/completions` request
"caching": false "caching": false
}' }'
``` ```
### Override caching per `/embeddings` request
Caching can be switched on/off per `/embeddings` request
- Caching **on** for embedding - pass `caching=True`:
```shell
curl --location 'http://0.0.0.0:8000/embeddings' \
--header 'Content-Type: application/json' \
--data ' {
"model": "text-embedding-ada-002",
"input": ["write a litellm poem"],
"caching": true
}'
```
- Caching **off** for completion - pass `caching=False`:
```shell
curl --location 'http://0.0.0.0:8000/embeddings' \
--header 'Content-Type: application/json' \
--data ' {
"model": "text-embedding-ada-002",
"input": ["write a litellm poem"],
"caching": false
}'
```

View file

@ -0,0 +1,78 @@
# Modify Incoming Data
Modify data just before making litellm completion calls call on proxy
See a complete example with our [parallel request rate limiter](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/parallel_request_limiter.py)
## Quick Start
1. In your Custom Handler add a new `async_pre_call_hook` function
This function is called just before a litellm completion call is made, and allows you to modify the data going into the litellm call [**See Code**](https://github.com/BerriAI/litellm/blob/589a6ca863000ba8e92c897ba0f776796e7a5904/litellm/proxy/proxy_server.py#L1000)
```python
from litellm.integrations.custom_logger import CustomLogger
import litellm
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self):
pass
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs):
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
data["model"] = "my-new-model"
return data
proxy_handler_instance = MyCustomHandler()
```
2. Add this file to your proxy config
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
```
3. Start the server + test the request
```shell
$ litellm /path/to/config.yaml
```
```shell
curl --location 'http://0.0.0.0:8000/chat/completions' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "good morning good sir"
}
],
"user": "ishaan-app",
"temperature": 0.2
}'
```

View file

@ -1,4 +1,11 @@
# Deploying LiteLLM Proxy import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🐳 Docker, Deploying LiteLLM Proxy
## Dockerfile
You can find the Dockerfile to build litellm proxy [here](https://github.com/BerriAI/litellm/blob/main/Dockerfile)
## Quick Start Docker Image: Github Container Registry ## Quick Start Docker Image: Github Container Registry
@ -7,12 +14,12 @@ See the latest available ghcr docker image here:
https://github.com/berriai/litellm/pkgs/container/litellm https://github.com/berriai/litellm/pkgs/container/litellm
```shell ```shell
docker pull ghcr.io/berriai/litellm:main-v1.10.1 docker pull ghcr.io/berriai/litellm:main-v1.12.3
``` ```
### Run the Docker Image ### Run the Docker Image
```shell ```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 docker run ghcr.io/berriai/litellm:main-v1.12.3
``` ```
#### Run the Docker Image with LiteLLM CLI args #### Run the Docker Image with LiteLLM CLI args
@ -21,12 +28,12 @@ See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
Here's how you can run the docker image and pass your config to `litellm` Here's how you can run the docker image and pass your config to `litellm`
```shell ```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml docker run ghcr.io/berriai/litellm:main-v1.12.3 --config your_config.yaml
``` ```
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8` Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
```shell ```shell
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8 docker run ghcr.io/berriai/litellm:main-v1.12.3 --port 8002 --num_workers 8
``` ```
#### Run the Docker Image using docker compose #### Run the Docker Image using docker compose
@ -42,6 +49,10 @@ Here's an example `docker-compose.yml` file
version: "3.9" version: "3.9"
services: services:
litellm: litellm:
build:
context: .
args:
target: runtime
image: ghcr.io/berriai/litellm:main image: ghcr.io/berriai/litellm:main
ports: ports:
- "8000:8000" # Map the container port to the host, change the host port if necessary - "8000:8000" # Map the container port to the host, change the host port if necessary
@ -74,6 +85,26 @@ Your LiteLLM container should be running now on the defined port e.g. `8000`.
<iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe> <iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
## Deploy on Google Cloud Run
**Click the button** to deploy to Google Cloud Run
[![Deploy](https://deploy.cloud.run/button.svg)](https://deploy.cloud.run/?git_repo=https://github.com/BerriAI/litellm)
#### Testing your deployed proxy
**Assuming the required keys are set as Environment Variables**
https://litellm-7yjrj3ha2q-uc.a.run.app is our example proxy, substitute it with your deployed cloud run app
```shell
curl https://litellm-7yjrj3ha2q-uc.a.run.app/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
```
## LiteLLM Proxy Performance ## LiteLLM Proxy Performance
LiteLLM proxy has been load tested to handle 1500 req/s. LiteLLM proxy has been load tested to handle 1500 req/s.

View file

@ -0,0 +1,244 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Embeddings - `/embeddings`
See supported Embedding Providers & Models [here](https://docs.litellm.ai/docs/embedding/supported_embedding)
## Quick start
Here's how to route between GPT-J embedding (sagemaker endpoint), Amazon Titan embedding (Bedrock) and Azure OpenAI embedding on the proxy server:
1. Set models in your config.yaml
```yaml
model_list:
- model_name: sagemaker-embeddings
litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
- model_name: amazon-embeddings
litellm_params:
model: "bedrock/amazon.titan-embed-text-v1"
- model_name: azure-embeddings
litellm_params:
model: "azure/azure-embedding-model"
api_base: "os.environ/AZURE_API_BASE" # os.getenv("AZURE_API_BASE")
api_key: "os.environ/AZURE_API_KEY" # os.getenv("AZURE_API_KEY")
api_version: "2023-07-01-preview"
general_settings:
master_key: sk-1234 # [OPTIONAL] if set all calls to proxy will require either this key or a valid generated token
```
2. Start the proxy
```shell
$ litellm --config /path/to/config.yaml
```
3. Test the embedding call
```shell
curl --location 'http://0.0.0.0:8000/v1/embeddings' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"input": "The food was delicious and the waiter..",
"model": "sagemaker-embeddings",
}'
```
## `/embeddings` Request Format
Input, Output and Exceptions are mapped to the OpenAI format for all supported models
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:8000/embeddings' \
--header 'Content-Type: application/json' \
--data ' {
"model": "text-embedding-ada-002",
"input": ["write a litellm poem"]
}'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
response = openai.embeddings.create(
input=["hello from litellm"],
model="text-embedding-ada-002"
)
print(response)
```
</TabItem>
<TabItem value="langchain-embedding" label="Langchain Embeddings">
```python
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="sagemaker-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
text = "This is a test document."
query_result = embeddings.embed_query(text)
print(f"SAGEMAKER EMBEDDINGS")
print(query_result[:5])
embeddings = OpenAIEmbeddings(model="bedrock-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
text = "This is a test document."
query_result = embeddings.embed_query(text)
print(f"BEDROCK EMBEDDINGS")
print(query_result[:5])
embeddings = OpenAIEmbeddings(model="bedrock-titan-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
text = "This is a test document."
query_result = embeddings.embed_query(text)
print(f"TITAN EMBEDDINGS")
print(query_result[:5])
```
</TabItem>
</Tabs>
## `/embeddings` Response Format
```json
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [
0.0023064255,
-0.009327292,
....
-0.0028842222,
],
"index": 0
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
```
## Supported Models
See supported Embedding Providers & Models [here](https://docs.litellm.ai/docs/embedding/supported_embedding)
#### Create Config.yaml
<Tabs>
<TabItem value="Hugging Face emb" label="Hugging Face Embeddings">
LiteLLM Proxy supports all <a href="https://huggingface.co/models?pipeline_tag=feature-extraction">Feature-Extraction Embedding models</a>.
```yaml
model_list:
- model_name: deployed-codebert-base
litellm_params:
# send request to deployed hugging face inference endpoint
model: huggingface/microsoft/codebert-base # add huggingface prefix so it routes to hugging face
api_key: hf_LdS # api key for hugging face inference endpoint
api_base: https://uysneno1wv2wd4lw.us-east-1.aws.endpoints.huggingface.cloud # your hf inference endpoint
- model_name: codebert-base
litellm_params:
# no api_base set, sends request to hugging face free inference api https://api-inference.huggingface.co/models/
model: huggingface/microsoft/codebert-base # add huggingface prefix so it routes to hugging face
api_key: hf_LdS # api key for hugging face
```
</TabItem>
<TabItem value="azure" label="Azure OpenAI Embeddings">
```yaml
model_list:
- model_name: azure-embedding-model # model group
litellm_params:
model: azure/azure-embedding-model # model name for litellm.embedding(model=azure/azure-embedding-model) call
api_base: your-azure-api-base
api_key: your-api-key
api_version: 2023-07-01-preview
```
</TabItem>
<TabItem value="openai" label="OpenAI Embeddings">
```yaml
model_list:
- model_name: text-embedding-ada-002 # model group
litellm_params:
model: text-embedding-ada-002 # model name for litellm.embedding(model=text-embedding-ada-002)
api_key: your-api-key-1
- model_name: text-embedding-ada-002
litellm_params:
model: text-embedding-ada-002
api_key: your-api-key-2
```
</TabItem>
<TabItem value="openai emb" label="OpenAI Compatible Embeddings">
<p>Use this for calling <a href="https://github.com/xorbitsai/inference">/embedding endpoints on OpenAI Compatible Servers</a>.</p>
**Note add `openai/` prefix to `litellm_params`: `model` so litellm knows to route to OpenAI**
```yaml
model_list:
- model_name: text-embedding-ada-002 # model group
litellm_params:
model: openai/<your-model-name> # model name for litellm.embedding(model=text-embedding-ada-002)
api_base: <model-api-base>
```
</TabItem>
</Tabs>
#### Start Proxy
```shell
litellm --config config.yaml
```
#### Make Request
Sends Request to `deployed-codebert-base`
```shell
curl --location 'http://0.0.0.0:8000/embeddings' \
--header 'Content-Type: application/json' \
--data ' {
"model": "deployed-codebert-base",
"input": ["write a litellm poem"]
}'
```

View file

@ -0,0 +1,62 @@
# Health Checks
Use this to health check all LLMs defined in your config.yaml
## Summary
The proxy exposes:
* a /health endpoint which returns the health of the LLM APIs
* a /test endpoint which makes a ping to the litellm server
#### Request
Make a GET Request to `/health` on the proxy
```shell
curl --location 'http://0.0.0.0:8000/health'
```
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
```
litellm --health
```
#### Response
```shell
{
"healthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
},
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
}
],
"unhealthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://openai-france-1234.openai.azure.com/"
}
]
}
```
## Background Health Checks
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
Here's how to use it:
1. in the config.yaml add:
```
general_settings:
background_health_checks: True # enable background health checks
health_check_interval: 300 # frequency of background health checks
```
2. Start server
```
$ litellm /path/to/config.yaml
```
3. Query health endpoint:
```
curl --location 'http://0.0.0.0:8000/health'
```

View file

@ -72,128 +72,28 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
' '
``` ```
## Router settings on config - routing_strategy, model_group_alias
## Fallbacks + Cooldowns + Retries + Timeouts litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
If a call fails after num_retries, fall back to another model group. Example config with `router_settings`
If the error is a context window exceeded error, fall back to a larger model group (if given).
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
**Set via config**
```yaml
model_list:
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8001
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8002
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8003
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
api_key: <my-openai-key>
- model_name: gpt-3.5-turbo-16k
litellm_params:
model: gpt-3.5-turbo-16k
api_key: <my-openai-key>
litellm_settings:
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
```
**Set dynamically**
```bash
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "zephyr-beta",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
"num_retries": 2,
"timeout": 10
}
'
```
## Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: azure/gpt-turbo-small-eu model: azure/<your-deployment-name>
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ api_base: <your-azure-endpoint>
api_key: <your-key> api_key: <your-azure-api-key>
timeout: 0.1 # timeout in (seconds) rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
stream_timeout: 0.01 # timeout for stream requests (seconds)
max_retries: 5
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: azure/gpt-turbo-small-ca model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/ api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: api_key: <your-azure-api-key>
timeout: 0.1 # timeout in (seconds) rpm: 6
stream_timeout: 0.01 # timeout for stream requests (seconds) router_settings:
max_retries: 5 model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]
``` num_retries: 2
timeout: 30 # 30 seconds
#### Start Proxy
```shell
$ litellm --config /path/to/config.yaml
```
## Health Check LLMs on Proxy
Use this to health check all LLMs defined in your config.yaml
#### Request
Make a GET Request to `/health` on the proxy
```shell
curl --location 'http://0.0.0.0:8000/health'
```
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
```
litellm --health
```
#### Response
```shell
{
"healthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
},
{
"model": "azure/gpt-35-turbo",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
}
],
"unhealthy_endpoints": [
{
"model": "azure/gpt-35-turbo",
"api_base": "https://openai-france-1234.openai.azure.com/"
}
]
}
``` ```

View file

@ -1,5 +1,8 @@
# Logging - Custom Callbacks, OpenTelemetry, Langfuse import Image from '@theme/IdealImage';
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry
# Logging - Custom Callbacks, OpenTelemetry, Langfuse, Sentry
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB
## Custom Callback Class [Async] ## Custom Callback Class [Async]
Use this when you want to run custom callbacks in `python` Use this when you want to run custom callbacks in `python`
@ -486,3 +489,166 @@ litellm --test
Expected output on Langfuse Expected output on Langfuse
<Image img={require('../../img/langfuse_small.png')} /> <Image img={require('../../img/langfuse_small.png')} />
## Logging Proxy Input/Output - DynamoDB
We will use the `--config` to set
- `litellm.success_callback = ["dynamodb"]`
- `litellm.dynamodb_table_name = "your-table-name"`
This will log all successfull LLM calls to DynamoDB
**Step 1** Set AWS Credentials in .env
```shell
AWS_ACCESS_KEY_ID = ""
AWS_SECRET_ACCESS_KEY = ""
AWS_REGION_NAME = ""
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["dynamodb"]
dynamodb_table_name: your-table-name
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```shell
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "Azure OpenAI GPT-4 East",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
Your logs should be available on DynamoDB
#### Data Logged to DynamoDB /chat/completions
```json
{
"id": {
"S": "chatcmpl-8W15J4480a3fAQ1yQaMgtsKJAicen"
},
"call_type": {
"S": "acompletion"
},
"endTime": {
"S": "2023-12-15 17:25:58.424118"
},
"messages": {
"S": "[{'role': 'user', 'content': 'This is a test'}]"
},
"metadata": {
"S": "{}"
},
"model": {
"S": "gpt-3.5-turbo"
},
"modelParameters": {
"S": "{'temperature': 0.7, 'max_tokens': 100, 'user': 'ishaan-2'}"
},
"response": {
"S": "ModelResponse(id='chatcmpl-8W15J4480a3fAQ1yQaMgtsKJAicen', choices=[Choices(finish_reason='stop', index=0, message=Message(content='Great! What can I assist you with?', role='assistant'))], created=1702641357, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=9, prompt_tokens=11, total_tokens=20))"
},
"startTime": {
"S": "2023-12-15 17:25:56.047035"
},
"usage": {
"S": "Usage(completion_tokens=9, prompt_tokens=11, total_tokens=20)"
},
"user": {
"S": "ishaan-2"
}
}
```
#### Data logged to DynamoDB /embeddings
```json
{
"id": {
"S": "4dec8d4d-4817-472d-9fc6-c7a6153eb2ca"
},
"call_type": {
"S": "aembedding"
},
"endTime": {
"S": "2023-12-15 17:25:59.890261"
},
"messages": {
"S": "['hi']"
},
"metadata": {
"S": "{}"
},
"model": {
"S": "text-embedding-ada-002"
},
"modelParameters": {
"S": "{'user': 'ishaan-2'}"
},
"response": {
"S": "EmbeddingResponse(model='text-embedding-ada-002-v2', data=[{'embedding': [-0.03503197431564331, -0.020601635798811913, -0.015375726856291294,
}
}
```
## Logging Proxy Input/Output - Sentry
If api calls fail (llm/database) you can log those to Sentry:
**Step 1** Install Sentry
```shell
pip install --upgrade sentry-sdk
```
**Step 2**: Save your Sentry_DSN and add `litellm_settings`: `failure_callback`
```shell
export SENTRY_DSN="your-sentry-dsn"
```
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
# other settings
failure_callback: ["sentry"]
general_settings:
database_url: "my-bad-url" # set a fake url to trigger a sentry exception
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```
litellm --test
```

View file

@ -0,0 +1,89 @@
# Fallbacks, Retries, Timeouts, Cooldowns
If a call fails after num_retries, fall back to another model group.
If the error is a context window exceeded error, fall back to a larger model group (if given).
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
**Set via config**
```yaml
model_list:
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8001
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8002
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8003
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
api_key: <my-openai-key>
- model_name: gpt-3.5-turbo-16k
litellm_params:
model: gpt-3.5-turbo-16k
api_key: <my-openai-key>
litellm_settings:
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
```
**Set dynamically**
```bash
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "zephyr-beta",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
"num_retries": 2,
"timeout": 10
}
'
```
## Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-eu
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: <your-key>
timeout: 0.1 # timeout in (seconds)
stream_timeout: 0.01 # timeout for stream requests (seconds)
max_retries: 5
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key:
timeout: 0.1 # timeout in (seconds)
stream_timeout: 0.01 # timeout for stream requests (seconds)
max_retries: 5
```
#### Start Proxy
```shell
$ litellm --config /path/to/config.yaml
```

View file

@ -1,13 +1,13 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy.md)] Local OpenAI Proxy Server # [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy)] Local OpenAI Proxy Server
A fast, and lightweight OpenAI-compatible server to call 100+ LLM APIs. A fast, and lightweight OpenAI-compatible server to call 100+ LLM APIs.
:::info :::info
Docs outdated. New docs 👉 [here](./simple_proxy.md) Docs outdated. New docs 👉 [here](./simple_proxy)
::: :::

View file

@ -366,6 +366,63 @@ router = Router(model_list: Optional[list] = None,
cache_responses=True) cache_responses=True)
``` ```
## Caching across model groups
If you want to cache across 2 different model groups (e.g. azure deployments, and openai), use caching groups.
```python
import litellm, asyncio, time
from litellm import Router
# set os env
os.environ["OPENAI_API_KEY"] = ""
os.environ["AZURE_API_KEY"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_VERSION"] = ""
async def test_acompletion_caching_on_router_caching_groups():
# tests acompletion + caching on router
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "openai-gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "azure-gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
start_time = time.time()
router = Router(model_list=model_list,
cache_responses=True,
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response1: {response1}")
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
assert response1.id == response2.id
assert len(response1.choices[0].message.content) > 0
assert response1.choices[0].message.content == response2.choices[0].message.content
except Exception as e:
traceback.print_exc()
asyncio.run(test_acompletion_caching_on_router_caching_groups())
```
#### Default litellm.completion/embedding params #### Default litellm.completion/embedding params
You can also set default params for litellm completion/embedding calls. Here's how to do that: You can also set default params for litellm completion/embedding calls. Here's how to do that:
@ -391,200 +448,3 @@ print(f"response: {response}")
## Deploy Router ## Deploy Router
If you want a server to load balance across different LLM APIs, use our [OpenAI Proxy Server](./simple_proxy#load-balancing---multiple-instances-of-1-model) If you want a server to load balance across different LLM APIs, use our [OpenAI Proxy Server](./simple_proxy#load-balancing---multiple-instances-of-1-model)
## Queuing (Beta)
**Never fail a request due to rate limits**
The LiteLLM Queuing endpoints can handle 100+ req/s. We use Celery workers to process requests.
:::info
This is pretty new, and might have bugs. Any contributions to improving our implementation are welcome
:::
[**See Code**](https://github.com/BerriAI/litellm/blob/fbf9cab5b9e35df524e2c9953180c58d92e4cd97/litellm/proxy/proxy_server.py#L589)
### Quick Start
1. Add Redis credentials in a .env file
```python
REDIS_HOST="my-redis-endpoint"
REDIS_PORT="my-redis-port"
REDIS_PASSWORD="my-redis-password" # [OPTIONAL] if self-hosted
REDIS_USERNAME="default" # [OPTIONAL] if self-hosted
```
2. Start litellm server with your model config
```bash
$ litellm --config /path/to/config.yaml --use_queue
```
Here's an example config for `gpt-3.5-turbo`
**config.yaml** (This will load balance between OpenAI + Azure endpoints)
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
api_key:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2 # actual model name
api_key:
api_version: 2023-07-01-preview
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
```
3. Test (in another window) → sends 100 simultaneous requests to the queue
```bash
$ litellm --test_async --num_requests 100
```
### Available Endpoints
- `/queue/request` - Queues a /chat/completions request. Returns a job id.
- `/queue/response/{id}` - Returns the status of a job. If completed, returns the response as well. Potential status's are: `queued` and `finished`.
## Hosted Request Queing api.litellm.ai
Queue your LLM API requests to ensure you're under your rate limits
- Step 1: Step 1 Add a config to the proxy, generate a temp key
- Step 2: Queue a request to the proxy, using your generated_key
- Step 3: Poll the request
### Step 1 Add a config to the proxy, generate a temp key
```python
import requests
import time
import os
# Set the base URL as needed
base_url = "https://api.litellm.ai"
# Step 1 Add a config to the proxy, generate a temp key
# use the same model_name to load balance
config = {
"model_list": [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'],
}
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-07-01-preview"
}
}
]
}
response = requests.post(
url=f"{base_url}/key/generate",
json={
"config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp 30 minute key
},
headers={
"Authorization": "Bearer sk-hosted-litellm" # this is the key to use api.litellm.ai
}
)
print("\nresponse from generating key", response.text)
print("\n json response from gen key", response.json())
generated_key = response.json()["key"]
print("\ngenerated key for proxy", generated_key)
```
#### Output
```shell
response from generating key {"key":"sk-...,"expires":"2023-12-22T03:43:57.615000+00:00"}
```
### Step 2: Queue a request to the proxy, using your generated_key
```python
print("Creating a job on the proxy")
job_response = requests.post(
url=f"{base_url}/queue/request",
json={
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'},
],
},
headers={
"Authorization": f"Bearer {generated_key}"
}
)
print(job_response.status_code)
print(job_response.text)
print("\nResponse from creating job", job_response.text)
job_response = job_response.json()
job_id = job_response["id"]
polling_url = job_response["url"]
polling_url = f"{base_url}{polling_url}"
print("\nCreated Job, Polling Url", polling_url)
```
#### Output
```shell
Response from creating job
{"id":"0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7","url":"/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7","eta":5,"status":"queued"}
```
### Step 3: Poll the request
```python
while True:
try:
print("\nPolling URL", polling_url)
polling_response = requests.get(
url=polling_url,
headers={
"Authorization": f"Bearer {generated_key}"
}
)
print("\nResponse from polling url", polling_response.text)
polling_response = polling_response.json()
status = polling_response.get("status", None)
if status == "finished":
llm_response = polling_response["result"]
print("LLM Response")
print(llm_response)
break
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)
break
```
#### Output
```shell
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
Response from polling url {"status":"queued"}
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
Response from polling url {"status":"queued"}
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
Response from polling url
{"status":"finished","result":{"id":"chatcmpl-8NYRce4IeI4NzYyodT3NNp8fk5cSW","choices":[{"finish_reason":"stop","index":0,"message":{"content":"I am an AI assistant and do not have a physical presence or personal identity. You can simply refer to me as \"Assistant.\" How may I assist you today?","role":"assistant"}}],"created":1700624639,"model":"gpt-3.5-turbo-0613","object":"chat.completion","system_fingerprint":null,"usage":{"completion_tokens":33,"prompt_tokens":17,"total_tokens":50}}}
```

View file

@ -61,11 +61,13 @@ const sidebars = {
}, },
items: [ items: [
"providers/openai", "providers/openai",
"providers/openai_compatible",
"providers/azure", "providers/azure",
"providers/huggingface", "providers/huggingface",
"providers/ollama", "providers/ollama",
"providers/vertex", "providers/vertex",
"providers/palm", "providers/palm",
"providers/mistral",
"providers/anthropic", "providers/anthropic",
"providers/aws_sagemaker", "providers/aws_sagemaker",
"providers/bedrock", "providers/bedrock",
@ -97,9 +99,13 @@ const sidebars = {
items: [ items: [
"proxy/quick_start", "proxy/quick_start",
"proxy/configs", "proxy/configs",
"proxy/embedding",
"proxy/load_balancing", "proxy/load_balancing",
"proxy/virtual_keys", "proxy/virtual_keys",
"proxy/model_management", "proxy/model_management",
"proxy/reliability",
"proxy/health",
"proxy/call_hooks",
"proxy/caching", "proxy/caching",
"proxy/logging", "proxy/logging",
"proxy/cli", "proxy/cli",
@ -189,6 +195,7 @@ const sidebars = {
slug: '/project', slug: '/project',
}, },
items: [ items: [
"projects/Docq.AI",
"projects/OpenInterpreter", "projects/OpenInterpreter",
"projects/FastREPL", "projects/FastREPL",
"projects/PROMPTMETHEUS", "projects/PROMPTMETHEUS",

File diff suppressed because it is too large Load diff

View file

@ -10,7 +10,7 @@ success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
@ -48,6 +48,8 @@ cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
_current_cost = 0 # private variable, used if max budget is set _current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
@ -56,6 +58,7 @@ aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
suppress_debug_info = False suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None num_retries: Optional[int] = None
@ -107,6 +110,8 @@ open_ai_text_completion_models: List = []
cohere_models: List = [] cohere_models: List = []
anthropic_models: List = [] anthropic_models: List = []
openrouter_models: List = [] openrouter_models: List = []
vertex_language_models: List = []
vertex_vision_models: List = []
vertex_chat_models: List = [] vertex_chat_models: List = []
vertex_code_chat_models: List = [] vertex_code_chat_models: List = []
vertex_text_models: List = [] vertex_text_models: List = []
@ -133,6 +138,10 @@ for key, value in model_cost.items():
vertex_text_models.append(key) vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models': elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
vertex_code_text_models.append(key) vertex_code_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-language-models':
vertex_language_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-vision-models':
vertex_vision_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-chat-models': elif value.get('litellm_provider') == 'vertex_ai-chat-models':
vertex_chat_models.append(key) vertex_chat_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
@ -154,7 +163,16 @@ for key, value in model_cost.items():
openai_compatible_endpoints: List = [ openai_compatible_endpoints: List = [
"api.perplexity.ai", "api.perplexity.ai",
"api.endpoints.anyscale.com/v1", "api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai" "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1"
]
# this is maintained for Exception Mapping
openai_compatible_providers: List = [
"anyscale",
"mistral",
"deepinfra",
"perplexity"
] ]
@ -266,6 +284,7 @@ model_list = (
provider_list: List = [ provider_list: List = [
"openai", "openai",
"custom_openai", "custom_openai",
"text-completion-openai",
"cohere", "cohere",
"anthropic", "anthropic",
"replicate", "replicate",
@ -287,6 +306,7 @@ provider_list: List = [
"deepinfra", "deepinfra",
"perplexity", "perplexity",
"anyscale", "anyscale",
"mistral",
"maritalk", "maritalk",
"custom", # custom apis "custom", # custom apis
] ]
@ -396,6 +416,7 @@ from .exceptions import (
AuthenticationError, AuthenticationError,
InvalidRequestError, InvalidRequestError,
BadRequestError, BadRequestError,
NotFoundError,
RateLimitError, RateLimitError,
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,
@ -404,7 +425,8 @@ from .exceptions import (
APIError, APIError,
Timeout, Timeout,
APIConnectionError, APIConnectionError,
APIResponseValidationError APIResponseValidationError,
UnprocessableEntityError
) )
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server

View file

@ -10,19 +10,7 @@
import litellm import litellm
import time, logging import time, logging
import json, traceback, ast import json, traceback, ast
from typing import Optional from typing import Optional, Literal, List
def get_prompt(*args, **kwargs):
# make this safe checks, it should not throw any exceptions
if len(args) > 1:
messages = args[1]
prompt = " ".join(message["content"] for message in messages)
return prompt
if "messages" in kwargs:
messages = kwargs["messages"]
prompt = " ".join(message["content"] for message in messages)
return prompt
return None
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
@ -174,34 +162,36 @@ class DualCache(BaseCache):
if self.redis_cache is not None: if self.redis_cache is not None:
self.redis_cache.flush_cache() self.redis_cache.flush_cache()
#### LiteLLM.Completion Cache #### #### LiteLLM.Completion / Embedding Cache ####
class Cache: class Cache:
def __init__( def __init__(
self, self,
type="local", type: Optional[Literal["local", "redis"]] = "local",
host=None, host: Optional[str] = None,
port=None, port: Optional[str] = None,
password=None, password: Optional[str] = None,
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs **kwargs
): ):
""" """
Initializes the cache based on the given type. Initializes the cache based on the given type.
Args: Args:
type (str, optional): The type of cache to initialize. Defaults to "local". type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
host (str, optional): The host address for the Redis cache. Required if type is "redis". host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis".
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache **kwargs: Additional keyword arguments for redis.Redis() cache
Raises: Raises:
ValueError: If an invalid cache type is provided. ValueError: If an invalid cache type is provided.
Returns: Returns:
None None. Cache is set as a litellm param
""" """
if type == "redis": if type == "redis":
self.cache = RedisCache(host, port, password, **kwargs) self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
if type == "local": if type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()
if "cache" not in litellm.input_callback: if "cache" not in litellm.input_callback:
@ -210,6 +200,7 @@ class Cache:
litellm.success_callback.append("cache") litellm.success_callback.append("cache")
if "cache" not in litellm._async_success_callback: if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache") litellm._async_success_callback.append("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
""" """
@ -222,29 +213,55 @@ class Cache:
Returns: Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated. str: The cache key generated from the arguments, or None if no cache key could be generated.
""" """
cache_key ="" cache_key = ""
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
print_verbose(f"\nReturning preset cache key: {cache_key}")
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4] # sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"] completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
for param in completion_kwargs: embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs
for param in combined_kwargs:
# ignore litellm params here # ignore litellm params here
if param in kwargs: if param in kwargs:
# check if param == model and model_group is passed in, then override model with model_group # check if param == model and model_group is passed in, then override model with model_group
if param == "model": if param == "model":
model_group = None model_group = None
caching_group = None
metadata = kwargs.get("metadata", None) metadata = kwargs.get("metadata", None)
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
if metadata is not None: if metadata is not None:
model_group = metadata.get("model_group") model_group = metadata.get("model_group")
model_group = metadata.get("model_group", None)
caching_groups = metadata.get("caching_groups", None)
if caching_groups:
for group in caching_groups:
if model_group in group:
caching_group = group
break
if litellm_params is not None: if litellm_params is not None:
metadata = litellm_params.get("metadata", None) metadata = litellm_params.get("metadata", None)
if metadata is not None: if metadata is not None:
model_group = metadata.get("model_group", None) model_group = metadata.get("model_group", None)
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"] caching_groups = metadata.get("caching_groups", None)
if caching_groups:
for group in caching_groups:
if model_group in group:
caching_group = group
break
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params
param_value = kwargs[param] param_value = kwargs[param]
cache_key+= f"{str(param)}: {str(param_value)}" cache_key+= f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}")
return cache_key return cache_key
def generate_streaming_content(self, content): def generate_streaming_content(self, content):
@ -297,4 +314,9 @@ class Cache:
result = result.model_dump_json() result = result.model_dump_json()
self.cache.set_cache(cache_key, result, **kwargs) self.cache.set_cache(cache_key, result, **kwargs)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
traceback.print_exc()
pass pass
async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs)

View file

@ -12,16 +12,19 @@
from openai import ( from openai import (
AuthenticationError, AuthenticationError,
BadRequestError, BadRequestError,
NotFoundError,
RateLimitError, RateLimitError,
APIStatusError, APIStatusError,
OpenAIError, OpenAIError,
APIError, APIError,
APITimeoutError, APITimeoutError,
APIConnectionError, APIConnectionError,
APIResponseValidationError APIResponseValidationError,
UnprocessableEntityError
) )
import httpx import httpx
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401 self.status_code = 401
@ -34,6 +37,20 @@ class AuthenticationError(AuthenticationError): # type: ignore
body=None body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 404
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class BadRequestError(BadRequestError): # type: ignore class BadRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 400 self.status_code = 400
@ -46,6 +63,18 @@ class BadRequestError(BadRequestError): # type: ignore
body=None body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 408 self.status_code = 408

View file

@ -2,8 +2,9 @@
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests import requests
import requests from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -28,6 +29,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass pass
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs): async def async_log_pre_api_call(self, model, messages, kwargs):
pass pass
@ -37,6 +43,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass pass
#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model
"""
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
pass
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):

View file

@ -0,0 +1,82 @@
#### What this does ####
# On success + failure, log events to Supabase
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose
class DyanmoDBLogger:
# Class variables or attributes
def __init__(self):
# Instance variables
import boto3
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
if litellm.dynamodb_table_name is None:
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`")
self.table_name = litellm.dynamodb_table_name
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
print_verbose(
f"DynamoDB Logging - Enters logging function for model {kwargs}"
)
# construct payload to send to DynamoDB
# follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
# Build the initial payload
payload = {
"id": id,
"call_type": call_type,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata
}
# Ensure everything in the payload is converted to str
for key, value in payload.items():
try:
payload[key] = str(value)
except:
# non blocking if it can't cast to a str
pass
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
# put data in dyanmo DB
table = self.dynamodb.Table(self.table_name)
# Assuming log_data is a dictionary with log information
response = table.put_item(Item=payload)
print_verbose(f"Response from DynamoDB:{str(response)}")
print_verbose(
f"DynamoDB Layer Logging - final response object: {response_obj}"
)
return response
except:
traceback.print_exc()
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass

View file

@ -58,7 +58,7 @@ class LangFuseLogger:
model=kwargs['model'], model=kwargs['model'],
modelParameters=optional_params, modelParameters=optional_params,
prompt=prompt, prompt=prompt,
completion=response_obj['choices'][0]['message'], completion=response_obj['choices'][0]['message'].json(),
usage=Usage( usage=Usage(
prompt_tokens=response_obj['usage']['prompt_tokens'], prompt_tokens=response_obj['usage']['prompt_tokens'],
completion_tokens=response_obj['usage']['completion_tokens'] completion_tokens=response_obj['usage']['completion_tokens']
@ -70,6 +70,9 @@ class LangFuseLogger:
f"Langfuse Layer Logging - final response object: {response_obj}" f"Langfuse Layer Logging - final response object: {response_obj}"
) )
except: except:
# traceback.print_exc() traceback.print_exc()
print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}") print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
pass pass
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)

View file

@ -58,7 +58,7 @@ class LangsmithLogger:
"inputs": { "inputs": {
**new_kwargs **new_kwargs
}, },
"outputs": response_obj, "outputs": response_obj.json(),
"session_name": project_name, "session_name": project_name,
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,

View file

@ -1,7 +1,8 @@
class TraceloopLogger: class TraceloopLogger:
def __init__(self): def __init__(self):
from traceloop.sdk.tracing.tracing import TracerWrapper from traceloop.sdk.tracing.tracing import TracerWrapper
from traceloop.sdk import Traceloop
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
self.tracer_wrapper = TracerWrapper() self.tracer_wrapper = TracerWrapper()
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):

View file

@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
else: else:
azure_client = client azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore response = azure_client.chat.completions.create(**data) # type: ignore
response.model = "azure/" + str(response.model) stringified_response = response.model_dump_json()
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) ## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
@ -318,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
azure_client_params: dict, azure_client_params: dict,
api_key: str,
input: list,
client=None, client=None,
logging_obj=None
): ):
response = None response = None
try: try:
@ -327,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.embeddings.create(**data) response = await openai_aclient.embeddings.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
except Exception as e: except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e raise e
def embedding(self, def embedding(self,
@ -372,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -391,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
} }
}, },
) )
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## COMPLETION CALL ## COMPLETION CALL
response = azure_client.embeddings.create(**data) # type: ignore response = azure_client.embeddings.create(**data) # type: ignore
## LOGGING ## LOGGING

View file

@ -482,7 +482,7 @@ def completion(
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=response_body, original_response=json.dumps(response_body),
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
@ -552,6 +552,7 @@ def _embedding_func_single(
## FORMAT EMBEDDING INPUT ## ## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0] provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
if provider == "amazon": if provider == "amazon":
input = input.replace(os.linesep, " ") input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params} data = {"inputText": input, **inference_params}
@ -587,7 +588,7 @@ def _embedding_func_single(
input=input, input=input,
api_key="", api_key="",
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response_body, original_response=json.dumps(response_body),
) )
if provider == "cohere": if provider == "cohere":
response = response_body.get("embeddings") response = response_body.get("embeddings")
@ -651,13 +652,4 @@ def embedding(
) )
model_response.usage = usage model_response.usage = usage
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": {"model": model,
"texts": input}},
original_response=embeddings,
)
return model_response return model_response

View file

@ -542,7 +542,7 @@ class Huggingface(BaseLLM):
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
@ -584,6 +584,14 @@ class Huggingface(BaseLLM):
"embedding": embedding # flatten list returned from hf "embedding": embedding # flatten list returned from hf
} }
) )
elif isinstance(embedding, list) and isinstance(embedding[0], float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
}
)
else: else:
output_data.append( output_data.append(
{ {

View file

@ -1,10 +1,9 @@
import requests, types import requests, types, time
import json import json
import traceback import traceback
from typing import Optional from typing import Optional
import litellm import litellm
import httpx import httpx, aiohttp, asyncio
try: try:
from async_generator import async_generator, yield_ # optional dependency from async_generator import async_generator, yield_ # optional dependency
async_generator_imported = True async_generator_imported = True
@ -115,6 +114,9 @@ def get_ollama_response_stream(
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
url = api_base url = api_base
@ -136,8 +138,19 @@ def get_ollama_response_stream(
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
api_key=None, api_key=None,
additional_args={"api_base": url, "complete_input_dict": data}, additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
) )
if acompletion is True:
if optional_params.get("stream", False):
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
return response
else:
return ollama_completion_stream(url=url, data=data)
def ollama_completion_stream(url, data):
session = requests.Session() session = requests.Session()
with session.post(url, json=data, stream=True) as resp: with session.post(url, json=data, stream=True) as resp:
@ -169,41 +182,38 @@ def get_ollama_response_stream(
traceback.print_exc() traceback.print_exc()
session.close() session.close()
if async_generator_imported:
# ollama implementation
@async_generator
async def async_get_ollama_response_stream(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
):
url = f"{api_base}/api/generate"
## Load Config async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
config=litellm.OllamaConfig.get_config() try:
for k, v in config.items(): client = httpx.AsyncClient()
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in async with client.stream(
optional_params[k] = v url=f"{url}",
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
data = { streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
"model": model, async for transformed_chunk in streamwrapper:
"prompt": prompt, yield transformed_chunk
**optional_params except Exception as e:
} traceback.print_exc()
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data},
)
session = requests.Session()
with session.post(url, json=data, stream=True) as resp: async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
if resp.status_code != 200: data["stream"] = False
raise OllamaError(status_code=resp.status_code, message=resp.text) try:
for line in resp.iter_lines(): timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
if resp.status != 200:
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
completion_string = ""
async for line in resp.content.iter_any():
if line: if line:
try: try:
json_chunk = line.decode("utf-8") json_chunk = line.decode("utf-8")
@ -217,15 +227,24 @@ if async_generator_imported:
"content": "", "content": "",
"error": j "error": j
} }
await yield_({"choices": [{"delta": completion_obj}]}) raise Exception(f"OllamError - {chunk}")
if "response" in j: if "response" in j:
completion_obj = { completion_obj = {
"role": "assistant", "role": "assistant",
"content": "", "content": j["response"],
} }
completion_obj["content"] = j["response"] completion_string = completion_string + completion_obj["content"]
await yield_({"choices": [{"delta": completion_obj}]})
except Exception as e: except Exception as e:
import logging traceback.print_exc()
logging.debug(f"Error decoding JSON: {e}")
session.close() ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = completion_string
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
completion_tokens = len(encoding.encode(completion_string))
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response
except Exception as e:
traceback.print_exc()

View file

@ -195,6 +195,16 @@ class OpenAIChatCompletion(BaseLLM):
**optional_params **optional_params
} }
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else:
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else:
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
@ -202,16 +212,6 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data}, additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
) )
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
else:
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
if client is None: if client is None:
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_client = client openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore response = openai_client.chat.completions.create(**data) # type: ignore
stringified_response = response.model_dump_json()
logging_obj.post_call( logging_obj.post_call(
input=None, input=messages,
api_key=api_key, api_key=api_key,
original_response=response, original_response=stringified_response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
@ -259,6 +260,8 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None, api_base: Optional[str]=None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None,
headers=None
): ):
response = None response = None
try: try:
@ -266,16 +269,23 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else: else:
openai_aclient = client openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=openai_aclient.api_key,
additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
)
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) stringified_response = response.model_dump_json()
logging_obj.post_call(
input=data['messages'],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e: except Exception as e:
if response and hasattr(response, "text"): raise e
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self, def streaming(self,
logging_obj, logging_obj,
@ -285,12 +295,19 @@ class OpenAIChatCompletion(BaseLLM):
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None, api_base: Optional[str]=None,
client = None, client = None,
max_retries=None max_retries=None,
headers=None
): ):
if client is None: if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else: else:
openai_client = client openai_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
)
response = openai_client.chat.completions.create(**data) response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
return streamwrapper return streamwrapper
@ -304,6 +321,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None, api_base: Optional[str]=None,
client=None, client=None,
max_retries=None, max_retries=None,
headers=None
): ):
response = None response = None
try: try:
@ -311,6 +329,13 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
else: else:
openai_aclient = client openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
)
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
@ -325,6 +350,7 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=500, message=f"{str(e)}") raise OpenAIError(status_code=500, message=f"{str(e)}")
async def aembedding( async def aembedding(
self, self,
input: list,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: float,
@ -332,6 +358,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None, api_base: Optional[str]=None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None
): ):
response = None response = None
try: try:
@ -340,9 +367,24 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.embeddings.create(**data) # type: ignore response = await openai_aclient.embeddings.create(**data) # type: ignore
return response stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
except Exception as e: except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e raise e
def embedding(self, def embedding(self,
model: str, model: str,
input: list, input: list,
@ -367,13 +409,6 @@ class OpenAIChatCompletion(BaseLLM):
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -381,6 +416,14 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base}, additional_args={"complete_input_dict": data, "api_base": api_base},
) )
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## COMPLETION CALL ## COMPLETION CALL
response = openai_client.embeddings.create(**data) # type: ignore response = openai_client.embeddings.create(**data) # type: ignore
## LOGGING ## LOGGING
@ -472,12 +515,14 @@ class OpenAITextCompletion(BaseLLM):
else: else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
# don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = { data = {
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params
} }
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,

View file

@ -73,8 +73,27 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
final_prompt_value="### Response:", final_prompt_value="### Response:",
messages=messages messages=messages
) )
elif "llava" in model:
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
return {
"prompt": prompt,
"images": images
}
else: else:
prompt = "".join(m["content"] for m in messages) prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
return prompt return prompt
def mistral_instruct_pt(messages): def mistral_instruct_pt(messages):
@ -161,6 +180,8 @@ def phind_codellama_pt(messages):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None): def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
## get the tokenizer config from huggingface ## get the tokenizer config from huggingface
bos_token = ""
eos_token = ""
if chat_template is None: if chat_template is None:
def _get_tokenizer_config(hf_model_name): def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
@ -187,7 +208,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
# Create a template object from the template text # Create a template object from the template text
env = Environment() env = Environment()
env.globals['raise_exception'] = raise_exception env.globals['raise_exception'] = raise_exception
try:
template = env.from_string(chat_template) template = env.from_string(chat_template)
except Exception as e:
raise e
def _is_system_in_template(): def _is_system_in_template():
try: try:
@ -227,8 +251,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
new_messages.append(reformatted_messages[-1]) new_messages.append(reformatted_messages[-1])
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages) rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
return rendered_text return rendered_text
except: except Exception as e:
raise Exception("Error rendering template") raise Exception(f"Error rendering template - {str(e)}")
# Anthropic template # Anthropic template
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
@ -266,6 +290,7 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
### TOGETHER AI ### TOGETHER AI
def get_model_info(token, model): def get_model_info(token, model):
try:
headers = { headers = {
'Authorization': f'Bearer {token}' 'Authorization': f'Bearer {token}'
} }
@ -278,8 +303,13 @@ def get_model_info(token, model):
return None, None return None, None
else: else:
return None, None return None, None
except Exception as e: # safely fail a prompt template request
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template): def format_prompt_togetherai(messages, prompt_format, chat_template):
if prompt_format is None:
return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split('{prompt}') human_prompt, assistant_prompt = prompt_format.split('{prompt}')
if chat_template is not None: if chat_template is not None:

View file

@ -232,7 +232,8 @@ def completion(
if system_prompt is not None: if system_prompt is not None:
input_data = { input_data = {
"prompt": prompt, "prompt": prompt,
"system_prompt": system_prompt "system_prompt": system_prompt,
**optional_params
} }
# Otherwise, use the prompt as is # Otherwise, use the prompt as is
else: else:

View file

@ -158,6 +158,7 @@ def completion(
) )
except Exception as e: except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}") raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"].read().decode("utf8") response = response["Body"].read().decode("utf8")
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -171,10 +172,17 @@ def completion(
completion_response = json.loads(response) completion_response = json.loads(response)
try: try:
completion_response_choices = completion_response[0] completion_response_choices = completion_response[0]
completion_output = ""
if "generation" in completion_response_choices: if "generation" in completion_response_choices:
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"] completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices: elif "generated_text" in completion_response_choices:
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"] completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
except: except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500) raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)

View file

@ -173,10 +173,11 @@ def completion(
message=json.dumps(completion_response["output"]), status_code=response.status_code message=json.dumps(completion_response["output"]), status_code=response.status_code
) )
if len(completion_response["output"]["choices"][0]["text"]) > 0: if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"] model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
## CALCULATING USAGE ## CALCULATING USAGE
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}")
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm
import httpx import httpx
@ -57,6 +57,108 @@ class VertexAIConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None} and v is not None}
def _get_image_bytes_from_url(image_url: str) -> bytes:
try:
response = requests.get(image_url)
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx)
image_bytes = response.content
return image_bytes
except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout)
return b'' # Return an empty bytes object or handle the error as needed
def _load_image_from_url(image_url: str):
"""
Loads an image from a URL.
Args:
image_url (str): The URL of the image.
Returns:
Image: The loaded image.
"""
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes)
def _gemini_vision_convert_messages(
messages: list
):
"""
Converts given messages for GPT-4 Vision to Gemini format.
Args:
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
Raises:
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
Exception: If any other exception occurs during the execution of the function.
Note:
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
The supported MIME types for images include 'image/png' and 'image/jpeg'.
Examples:
>>> messages = [
... {"content": "Hello, world!"},
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
... ]
>>> _gemini_vision_convert_messages(messages)
('Hello, world!This is a text message.', [<Part object>, <Part object>])
"""
try:
import vertexai
except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
# processing images passed to gemini
processed_images = []
for img in images:
if "gs://" in img:
# Case 1: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in img else "image/jpeg"
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
processed_images.append(google_clooud_part)
elif "https:/" in img:
# Case 2: Images with direct links
image = _load_image_from_url(img)
processed_images.append(image)
elif ".mp4" in img and "gs://" in img:
# Case 3: Videos with Cloud Storage URIs
part_mime = "video/mp4"
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
processed_images.append(google_clooud_part)
return prompt, processed_images
except Exception as e:
raise e
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -69,6 +171,7 @@ def completion(
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool=False
): ):
try: try:
import vertexai import vertexai
@ -77,6 +180,8 @@ def completion(
try: try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
vertexai.init( vertexai.init(
project=vertex_project, location=vertex_location project=vertex_project, location=vertex_location
@ -90,34 +195,94 @@ def completion(
# vertexai does not use an API key, it looks for credentials.json in the environment # vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
mode = "" mode = ""
request_str = "" request_str = ""
if model in litellm.vertex_chat_models: response_obj = None
chat_model = ChatModel.from_pretrained(model) if model in litellm.vertex_language_models:
llm_model = GenerativeModel(model)
mode = ""
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_vision_models:
llm_model = GenerativeModel(model)
request_str += f"llm_model = GenerativeModel({model})\n"
mode = "vision"
elif model in litellm.vertex_chat_models:
llm_model = ChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
request_str += f"chat_model = ChatModel.from_pretrained({model})\n" request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models: elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model) llm_model = TextGenerationModel.from_pretrained(model)
mode = "text" mode = "text"
request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models: elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model) llm_model = CodeGenerationModel.from_pretrained(model)
mode = "text" mode = "text"
request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
else: # vertex_code_chat_models else: # vertex_code_llm_models
chat_model = CodeChatModel.from_pretrained(model) llm_model = CodeChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
request_str += f"chat_model = CodeChatModel.from_pretrained({model})\n" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
if mode == "chat": if acompletion == True: # [TODO] expand support to vertex ai chat + text models
chat = chat_model.start_chat() if optional_params.get("stream", False) is True:
request_str+= f"chat = chat_model.start_chat()\n" # async streaming
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params)
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params)
if mode == "":
chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params),
stream=True
)
optional_params["stream"] = True
return model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
## LLM Call
response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params)
)
completion_response = response.text
response_obj = response._raw_response
elif mode == "chat":
chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error, # NOTE: VertexAI does not accept stream=True as a param and raises an error,
@ -125,27 +290,30 @@ def completion(
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = chat.send_message_streaming(prompt, **optional_params) model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
completion_response = chat.send_message(prompt, **optional_params).text completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text": elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n" request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = text_model.predict_streaming(prompt, **optional_params) model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"text_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
completion_response = text_model.predict(prompt, **optional_params).text completion_response = llm_model.predict(prompt, **optional_params).text
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -161,6 +329,12 @@ def completion(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len( prompt_tokens = len(
encoding.encode(prompt) encoding.encode(prompt)
) )
@ -177,6 +351,140 @@ def completion(
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
"""
Add support for acompletion calls for gemini-pro
"""
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
## LLM Call
response = await llm_model._generate_content_async(
contents=content,
generation_config=GenerationConfig(**optional_params)
)
completion_response = response.text
response_obj = response._raw_response
elif mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
"""
Add support for async streaming calls for gemini-pro
"""
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
optional_params["stream"] = True
elif mode == "vision":
stream = optional_params.pop("stream")
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = llm_model._generate_content_streaming_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
stream=True
)
optional_params["stream"] = True
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True
elif mode == "text":
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = llm_model.predict_streaming_async(prompt, **optional_params)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls

View file

@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
exception_type, exception_type,
@ -31,7 +32,8 @@ from litellm.utils import (
mock_completion_streaming_obj, mock_completion_streaming_obj,
convert_to_model_response_object, convert_to_model_response_object,
token_counter, token_counter,
Usage Usage,
get_optional_params_embeddings
) )
from .llms import ( from .llms import (
anthropic, anthropic,
@ -171,11 +173,14 @@ async def acompletion(*args, **kwargs):
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openrouter" or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all. or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): if kwargs.get("stream", False):
response = completion(*args, **kwargs) response = completion(*args, **kwargs)
else: else:
@ -200,9 +205,12 @@ async def acompletion(*args, **kwargs):
async def _async_streaming(response, model, custom_llm_provider, args): async def _async_streaming(response, model, custom_llm_provider, args):
try: try:
print_verbose(f"received response in _async_streaming: {response}")
async for line in response: async for line in response:
print_verbose(f"line in async streaming: {line}")
yield line yield line
except Exception as e: except Exception as e:
print_verbose(f"error raised _async_streaming: {traceback.format_exc()}")
raise exception_type( raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
) )
@ -278,7 +286,7 @@ def completion(
# Optional liteLLM function params # Optional liteLLM function params
**kwargs, **kwargs,
) -> ModelResponse: ) -> Union[ModelResponse, CustomStreamWrapper]:
""" """
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
Parameters: Parameters:
@ -319,7 +327,6 @@ def completion(
######### unpacking kwargs ##################### ######### unpacking kwargs #####################
args = locals() args = locals()
api_base = kwargs.get('api_base', None) api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None) mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600) ## deprecated force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None) logger_fn = kwargs.get('logger_fn', None)
@ -344,13 +351,14 @@ def completion(
final_prompt_value = kwargs.get("final_prompt_value", None) final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None) bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None) eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None) hf_model_name = kwargs.get("hf_model_name", None)
### ASYNC CALLS ### ### ASYNC CALLS ###
acompletion = kwargs.get("acompletion", False) acompletion = kwargs.get("acompletion", False)
client = kwargs.get("client", None) client = kwargs.get("client", None)
######## end of unpacking kwargs ########### ######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request"] litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response: if mock_response:
@ -384,7 +392,6 @@ def completion(
model=deployment_id model=deployment_id
custom_llm_provider="azure" custom_llm_provider="azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None: if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model({ litellm.register_model({
@ -448,7 +455,6 @@ def completion(
# For logging - save the values of the litellm-specific params passed in # For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params( litellm_params = get_litellm_params(
acompletion=acompletion, acompletion=acompletion,
return_async=return_async,
api_key=api_key, api_key=api_key,
force_timeout=force_timeout, force_timeout=force_timeout,
logger_fn=logger_fn, logger_fn=logger_fn,
@ -460,7 +466,8 @@ def completion(
completion_call_id=id, completion_call_id=id,
metadata=metadata, metadata=metadata,
model_info=model_info, model_info=model_info,
proxy_server_request=proxy_server_request proxy_server_request=proxy_server_request,
preset_cache_key=preset_cache_key
) )
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params) logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
@ -524,6 +531,7 @@ def completion(
client=client # pass AsyncAzureOpenAI, AzureOpenAI client client=client # pass AsyncAzureOpenAI, AzureOpenAI client
) )
if optional_params.get("stream", False) or acompletion == True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -541,6 +549,7 @@ def completion(
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openai" or custom_llm_provider == "openai"
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
): # allow user to make an openai call with a custom base ): # allow user to make an openai call with a custom base
@ -604,6 +613,7 @@ def completion(
) )
raise e raise e
if optional_params.get("stream", False):
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -616,7 +626,6 @@ def completion(
or "ft:babbage-002" in model or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models or "ft:davinci-002" in model # support for finetuned completion models
): ):
# print("calling custom openai provider")
openai.api_type = "openai" openai.api_type = "openai"
api_base = ( api_base = (
@ -655,17 +664,6 @@ def completion(
prompt = messages[0]["content"] prompt = messages[0]["content"]
else: else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
## LOGGING
logging.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"openai_organization": litellm.organization,
"headers": headers,
"api_base": api_base,
"api_type": openai.api_type,
},
)
## COMPLETION CALL ## COMPLETION CALL
model_response = openai_text_completions.completion( model_response = openai_text_completions.completion(
model=model, model=model,
@ -681,9 +679,14 @@ def completion(
logger_fn=logger_fn logger_fn=logger_fn
) )
# if "stream" in optional_params and optional_params["stream"] == True: if optional_params.get("stream", False) or acompletion == True:
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging) ## LOGGING
# return response logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
additional_args={"headers": headers},
)
response = model_response response = model_response
elif ( elif (
"replicate" in model or "replicate" in model or
@ -728,8 +731,16 @@ def completion(
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
return response
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=replicate_key,
original_response=model_response,
)
response = model_response response = model_response
elif custom_llm_provider=="anthropic": elif custom_llm_provider=="anthropic":
@ -749,7 +760,7 @@ def completion(
custom_prompt_dict custom_prompt_dict
or litellm.custom_prompt_dict or litellm.custom_prompt_dict
) )
model_response = anthropic.completion( response = anthropic.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -765,9 +776,16 @@ def completion(
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging) response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
return response
response = model_response if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
nlp_cloud_key = ( nlp_cloud_key = (
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
@ -780,7 +798,7 @@ def completion(
or "https://api.nlpcloud.io/v1/gpu/" or "https://api.nlpcloud.io/v1/gpu/"
) )
model_response = nlp_cloud.completion( response = nlp_cloud.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -796,9 +814,17 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging) response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
return response
response = model_response if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":
aleph_alpha_key = ( aleph_alpha_key = (
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
@ -1100,7 +1126,7 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models: elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (litellm.vertex_project vertex_ai_project = (litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")) or get_secret("VERTEXAI_PROJECT"))
vertex_ai_location = (litellm.vertex_location vertex_ai_location = (litellm.vertex_location
@ -1117,10 +1143,11 @@ def completion(
encoding=encoding, encoding=encoding,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
logging_obj=logging logging_obj=logging,
acompletion=acompletion
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True and acompletion == False:
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
) )
@ -1186,6 +1213,7 @@ def completion(
# "SageMaker is currently not supporting streaming responses." # "SageMaker is currently not supporting streaming responses."
# fake streaming for sagemaker # fake streaming for sagemaker
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
resp_string = model_response["choices"][0]["message"]["content"] resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper( response = CustomStreamWrapper(
resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging
@ -1200,7 +1228,7 @@ def completion(
custom_prompt_dict custom_prompt_dict
or litellm.custom_prompt_dict or litellm.custom_prompt_dict
) )
model_response = bedrock.completion( response = bedrock.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict, custom_prompt_dict=litellm.custom_prompt_dict,
@ -1218,16 +1246,24 @@ def completion(
# don't try to access stream object, # don't try to access stream object,
if "ai21" in model: if "ai21" in model:
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="bedrock", logging_obj=logging response, model, custom_llm_provider="bedrock", logging_obj=logging
) )
else: else:
response = CustomStreamWrapper( response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
) )
return response
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = response
elif custom_llm_provider == "vllm": elif custom_llm_provider == "vllm":
model_response = vllm.completion( model_response = vllm.completion(
model=model, model=model,
@ -1273,14 +1309,18 @@ def completion(
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
## LOGGING if isinstance(prompt, dict):
if kwargs.get('acompletion', False) == True: # for multimode models - ollama/llava prompt_factory returns a dict {
if optional_params.get("stream", False) == True: # "prompt": prompt,
# assume all ollama responses are streamed # "images": images
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging) # }
return async_generator prompt, images = prompt["prompt"], prompt["images"]
optional_params["images"] = images
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging) ## LOGGING
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True:
return generator
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed # assume all ollama responses are streamed
response = CustomStreamWrapper( response = CustomStreamWrapper(
@ -1716,8 +1756,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter" or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"): # currently implemented aiohttp calls for just azure and openai, soon all.
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
@ -1781,22 +1820,21 @@ def embedding(
rpm = kwargs.pop("rpm", None) rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None) tpm = kwargs.pop("tpm", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None) aembedding = kwargs.get("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"] openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info"] litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {}
for param in non_default_params:
optional_params[param] = kwargs[param]
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params)
try: try:
response = None response = None
logging = litellm_logging_obj logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info}) logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}})
if azure == True or custom_llm_provider == "azure": if azure == True or custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -1936,7 +1974,7 @@ def embedding(
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=input, input=input,
api_key=openai.api_key, api_key=api_key,
original_response=str(e), original_response=str(e),
) )
## Map to OpenAI Exception ## Map to OpenAI Exception
@ -1948,6 +1986,59 @@ def embedding(
###### Text Completion ################ ###### Text Completion ################
async def atext_completion(*args, **kwargs):
"""
Implemented to handle async streaming for the text completion endpoint
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO COMPLETION ###
kwargs["acompletion"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(text_completion, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if (custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = text_completion(*args, **kwargs)
else:
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
else:
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
def text_completion( def text_completion(
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for. prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
model: Optional[str]=None, # Optional: either `model` or `engine` can be set model: Optional[str]=None, # Optional: either `model` or `engine` can be set
@ -2079,7 +2170,7 @@ def text_completion(
*args, *args,
**all_params, **all_params,
) )
#print(response)
text_completion_response["id"] = response.get("id", None) text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion" text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None) text_completion_response["created"] = response.get("created", None)
@ -2294,6 +2385,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
completion_output = combined_content completion_output = combined_content
elif len(combined_arguments) > 0: elif len(combined_arguments) > 0:
completion_output = combined_arguments completion_output = combined_arguments
else:
completion_output = ""
# # Update usage information if needed # # Update usage information if needed
try: try:
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages) response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)

View file

@ -41,6 +41,20 @@
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat" "mode": "chat"
}, },
"gpt-4-1106-preview": {
"max_tokens": 128000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-4-vision-preview": {
"max_tokens": 128000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"max_tokens": 4097, "max_tokens": 4097,
"input_cost_per_token": 0.0000015, "input_cost_per_token": 0.0000015,
@ -62,6 +76,13 @@
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat" "mode": "chat"
}, },
"gpt-3.5-turbo-1106": {
"max_tokens": 16385,
"input_cost_per_token": 0.0000010,
"output_cost_per_token": 0.0000020,
"litellm_provider": "openai",
"mode": "chat"
},
"gpt-3.5-turbo-16k": { "gpt-3.5-turbo-16k": {
"max_tokens": 16385, "max_tokens": 16385,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
@ -76,6 +97,62 @@
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat" "mode": "chat"
}, },
"ft:gpt-3.5-turbo": {
"max_tokens": 4097,
"input_cost_per_token": 0.000012,
"output_cost_per_token": 0.000016,
"litellm_provider": "openai",
"mode": "chat"
},
"text-embedding-ada-002": {
"max_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000,
"litellm_provider": "openai",
"mode": "embedding"
},
"azure/gpt-4-1106-preview": {
"max_tokens": 128000,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-4-32k": {
"max_tokens": 8192,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-4": {
"max_tokens": 16385,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-3.5-turbo-16k": {
"max_tokens": 16385,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000004,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/gpt-3.5-turbo": {
"max_tokens": 4097,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002,
"litellm_provider": "azure",
"mode": "chat"
},
"azure/text-embedding-ada-002": {
"max_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000,
"litellm_provider": "azure",
"mode": "embedding"
},
"text-davinci-003": { "text-davinci-003": {
"max_tokens": 4097, "max_tokens": 4097,
"input_cost_per_token": 0.000002, "input_cost_per_token": 0.000002,
@ -127,6 +204,7 @@
}, },
"claude-instant-1": { "claude-instant-1": {
"max_tokens": 100000, "max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163, "input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551, "output_cost_per_token": 0.00000551,
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
@ -134,15 +212,25 @@
}, },
"claude-instant-1.2": { "claude-instant-1.2": {
"max_tokens": 100000, "max_tokens": 100000,
"input_cost_per_token": 0.00000163, "max_output_tokens": 8191,
"output_cost_per_token": 0.00000551, "input_cost_per_token": 0.000000163,
"output_cost_per_token": 0.000000551,
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat" "mode": "chat"
}, },
"claude-2": { "claude-2": {
"max_tokens": 100000, "max_tokens": 100000,
"input_cost_per_token": 0.00001102, "max_output_tokens": 8191,
"output_cost_per_token": 0.00003268, "input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-2.1": {
"max_tokens": 200000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat" "mode": "chat"
}, },
@ -227,9 +315,51 @@
"max_tokens": 32000, "max_tokens": 32000,
"input_cost_per_token": 0.000000125, "input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125,
"litellm_provider": "vertex_ai-chat-models", "litellm_provider": "vertex_ai-code-chat-models",
"mode": "chat" "mode": "chat"
}, },
"palm/chat-bison": {
"max_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "chat"
},
"palm/chat-bison-001": {
"max_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "chat"
},
"palm/text-bison": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"palm/text-bison-001": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"palm/text-bison-safety-off": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"palm/text-bison-safety-recitation-off": {
"max_tokens": 8196,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125,
"litellm_provider": "palm",
"mode": "completion"
},
"command-nightly": { "command-nightly": {
"max_tokens": 4096, "max_tokens": 4096,
"input_cost_per_token": 0.000015, "input_cost_per_token": 0.000015,
@ -267,6 +397,8 @@
}, },
"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": { "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": {
"max_tokens": 4096, "max_tokens": 4096,
"input_cost_per_token": 0.0000,
"output_cost_per_token": 0.0000,
"litellm_provider": "replicate", "litellm_provider": "replicate",
"mode": "chat" "mode": "chat"
}, },
@ -293,6 +425,7 @@
}, },
"openrouter/anthropic/claude-instant-v1": { "openrouter/anthropic/claude-instant-v1": {
"max_tokens": 100000, "max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163, "input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551, "output_cost_per_token": 0.00000551,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
@ -300,6 +433,7 @@
}, },
"openrouter/anthropic/claude-2": { "openrouter/anthropic/claude-2": {
"max_tokens": 100000, "max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00001102, "input_cost_per_token": 0.00001102,
"output_cost_per_token": 0.00003268, "output_cost_per_token": 0.00003268,
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
@ -496,20 +630,31 @@
}, },
"anthropic.claude-v1": { "anthropic.claude-v1": {
"max_tokens": 100000, "max_tokens": 100000,
"input_cost_per_token": 0.00001102, "max_output_tokens": 8191,
"output_cost_per_token": 0.00003268, "input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"anthropic.claude-v2": { "anthropic.claude-v2": {
"max_tokens": 100000, "max_tokens": 100000,
"input_cost_per_token": 0.00001102, "max_output_tokens": 8191,
"output_cost_per_token": 0.00003268, "input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v2:1": {
"max_tokens": 200000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"anthropic.claude-instant-v1": { "anthropic.claude-instant-v1": {
"max_tokens": 100000, "max_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163, "input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551, "output_cost_per_token": 0.00000551,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -529,26 +674,80 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"meta.llama2-70b-chat-v1": {
"max_tokens": 4096,
"input_cost_per_token": 0.00000195,
"output_cost_per_token": 0.00000256,
"litellm_provider": "bedrock",
"mode": "chat"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "completion"
},
"sagemaker/meta-textgeneration-llama-2-7b-f": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "chat"
},
"sagemaker/meta-textgeneration-llama-2-13b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "completion"
},
"sagemaker/meta-textgeneration-llama-2-13b-f": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "chat"
},
"sagemaker/meta-textgeneration-llama-2-70b": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "completion"
},
"sagemaker/meta-textgeneration-llama-2-70b-b-f": {
"max_tokens": 4096,
"input_cost_per_token": 0.000,
"output_cost_per_token": 0.000,
"litellm_provider": "sagemaker",
"mode": "chat"
},
"together-ai-up-to-3b": { "together-ai-up-to-3b": {
"input_cost_per_token": 0.0000001, "input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000001 "output_cost_per_token": 0.0000001,
"litellm_provider": "together_ai"
}, },
"together-ai-3.1b-7b": { "together-ai-3.1b-7b": {
"input_cost_per_token": 0.0000002, "input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002 "output_cost_per_token": 0.0000002,
"litellm_provider": "together_ai"
}, },
"together-ai-7.1b-20b": { "together-ai-7.1b-20b": {
"max_tokens": 1000, "max_tokens": 1000,
"input_cost_per_token": 0.0000004, "input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004 "output_cost_per_token": 0.0000004,
"litellm_provider": "together_ai"
}, },
"together-ai-20.1b-40b": { "together-ai-20.1b-40b": {
"input_cost_per_token": 0.000001, "input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.000001 "output_cost_per_token": 0.0000008,
"litellm_provider": "together_ai"
}, },
"together-ai-40.1b-70b": { "together-ai-40.1b-70b": {
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.000003 "output_cost_per_token": 0.0000009,
"litellm_provider": "together_ai"
}, },
"ollama/llama2": { "ollama/llama2": {
"max_tokens": 4096, "max_tokens": 4096,
@ -578,10 +777,38 @@
"litellm_provider": "ollama", "litellm_provider": "ollama",
"mode": "completion" "mode": "completion"
}, },
"ollama/mistral": {
"max_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/codellama": {
"max_tokens": 4096,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/orca-mini": {
"max_tokens": 4096,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/vicuna": {
"max_tokens": 2048,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "completion"
},
"deepinfra/meta-llama/Llama-2-70b-chat-hf": { "deepinfra/meta-llama/Llama-2-70b-chat-hf": {
"max_tokens": 6144, "max_tokens": 4096,
"input_cost_per_token": 0.000001875, "input_cost_per_token": 0.000000700,
"output_cost_per_token": 0.000001875, "output_cost_per_token": 0.000000950,
"litellm_provider": "deepinfra", "litellm_provider": "deepinfra",
"mode": "chat" "mode": "chat"
}, },
@ -619,5 +846,103 @@
"output_cost_per_token": 0.00000095, "output_cost_per_token": 0.00000095,
"litellm_provider": "deepinfra", "litellm_provider": "deepinfra",
"mode": "chat" "mode": "chat"
},
"perplexity/pplx-7b-chat": {
"max_tokens": 8192,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-70b-chat": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-7b-online": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.0005,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/pplx-70b-online": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.0005,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-2-13b-chat": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/llama-2-70b-chat": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/mistral-7b-instruct": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"perplexity/replit-code-v1.5-3b": {
"max_tokens": 4096,
"input_cost_per_token": 0.0000000,
"output_cost_per_token": 0.000000,
"litellm_provider": "perplexity",
"mode": "chat"
},
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
"max_tokens": 16384,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
"max_tokens": 16384,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
"max_tokens": 4096,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/meta-llama/Llama-2-13b-chat-hf": {
"max_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000025,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/meta-llama/Llama-2-70b-chat-hf": {
"max_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/codellama/CodeLlama-34b-Instruct-hf": {
"max_tokens": 16384,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "anyscale",
"mode": "chat"
} }
} }

View file

@ -0,0 +1,4 @@
def my_custom_rule(input): # receives the model response
# if len(input) < 5: # trigger fallback if the model response is too short
return False
return True

View file

@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
######### Request Class Definition ###### ######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel): class ProxyChatCompletionRequest(LiteLLMBase):
model: str model: str
messages: List[Dict[str, str]] messages: List[Dict[str, str]]
temperature: Optional[float] = None temperature: Optional[float] = None
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
class Config: class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(BaseModel): class ModelInfoDelete(LiteLLMBase):
id: Optional[str] id: Optional[str]
class ModelInfo(BaseModel): class ModelInfo(LiteLLMBase):
id: Optional[str] id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']] mode: Optional[Literal['embedding', 'chat', 'completion']]
input_cost_per_token: Optional[float] input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model # for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json # we look up the base model in model_prices_and_context_window.json
@ -66,37 +79,40 @@ class ModelInfo(BaseModel):
extra = Extra.allow # Allow extra fields extra = Extra.allow # Allow extra fields
protected_namespaces = () protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values): @root_validator(pre=True)
# if values.get("id") is None: def set_model_info(cls, values):
# values.update({"id": str(uuid.uuid4())}) if values.get("id") is None:
# if values.get("mode") is None: values.update({"id": str(uuid.uuid4())})
# values.update({"mode": str(uuid.uuid4())}) if values.get("mode") is None:
# return values values.update({"mode": None})
if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None:
values.update({"max_tokens": None})
if values.get("base_model") is None:
values.update({"base_model": None})
return values
class ModelParams(BaseModel): class ModelParams(LiteLLMBase):
model_name: str model_name: str
litellm_params: dict litellm_params: dict
model_info: Optional[ModelInfo]=None model_info: ModelInfo
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
# self.model_name = model_name
# self.litellm_params = litellm_params
# self.model_info = model_info if model_info else ModelInfo()
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
# @root_validator(pre=True) @root_validator(pre=True)
# def set_model_info(cls, values): def set_model_info(cls, values):
# if values.get("model_info") is None: if values.get("model_info") is None:
# values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
# return values return values
class GenerateKeyRequest(BaseModel): class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h" duration: Optional[str] = "1h"
models: Optional[list] = [] models: Optional[list] = []
aliases: Optional[dict] = {} aliases: Optional[dict] = {}
@ -105,26 +121,32 @@ class GenerateKeyRequest(BaseModel):
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
def json(self, **kwargs): class UpdateKeyRequest(LiteLLMBase):
try: key: str
return self.model_dump() # noqa duration: Optional[str] = None
except: models: Optional[list] = None
# if using pydantic v1 aliases: Optional[dict] = None
return self.dict() config: Optional[dict] = None
spend: Optional[float] = None
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(BaseModel): class GenerateKeyResponse(LiteLLMBase):
key: str key: str
expires: datetime expires: datetime
user_id: str user_id: str
class _DeleteKeyObject(BaseModel):
class _DeleteKeyObject(LiteLLMBase):
key: str key: str
class DeleteKeyRequest(BaseModel): class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject] keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
""" """
Return the row in the db Return the row in the db
""" """
@ -137,7 +159,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
duration: str = "1h" duration: str = "1h"
class ConfigGeneralSettings(BaseModel): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
""" """
@ -153,10 +175,12 @@ class ConfigGeneralSettings(BaseModel):
health_check_interval: int = Field(300, description="background health check interval in seconds") health_check_interval: int = Field(300, description="background health check interval in seconds")
class ConfigYAML(BaseModel): class ConfigYAML(LiteLLMBase):
""" """
Documents all the fields supported by the config.yaml Documents all the fields supported by the config.yaml
""" """
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs") model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache") litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
general_settings: Optional[ConfigGeneralSettings] = None general_settings: Optional[ConfigGeneralSettings] = None
class Config:
protected_namespaces = ()

View file

@ -1,3 +1,11 @@
import sys, os, traceback
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import litellm import litellm
import inspect import inspect
@ -37,8 +45,11 @@ class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!") print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!") print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj)
assert response_cost > 0.0
return return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):

View file

@ -69,7 +69,6 @@ async def _perform_health_check(model_list: list):
for model in model_list: for model in model_list:
litellm_params = model["litellm_params"] litellm_params = model["litellm_params"]
model_info = model.get("model_info", {}) model_info = model.get("model_info", {})
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
litellm_params["messages"] = _get_random_llm_message() litellm_params["messages"] = _get_random_llm_message()
prepped_params.append(litellm_params) prepped_params.append(litellm_params)

View file

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
import litellm import litellm
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
@ -14,24 +15,28 @@ class MaxParallelRequestsHandler(CustomLogger):
print(print_statement) # noqa print(print_statement) # noqa
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache): async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
if api_key is None: if api_key is None:
return return
if max_parallel_requests is None: if max_parallel_requests is None:
return return
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value self.user_api_key_cache = cache # save the api key cache for updating the value
# CHECK IF REQUEST ALLOWED # CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key) current = cache.get_cache(key=request_count_api_key)
self.print_verbose(f"current: {current}") self.print_verbose(f"current: {current}")
if current is None: if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1) cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests: elif int(current) < max_parallel_requests:
# Increase count for this token # Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1) cache.set_cache(request_count_api_key, int(current) + 1)
else: else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.") raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
@ -55,11 +60,19 @@ class MaxParallelRequestsHandler(CustomLogger):
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa
async def async_log_failure_call(self, api_key, user_api_key_cache): async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception):
try: try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
api_key = user_api_key_dict.api_key
if api_key is None: if api_key is None:
return return
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token # Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1

View file

@ -3,6 +3,7 @@ import subprocess, traceback, json
import os, sys import os, sys
import random, appdirs import random, appdirs
from datetime import datetime from datetime import datetime
import importlib
from dotenv import load_dotenv from dotenv import load_dotenv
import operator import operator
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
@ -76,13 +77,14 @@ def is_port_in_use(port):
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`') @click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`') @click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') @click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version')
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.') @click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml') @click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to') @click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response') @click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with') @click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
@click.option('--local', is_flag=True, default=False, help='for local debugging') @click.option('--local', is_flag=True, default=False, help='for local debugging')
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health): def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version):
global feature_telemetry global feature_telemetry
args = locals() args = locals()
if local: if local:
@ -113,6 +115,10 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
except: except:
raise Exception("LiteLLM: No logs saved!") raise Exception("LiteLLM: No logs saved!")
return return
if version == True:
pkg_version = importlib.metadata.version("litellm")
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n')
return
if model and "ollama" in model and api_base is None: if model and "ollama" in model and api_base is None:
run_ollama_serve() run_ollama_serve()
if test_async is True: if test_async is True:

View file

@ -11,8 +11,10 @@ model_list:
output_cost_per_token: 0.00003 output_cost_per_token: 0.00003
max_tokens: 4096 max_tokens: 4096
base_model: gpt-3.5-turbo base_model: gpt-3.5-turbo
- model_name: BEDROCK_GROUP
- model_name: openai-gpt-3.5 litellm_params:
model: bedrock/cohere.command-text-v14
- model_name: Azure OpenAI GPT-4 Canada-East (External)
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
@ -41,11 +43,12 @@ model_list:
mode: completion mode: completion
litellm_settings: litellm_settings:
# cache: True
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
model_group_alias_map: {"gpt-4": "openai-gpt-3.5"} # all requests with gpt-4 model_name, get sent to openai-gpt-3.5
general_settings: general_settings:
environment_variables:
# otel: True # OpenTelemetry Logger # otel: True # OpenTelemetry Logger
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) # master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)

View file

@ -195,8 +195,10 @@ prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
user_custom_auth = None user_custom_auth = None
use_background_health_checks = None use_background_health_checks = None
use_queue = False
health_check_interval = None health_check_interval = None
health_check_results = {} health_check_results = {}
queue: List = []
### INITIALIZE GLOBAL LOGGING OBJECT ### ### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ### ### REDIS QUEUE ###
@ -252,23 +254,27 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if api_key is None: # only require api key if master key is set if api_key is None: # only require api key if master key is set
raise Exception(f"No api key passed in.") raise Exception(f"No api key passed in.")
route = request.url.path route: str = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid: if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key) return UserAPIKeyAuth(api_key=master_key)
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: if route.startswith("/key/") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
if prisma_client:
## check for cache hit (In-Memory Cache) ## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key) valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}") print(f"valid_token from cache: {valid_token}")
if valid_token is None: if valid_token is None:
## check db ## check db
print(f"api key: {api_key}")
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
print(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:
print(f"API Key Cache Hit!") print(f"API Key Cache Hit!")
@ -285,7 +291,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
try:
data = await request.json() data = await request.json()
except json.JSONDecodeError:
data = {} # Provide a default value, such as an empty dictionary
model = data.get("model", None) model = data.get("model", None)
if model in litellm.model_alias_map: if model in litellm.model_alias_map:
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
@ -310,24 +319,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
def prisma_setup(database_url: Optional[str]): def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj, user_api_key_cache global prisma_client, proxy_logging_obj, user_api_key_cache
proxy_logging_obj._init_litellm_callbacks()
if database_url is not None: if database_url is not None:
try: try:
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj) prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
except Exception as e: except Exception as e:
print("Error when initializing prisma, Ensure you run pip install prisma", e) print("Error when initializing prisma, Ensure you run pip install prisma", e)
def celery_setup(use_queue: bool):
global celery_fn, celery_app_conn, async_result
if use_queue:
from litellm.proxy.queue.celery_worker import start_worker
from litellm.proxy.queue.celery_app import celery_app, process_job
from celery.result import AsyncResult
start_worker(os.getcwd())
celery_fn = process_job
async_result = AsyncResult
celery_app_conn = celery_app
def load_from_azure_key_vault(use_azure_key_vault: bool = False): def load_from_azure_key_vault(use_azure_key_vault: bool = False):
if use_azure_key_vault is False: if use_azure_key_vault is False:
return return
@ -380,30 +377,14 @@ async def track_cost_callback(
if "complete_streaming_response" in kwargs: if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"] completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"] response_cost = litellm.completion_cost(completion_response=completion_response)
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print("streaming response_cost", response_cost) print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client: if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost) await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
input_text = kwargs.get("messages", "") response_cost = litellm.completion_cost(completion_response=completion_response)
print(f"type of input_text: {type(input_text)}")
if isinstance(input_text, list):
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
elif isinstance(input_text, str):
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
print(f"received completion response: {completion_response}")
print(f"regular response_cost: {response_cost}")
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client: if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost) await update_prisma_database(token=user_api_key, response_cost=response_cost)
except Exception as e: except Exception as e:
@ -459,7 +440,7 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval) await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
config = {} config = {}
try: try:
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
@ -504,6 +485,18 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
cache_port = litellm.get_secret("REDIS_PORT", None) cache_port = litellm.get_secret("REDIS_PORT", None)
cache_password = litellm.get_secret("REDIS_PASSWORD", None) cache_password = litellm.get_secret("REDIS_PASSWORD", None)
cache_params = {
"type": cache_type,
"host": cache_host,
"port": cache_port,
"password": cache_password
}
if "cache_params" in litellm_settings:
cache_params_in_config = litellm_settings["cache_params"]
# overwrie cache_params with cache_params_in_config
cache_params.update(cache_params_in_config)
# Assuming cache_type, cache_host, cache_port, and cache_password are strings # Assuming cache_type, cache_host, cache_port, and cache_password are strings
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}") print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
print(f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}") print(f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}")
@ -513,15 +506,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache( litellm.cache = Cache(
type=cache_type, **cache_params
host=cache_host,
port=cache_port,
password=cache_password
) )
print(f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}") print(f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}")
elif key == "callbacks": elif key == "callbacks":
litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)] litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)]
print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}") print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}")
elif key == "post_call_rules":
litellm.post_call_rules = [get_instance_fn(value=value, config_file_path=config_file_path)]
print(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback": elif key == "success_callback":
litellm.success_callback = [] litellm.success_callback = []
@ -533,10 +526,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
# these are litellm callbacks - "langfuse", "sentry", "wandb" # these are litellm callbacks - "langfuse", "sentry", "wandb"
else: else:
litellm.success_callback.append(callback) litellm.success_callback.append(callback)
if callback == "traceloop":
from traceloop.sdk import Traceloop
print_verbose(f"{blue_color_code} Initializing Traceloop SDK - \nRunning:`Traceloop.init(app_name='Litellm-Server', disable_batch=True)`")
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}") print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}")
elif key == "failure_callback": elif key == "failure_callback":
litellm.failure_callback = [] litellm.failure_callback = []
@ -550,6 +539,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
else: else:
litellm.failure_callback.append(callback) litellm.failure_callback.append(callback)
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}") print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}")
elif key == "cache_params":
# this is set in the cache branch
# see usage here: https://docs.litellm.ai/docs/proxy/caching
pass
else: else:
setattr(litellm, key, value) setattr(litellm, key, value)
@ -572,7 +565,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
cost_tracking() cost_tracking()
### START REDIS QUEUE ### ### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False) use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue)
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get("master_key", None) master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
@ -683,6 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List): async def delete_verification_token(tokens: List):
global prisma_client global prisma_client
try: try:
@ -761,8 +755,6 @@ def initialize(
if max_budget: # litellm-specific param if max_budget: # litellm-specific param
litellm.max_budget = max_budget litellm.max_budget = max_budget
dynamic_config["general"]["max_budget"] = max_budget dynamic_config["general"]["max_budget"] = max_budget
if use_queue:
celery_setup(use_queue=use_queue)
if experimental: if experimental:
pass pass
user_telemetry = telemetry user_telemetry = telemetry
@ -798,48 +790,12 @@ def data_generator(response):
async def async_data_generator(response, user_api_key_dict): async def async_data_generator(response, user_api_key_dict):
print_verbose("inside generator") print_verbose("inside generator")
async for chunk in response: async for chunk in response:
# try:
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
# except Exception as e:
# print(f"An exception occurred - {str(e)}")
print_verbose(f"returned chunk: {chunk}") print_verbose(f"returned chunk: {chunk}")
try: try:
yield f"data: {json.dumps(chunk.dict())}\n\n" yield f"data: {json.dumps(chunk.dict())}\n\n"
except: except:
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
def litellm_completion(*args, **kwargs):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
call_type = kwargs.pop("call_type")
# override with user settings, these are params passed via cli
if user_temperature:
kwargs["temperature"] = user_temperature
if user_request_timeout:
kwargs["request_timeout"] = user_request_timeout
if user_max_tokens:
kwargs["max_tokens"] = user_max_tokens
if user_api_base:
kwargs["api_base"] = user_api_base
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
try:
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
if call_type == "chat_completion":
response = llm_router.completion(*args, **kwargs)
elif call_type == "text_completion":
response = llm_router.text_completion(*args, **kwargs)
else:
if call_type == "chat_completion":
response = litellm.completion(*args, **kwargs)
elif call_type == "text_completion":
response = litellm.text_completion(*args, **kwargs)
except Exception as e:
raise e
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
def get_litellm_model_info(model: dict = {}): def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {}) model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None) model_to_lookup = model.get("litellm_params", {}).get("model", None)
@ -870,6 +826,8 @@ async def startup_event():
initialize(**worker_config) initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if use_background_health_checks: if use_background_health_checks:
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine. asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
@ -881,16 +839,6 @@ async def startup_event():
# add master key to db # add master key to db
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key) await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client, master_key, user_custom_auth
if prisma_client:
print("Disconnecting from Prisma")
await prisma_client.disconnect()
## RESET CUSTOM VARIABLES ##
master_key = None
user_custom_auth = None
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@ -929,7 +877,8 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
try: try:
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
@ -938,7 +887,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["user"] = user_api_key_dict.user_id data["user"] = data.get("user", user_api_key_dict.user_id)
data["model"] = ( data["model"] = (
general_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -947,17 +896,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
) )
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
data["call_type"] = "text_completion"
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
return litellm_completion( # override with user settings, these are params passed via cli
**data if user_temperature:
) data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.atext_completion(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion(**data, specific_deployment = True)
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.atext_completion(**data)
else: # router is not set
response = await litellm.atext_completion(**data)
print(f"final response: {response}")
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e: except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
traceback.print_exc()
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
try: try:
@ -995,7 +971,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
) )
# users can pass in 'user' param to /chat/completions. Don't override it # users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None: if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data` # if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id data["user"] = user_api_key_dict.user_id
@ -1027,7 +1003,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
response = await llm_router.acompletion(**data) response = await llm_router.acompletion(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(**data, specific_deployment = True) response = await llm_router.acompletion(**data, specific_deployment = True)
elif llm_router is not None and litellm.model_group_alias_map is not None and data["model"] in litellm.model_group_alias_map: # model set in model_group_alias_map elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.acompletion(**data) response = await llm_router.acompletion(**data)
else: # router is not set else: # router is not set
response = await litellm.acompletion(**data) response = await litellm.acompletion(**data)
@ -1088,7 +1064,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
"body": copy.copy(data) # use copy instead of deepcopy "body": copy.copy(data) # use copy instead of deepcopy
} }
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id data["user"] = user_api_key_dict.user_id
data["model"] = ( data["model"] = (
general_settings.get("embedding_model", None) # server default general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -1098,10 +1076,11 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
data["model"] = user_model data["model"] = user_model
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
print(f"received data: {data['input']}")
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
# check if non-openai/azure model called - e.g. for langchain integration # check if non-openai/azure model called - e.g. for langchain integration
if llm_model_list is not None and data["model"] in router_model_names: if llm_model_list is not None and data["model"] in router_model_names:
@ -1119,12 +1098,13 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
### CALL HOOKS ### - modify incoming data / reject request before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data) response = await llm_router.aembedding(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.aembedding(**data, specific_deployment = True) response = await llm_router.aembedding(**data, specific_deployment = True)
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.aembedding(**data) # ensure this goes the llm_router, router will do the correct alias mapping
else: else:
response = await litellm.aembedding(**data) response = await litellm.aembedding(**data)
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
@ -1133,7 +1113,19 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
except Exception as e: except Exception as e:
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc() traceback.print_exc()
if isinstance(e, HTTPException):
raise e raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
#### KEY MANAGEMENT #### #### KEY MANAGEMENT ####
@ -1162,6 +1154,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def update_key_fn(request: Request, data: UpdateKeyRequest):
"""
Update an existing key
"""
global prisma_client
try:
data_json: dict = data.json()
key = data_json.pop("key")
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
non_default_values = {k: v for k, v in data_json.items() if v is not None}
print(f"non_default_values: {non_default_values}")
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
return {"key": key, **non_default_values}
# update based on remaining passed in values
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest): async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try: try:
@ -1207,10 +1223,12 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
# Add the new model to the config # Add the new model to the config
model_info = model_params.model_info.json()
model_info = {k: v for k, v in model_info.items() if v is not None}
config['model_list'].append({ config['model_list'].append({
'model_name': model_params.model_name, 'model_name': model_params.model_name,
'litellm_params': model_params.litellm_params, 'litellm_params': model_params.litellm_params,
'model_info': model_params.model_info 'model_info': model_info
}) })
# Save the updated config # Save the updated config
@ -1228,7 +1246,7 @@ async def add_new_model(model_params: ModelParams):
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) @router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info_v1(request: Request): async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path
# Load existing config # Load existing config
@ -1256,7 +1274,7 @@ async def model_info_v1(request: Request):
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933 #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info(request: Request): async def model_info(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path
# Load existing config # Load existing config
@ -1341,46 +1359,107 @@ async def delete_model(model_info: ModelInfoDelete):
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
#### EXPERIMENTAL QUEUING #### #### EXPERIMENTAL QUEUING ####
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)]) async def _litellm_chat_completions_worker(data, user_api_key_dict):
async def async_queue_request(request: Request): """
global celery_fn, llm_model_list worker to make litellm completions calls
if celery_fn is not None: """
body = await request.body() while True:
body_str = body.decode()
try: try:
data = ast.literal_eval(body_str) ### CALL HOOKS ### - modify incoming data before calling the model
except: data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
data = json.loads(body_str)
print(f"_litellm_chat_completions_worker started")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.acompletion(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(**data, specific_deployment = True)
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.acompletion(**data)
else: # router is not set
response = await litellm.acompletion(**data)
print(f"final response: {response}")
return response
except HTTPException as e:
print(f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}")
if e.status_code == 429 and "Max parallel request limit reached" in e.detail:
print(f"Max parallel request limit reached!")
timeout = litellm._calculate_retry_after(remaining_retries=3, max_retries=3, min_timeout=1)
await asyncio.sleep(timeout)
else:
raise e
@router.post("/queue/chat/completions", tags=["experimental"], dependencies=[Depends(user_api_key_auth)])
async def async_queue_request(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug, proxy_logging_obj
"""
v2 attempt at a background worker to handle queuing.
Just supports /chat/completion calls currently.
Now using a FastAPI background task + /chat/completions compatible endpoint
"""
try:
data = {}
data = await request.json() # type: ignore
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data) # use copy instead of deepcopy
}
print_verbose(f"receiving data: {data}")
data["model"] = ( data["model"] = (
general_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )
data["llm_model_list"] = llm_model_list
print(f"data: {data}") # users can pass in 'user' param to /chat/completions. Don't override it
job = celery_fn.apply_async(kwargs=data) if data.get("user", None) is None and user_api_key_dict.user_id is not None:
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} # if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
if "metadata" in data:
print(f'received metadata: {data["metadata"]}')
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
response = await asyncio.wait_for(_litellm_chat_completions_worker(data=data, user_api_key_dict=user_api_key_dict), timeout=litellm.request_timeout)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Queue not initialized"}, detail={"error": str(e)},
) )
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
async def async_queue_response(request: Request, task_id: str):
global celery_app_conn, async_result
try:
if celery_app_conn is not None and async_result is not None:
job = async_result(task_id, app=celery_app_conn)
if job.ready():
return {"status": "finished", "result": job.result}
else:
return {'status': 'queued'}
else:
raise Exception()
except Exception as e:
return {"status": "finished", "result": str(e)}
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request): async def retrieve_server_log(request: Request):
@ -1411,8 +1490,18 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
return {"hello": "world"} return {"hello": "world"}
@router.get("/test") @router.get("/test", tags=["health"])
async def test_endpoint(request: Request): async def test_endpoint(request: Request):
"""
A test endpoint that pings the proxy server to check if it's healthy.
Parameters:
request (Request): The incoming request.
Returns:
dict: A dictionary containing the route of the request URL.
"""
# ping the proxy server to check if its healthy
return {"route": request.url.path} return {"route": request.url.path}
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) @router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
@ -1470,4 +1559,27 @@ async def get_routes():
return {"routes": routes} return {"routes": routes}
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client, master_key, user_custom_auth
if prisma_client:
print("Disconnecting from Prisma")
await prisma_client.disconnect()
## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables()
def cleanup_router_config_variables():
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
# Set all variables to None
master_key = None
user_config_file_path = None
otel_logging = None
user_custom_auth = None
user_custom_auth_path = None
use_background_health_checks = None
health_check_interval = None
app.include_router(router) app.include_router(router)

View file

@ -1,23 +0,0 @@
import openai
client = openai.OpenAI(
api_key="anything",
# base_url="http://0.0.0.0:8000",
)
try:
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
},
])
print(response)
# except openai.APITimeoutError:
# print("Got openai Timeout Exception. Good job. The proxy mapped to OpenAI exceptions")
except Exception as e:
print("\n the proxy did not map to OpenAI exception. Instead got", e)
print(e.type) # type: ignore
print(e.message) # type: ignore
print(e.code) # type: ignore

View file

@ -1,13 +1,13 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio import os, subprocess, hashlib, importlib, asyncio, copy
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.integrations.custom_logger import CustomLogger
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ### ### LOGGING ###
class ProxyLogging: class ProxyLogging:
""" """
@ -26,7 +26,7 @@ class ProxyLogging:
pass pass
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
@ -65,17 +65,13 @@ class ProxyLogging:
2. /embeddings 2. /embeddings
""" """
try: try:
self.call_details["data"] = data for callback in litellm.callbacks:
self.call_details["call_type"] = call_type if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
## check if max parallel requests set if response is not None:
if user_api_key_dict.max_parallel_requests is not None: data = response
## if set, check if request allowed
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
print_verbose(f'final data being sent to {call_type} call: {data}')
return data return data
except Exception as e: except Exception as e:
raise e raise e
@ -103,17 +99,13 @@ class ProxyLogging:
1. /chat/completions 1. /chat/completions
2. /embeddings 2. /embeddings
""" """
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None: for callback in litellm.callbacks:
## decrement call count if call failed try:
if (hasattr(original_exception, "status_code") if isinstance(callback, CustomLogger):
and original_exception.status_code == 429 await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
and "Max parallel request limit reached" in str(original_exception)): except Exception as e:
pass # ignore failed calls due to max limit being reached raise e
else:
await self.max_parallel_request_limiter.async_log_failure_call(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return return
@ -165,19 +157,20 @@ class PrismaClient:
async def get_data(self, token: str, expires: Optional[Any]=None): async def get_data(self, token: str, expires: Optional[Any]=None):
try: try:
# check if plain text or hash # check if plain text or hash
hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
if expires: if expires:
response = await self.db.litellm_verificationtoken.find_first( response = await self.db.litellm_verificationtoken.find_first(
where={ where={
"token": token, "token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired "expires": {"gte": expires} # Check if the token is not expired
} }
) )
else: else:
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={ where={
"token": token "token": hashed_token
} }
) )
return response return response
@ -200,18 +193,18 @@ class PrismaClient:
try: try:
token = data["token"] token = data["token"]
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
data["token"] = hashed_token db_data = copy.deepcopy(data)
db_data["token"] = hashed_token
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, 'token': hashed_token,
}, },
data={ data={
"create": {**data}, #type: ignore "create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists "update": {} # don't do anything if it already exists
} }
) )
return new_verification_token return new_verification_token
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
@ -235,15 +228,16 @@ class PrismaClient:
if token.startswith("sk-"): if token.startswith("sk-"):
token = self.hash_token(token=token) token = self.hash_token(token=token)
data["token"] = token db_data = copy.deepcopy(data)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update( response = await self.db.litellm_verificationtoken.update(
where={ where={
"token": token "token": token
}, },
data={**data} # type: ignore data={**db_data} # type: ignore
) )
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data} return {"token": token, "data": db_data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")

View file

@ -7,6 +7,7 @@
# #
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import copy
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
@ -17,6 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI from openai import AsyncOpenAI
from collections import defaultdict from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
import copy
class Router: class Router:
""" """
Example usage: Example usage:
@ -68,6 +70,7 @@ class Router:
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
cache_responses: Optional[bool] = False, cache_responses: Optional[bool] = False,
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
## RELIABILITY ## ## RELIABILITY ##
num_retries: int = 0, num_retries: int = 0,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@ -76,11 +79,13 @@ class Router:
fallbacks: List = [], fallbacks: List = [],
allowed_fails: Optional[int] = None, allowed_fails: Optional[int] = None,
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
self.set_verbose = set_verbose self.set_verbose = set_verbose
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
if model_list: if model_list:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list self.healthy_deployments: List = self.model_list
self.deployment_latency_map = {} self.deployment_latency_map = {}
@ -99,6 +104,7 @@ class Router:
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call) self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
self.model_group_alias: dict = model_group_alias or {} # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
# make Router.chat.completions.create compatible for openai.chat.completions.create # make Router.chat.completions.create compatible for openai.chat.completions.create
self.chat = litellm.Chat(params=default_litellm_params) self.chat = litellm.Chat(params=default_litellm_params)
@ -107,9 +113,10 @@ class Router:
self.default_litellm_params = default_litellm_params self.default_litellm_params = default_litellm_params
self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0) self.default_litellm_params.setdefault("max_retries", 0)
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
### CACHING ### ### CACHING ###
cache_type = "local" # default to an in-memory cache cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None redis_cache = None
cache_config = {} cache_config = {}
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
@ -133,7 +140,7 @@ class Router:
if cache_responses: if cache_responses:
if litellm.cache is None: if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it # the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses self.cache_responses = cache_responses
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ### ### ROUTING SETUP ###
@ -198,19 +205,10 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
########## remove -ModelID-XXXX from model ############## kwargs[k].update(v)
original_model_string = data["model"]
# Find the index of "ModelID" in the string
self.print_verbose(f"completion model: {original_model_string}")
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs) model_client = self._get_client(deployment=deployment, kwargs=kwargs)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e: except Exception as e:
@ -241,31 +239,25 @@ class Router:
**kwargs): **kwargs):
try: try:
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}") self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
original_model_string = None # set a default for this variable
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None)) deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
########## remove -ModelID-XXXX from model ############## elif k == "metadata":
original_model_string = data["model"] kwargs[k].update(v)
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[original_model_string] +=1 self.total_calls[model_name] +=1
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[original_model_string] +=1 self.success_calls[model_name] +=1
return response return response
except Exception as e: except Exception as e:
if original_model_string is not None: if model_name is not None:
self.fail_calls[original_model_string] +=1 self.fail_calls[model_name] +=1
raise e raise e
def text_completion(self, def text_completion(self,
@ -283,8 +275,43 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
async def atext_completion(self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs):
try:
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
########## remove -ModelID-XXXX from model ############## ########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"] original_model_string = data["model"]
# Find the index of "ModelID" in the string # Find the index of "ModelID" in the string
@ -294,8 +321,9 @@ class Router:
data["model"] = original_model_string[:index_of_model_id] data["model"] = original_model_string[:index_of_model_id]
else: else:
data["model"] = original_model_string data["model"] = original_model_string
# call via litellm.completion() # call via litellm.atext_completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
return response
except Exception as e: except Exception as e:
if self.num_retries > 0: if self.num_retries > 0:
kwargs["model"] = model kwargs["model"] = model
@ -313,21 +341,14 @@ class Router:
**kwargs) -> Union[List[float], None]: **kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) kwargs.setdefault("model_info", {})
kwargs["model_info"] = deployment.get("model_info", {}) kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
########## remove -ModelID-XXXX from model ############## elif k == "metadata":
original_model_string = data["model"] kwargs[k].update(v)
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs) model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding() # call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -339,21 +360,15 @@ class Router:
**kwargs) -> Union[List[float], None]: **kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in kwargs: # prioritize model-specific params > default router params
data[k] = v kwargs[k] = v
########## remove -ModelID-XXXX from model ############## elif k == "metadata":
original_model_string = data["model"] kwargs[k].update(v)
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -371,7 +386,7 @@ class Router:
self.print_verbose(f'Async Response: {response}') self.print_verbose(f'Async Response: {response}')
return response return response
except Exception as e: except Exception as e:
self.print_verbose(f"An exception occurs: {e}") self.print_verbose(f"An exception occurs: {e}\n\n Traceback{traceback.format_exc()}")
original_exception = e original_exception = e
try: try:
self.print_verbose(f"Trying to fallback b/w models") self.print_verbose(f"Trying to fallback b/w models")
@ -637,9 +652,10 @@ class Router:
model_name = kwargs.get('model', None) # i.e. gpt35turbo model_name = kwargs.get('model', None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get('metadata', None) metadata = kwargs.get("litellm_params", {}).get('metadata', None)
deployment_id = kwargs.get("litellm_params", {}).get("model_info").get("id")
self._set_cooldown_deployments(deployment_id) # setting deployment_id in cooldown deployments
if metadata: if metadata:
deployment = metadata.get("deployment", None) deployment = metadata.get("deployment", None)
self._set_cooldown_deployments(deployment)
deployment_exceptions = self.model_exception_map.get(deployment, []) deployment_exceptions = self.model_exception_map.get(deployment, [])
deployment_exceptions.append(exception_str) deployment_exceptions.append(exception_str)
self.model_exception_map[deployment] = deployment_exceptions self.model_exception_map[deployment] = deployment_exceptions
@ -877,7 +893,7 @@ class Router:
return chosen_item return chosen_item
def set_model_list(self, model_list: list): def set_model_list(self, model_list: list):
self.model_list = model_list self.model_list = copy.deepcopy(model_list)
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os import os
for model in self.model_list: for model in self.model_list:
@ -889,23 +905,26 @@ class Router:
model["model_info"] = model_info model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ######## #### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider is None: custom_llm_provider = custom_llm_provider or model_name.split("/",1)[0] or ""
custom_llm_provider = model_name.split("/",1)[0] default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(model=model_name)
default_api_base = api_base
default_api_key = api_key
if ( if (
model_name in litellm.open_ai_chat_completion_models model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai" or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or "ft:gpt-3.5-turbo" in model_name or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models or model_name in litellm.open_ai_embedding_models
): ):
# glorified / complicated reading of configs # glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key # we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"): if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "") api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name) api_key = litellm.get_secret(api_key_env_name)
@ -913,7 +932,7 @@ class Router:
api_base = litellm_params.get("api_base") api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url") base_url = litellm_params.get("base_url")
api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure api_base = api_base or base_url or default_api_base # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"): if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "") api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name) api_base = litellm.get_secret(api_base_env_name)
@ -1049,12 +1068,6 @@ class Router:
############ End of initializing Clients for OpenAI/Azure ################### ############ End of initializing Clients for OpenAI/Azure ###################
self.deployment_names.append(model["litellm_params"]["model"]) self.deployment_names.append(model["litellm_params"]["model"])
model_id = ""
for key in model["litellm_params"]:
if key != "api_key" and key != "metadata":
model_id+= str(model["litellm_params"][key])
model["litellm_params"]["model"] += "-ModelID-" + model_id
self.print_verbose(f"\n Initialized Model List {self.model_list}") self.print_verbose(f"\n Initialized Model List {self.model_list}")
############ Users can either pass tpm/rpm as a litellm_param or a router param ########### ############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
@ -1115,38 +1128,41 @@ class Router:
if specific_deployment == True: if specific_deployment == True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list: for deployment in self.model_list:
cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model")) deployment_model = deployment.get("litellm_params").get("model")
if cleaned_model == model: if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name # return the first deployment where the `model` matches the specificed deployment name
return deployment return deployment
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}") raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
# check if aliases set on litellm model alias map # check if aliases set on litellm model alias map
if model in litellm.model_group_alias_map: if model in self.model_group_alias:
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {litellm.model_group_alias_map.get(model)}") self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}")
model = litellm.model_group_alias_map[model] model = self.model_group_alias[model]
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
### filter out the deployments currently cooling down
healthy_deployments = [m for m in self.model_list if m["model_name"] == model] healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead # check if the user sent in a deployment name instead
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model] healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
self.print_verbose(f"initial list of deployments: {healthy_deployments}") self.print_verbose(f"initial list of deployments: {healthy_deployments}")
# filter out the deployments currently cooling down
deployments_to_remove = [] deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = self._get_cooldown_deployments() cooldown_deployments = self._get_cooldown_deployments()
self.print_verbose(f"cooldown deployments: {cooldown_deployments}") self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
### FIND UNHEALTHY DEPLOYMENTS # Find deployments in model_list whose model_id is cooling down
for deployment in healthy_deployments: for deployment in healthy_deployments:
deployment_name = deployment["litellm_params"]["model"] deployment_id = deployment["model_info"]["id"]
if deployment_name in cooldown_deployments: if deployment_id in cooldown_deployments:
deployments_to_remove.append(deployment) deployments_to_remove.append(deployment)
### FILTER OUT UNHEALTHY DEPLOYMENTS # remove unhealthy deployments from healthy deployments
for deployment in deployments_to_remove: for deployment in deployments_to_remove:
healthy_deployments.remove(deployment) healthy_deployments.remove(deployment)
self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}") self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}")
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
raise ValueError("No models available") raise ValueError("No models available")
@ -1222,11 +1238,14 @@ class Router:
raise ValueError("No models available.") raise ValueError("No models available.")
def flush_cache(self): def flush_cache(self):
litellm.cache = None
self.cache.flush_cache() self.cache.flush_cache()
def reset(self): def reset(self):
## clean up on close ## clean up on close
litellm.success_callback = [] litellm.success_callback = []
litellm.__async_success_callback = []
litellm.failure_callback = [] litellm.failure_callback = []
litellm._async_failure_callback = []
self.flush_cache() self.flush_cache()

34
litellm/tests/conftest.py Normal file
View file

@ -0,0 +1,34 @@
# conftest.py
import pytest, sys, os
import importlib
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
import litellm
importlib.reload(litellm)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests

View file

@ -0,0 +1,7 @@
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
cache: True

View file

@ -0,0 +1,10 @@
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
cache: True
cache_params:
supported_call_types: ["embedding", "aembedding"]
host: "localhost"

View file

@ -0,0 +1,4 @@
uploading batch of 2 items
successfully uploaded batch of 2 items
uploading batch of 2 items
successfully uploaded batch of 2 items

View file

@ -118,6 +118,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": "BAD_API_BASE", "api_base": "BAD_API_BASE",
"tpm": 90
}, },
}, },
{ {
@ -126,7 +127,8 @@ def test_cooldown_same_model_name():
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
"tpm": 0.000001
}, },
}, },
] ]
@ -151,13 +153,14 @@ def test_cooldown_same_model_name():
] ]
) )
print(router.model_list) print(router.model_list)
litellm_model_names = [] model_ids = []
for model in router.model_list: for model in router.model_list:
litellm_model_names.append(model["litellm_params"]["model"]) model_ids.append(model["model_info"]["id"])
print("\n litellm model names ", litellm_model_names) print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960'] # example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert litellm_model_names[0] != litellm_model_names[1] # ensure both models have a uuid added, and they have different names assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response) print("\ngot response\n", response)
except Exception as e: except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {e}") pytest.fail(f"Got unexpected exception on router! - {e}")

View file

@ -9,9 +9,9 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout, acompletion
from litellm import RateLimitError from litellm import RateLimitError
import json import json
import os import os
@ -63,6 +63,27 @@ def load_vertex_ai_credentials():
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name) os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
@pytest.mark.asyncio
async def get_response():
load_vertex_ai_credentials()
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
try:
response = await acompletion(
model="gemini-pro",
messages=[
{
"role": "system",
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
},
{"role": "user", "content": prompt},
],
)
return response
except litellm.UnprocessableEntityError as e:
pass
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")
def test_vertex_ai(): def test_vertex_ai():
import random import random
@ -72,14 +93,15 @@ def test_vertex_ai():
litellm.set_verbose=False litellm.set_verbose=False
litellm.vertex_project = "hardy-device-386718" litellm.vertex_project = "hardy-device-386718"
test_models = random.sample(test_models, 4) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}]) response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7)
print("\nModel Response", response) print("\nModel Response", response)
print(response) print(response)
assert type(response.choices[0].message.content) == str assert type(response.choices[0].message.content) == str
@ -95,10 +117,11 @@ def test_vertex_ai_stream():
import random import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = random.sample(test_models, 4) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
@ -115,3 +138,199 @@ def test_vertex_ai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream() # test_vertex_ai_stream()
@pytest.mark.asyncio
async def test_async_vertexai_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
print(f'model being tested in async call: {model}')
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_response())
@pytest.mark.asyncio
async def test_async_vertexai_streaming_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
print(f"response: {response}")
complete_response = ""
async for chunk in response:
print(f"chunk: {chunk}")
complete_response += chunk.choices[0].delta.content
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
except litellm.Timeout as e:
pass
except Exception as e:
print(e)
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_streaming_response())
def test_gemini_pro_vision():
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
litellm.num_retries=0
resp = litellm.completion(
model = "vertex_ai/gemini-pro-vision",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
}
}
]
}
],
)
print(resp)
except Exception as e:
import traceback
traceback.print_exc()
raise e
# test_gemini_pro_vision()
# Extra gemini Vision tests for completion + stream, async, async + stream
# if we run into issues with gemini, we will also add these to our ci/cd pipeline
# def test_gemini_pro_vision_stream():
# try:
# litellm.set_verbose = False
# litellm.num_retries=0
# print("streaming response from gemini-pro-vision")
# resp = litellm.completion(
# model = "vertex_ai/gemini-pro-vision",
# messages=[
# {
# "role": "user",
# "content": [
# {
# "type": "text",
# "text": "Whats in this image?"
# },
# {
# "type": "image_url",
# "image_url": {
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
# }
# }
# ]
# }
# ],
# stream=True
# )
# print(resp)
# for chunk in resp:
# print(chunk)
# except Exception as e:
# import traceback
# traceback.print_exc()
# raise e
# test_gemini_pro_vision_stream()
# def test_gemini_pro_vision_async():
# try:
# litellm.set_verbose = True
# litellm.num_retries=0
# async def test():
# resp = await litellm.acompletion(
# model = "vertex_ai/gemini-pro-vision",
# messages=[
# {
# "role": "user",
# "content": [
# {
# "type": "text",
# "text": "Whats in this image?"
# },
# {
# "type": "image_url",
# "image_url": {
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
# }
# }
# ]
# }
# ],
# )
# print("async response gemini pro vision")
# print(resp)
# asyncio.run(test())
# except Exception as e:
# import traceback
# traceback.print_exc()
# raise e
# test_gemini_pro_vision_async()
# def test_gemini_pro_vision_async_stream():
# try:
# litellm.set_verbose = True
# litellm.num_retries=0
# async def test():
# resp = await litellm.acompletion(
# model = "vertex_ai/gemini-pro-vision",
# messages=[
# {
# "role": "user",
# "content": [
# {
# "type": "text",
# "text": "Whats in this image?"
# },
# {
# "type": "image_url",
# "image_url": {
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
# }
# }
# ]
# }
# ],
# stream=True
# )
# print("async response gemini pro vision")
# print(resp)
# for chunk in resp:
# print(chunk)
# asyncio.run(test())
# except Exception as e:
# import traceback
# traceback.print_exc()
# raise e
# test_gemini_pro_vision_async()

View file

@ -29,16 +29,19 @@ def generate_random_word(length=4):
messages = [{"role": "user", "content": "who is ishaan 5222"}] messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache def test_caching_v2(): # test in memory cache
try: try:
litellm.set_verbose=True
litellm.cache = Cache() litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
litellm.success_callback = []
litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred:")
except Exception as e: except Exception as e:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -58,6 +61,8 @@ def test_caching_with_models_v2():
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response2: {response2}") print(f"response2: {response2}")
@ -91,6 +96,8 @@ def test_embedding_caching():
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
@ -145,6 +152,8 @@ def test_embedding_caching_azure():
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
@ -175,6 +184,8 @@ def test_redis_cache_completion():
print("\nresponse 3", response3) print("\nresponse 3", response3)
print("\nresponse 4", response4) print("\nresponse 4", response4)
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
""" """
1 & 2 should be exactly the same 1 & 2 should be exactly the same
@ -226,6 +237,8 @@ def test_redis_cache_completion_stream():
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
litellm.success_callback = [] litellm.success_callback = []
@ -271,11 +284,53 @@ def test_redis_cache_acompletion_stream():
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
# test_redis_cache_acompletion_stream() # test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream_bedrock()
# redis cache with custom keys # redis cache with custom keys
def custom_get_cache_key(*args, **kwargs): def custom_get_cache_key(*args, **kwargs):
# return key to use for your cache: # return key to use for your cache:
@ -312,9 +367,44 @@ def test_custom_redis_cache_with_key():
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# test_custom_redis_cache_with_key() # test_custom_redis_cache_with_key()
def test_cache_override():
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
# in this case it should not return cached responses
litellm.cache = Cache()
print("Testing cache override")
litellm.set_verbose=True
# test embedding
response1 = embedding(
model = "text-embedding-ada-002",
input=[
"hello who are you"
],
caching = False
)
start_time = time.time()
response2 = embedding(
model = "text-embedding-ada-002",
input=[
"hello who are you"
],
caching = False
)
end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached.
# test_cache_override()
def test_custom_redis_cache_params(): def test_custom_redis_cache_params():
# test if we can init redis with **kwargs # test if we can init redis with **kwargs
@ -333,6 +423,8 @@ def test_custom_redis_cache_params():
print(litellm.cache.cache.redis_client) print(litellm.cache.cache.redis_client)
litellm.cache = None litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred:", e) pytest.fail(f"Error occurred:", e)
@ -340,15 +432,58 @@ def test_custom_redis_cache_params():
def test_get_cache_key(): def test_get_cache_key():
from litellm.caching import Cache from litellm.caching import Cache
try: try:
print("Testing get_cache_key")
cache_instance = Cache() cache_instance = Cache()
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
) )
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
)
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
embedding_cache_key = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
'api_key': '', 'api_version': '2023-07-01-preview',
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
'caching': True,
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>"
}
)
print(embedding_cache_key)
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
# Proxy - embedding cache, test if embedding key, gets model_group and not model
embedding_cache_key_2 = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
'api_key': '', 'api_version': '2023-07-01-preview',
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
'caching': True,
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings',
'method': 'POST',
'headers':
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json',
'content-length': '80'},
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}},
'user': None,
'metadata': {'user_api_key': None,
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'},
'model_group': 'EMBEDDING_MODEL_GROUP',
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'},
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'},
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'}
)
print(embedding_cache_key_2)
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
print("passed!")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred:", e) pytest.fail(f"Error occurred:", e)
# test_get_cache_key() test_get_cache_key()
# test_custom_redis_cache_params() # test_custom_redis_cache_params()

View file

@ -21,6 +21,13 @@ messages = [{"content": user_message, "role": "user"}]
def logger_fn(user_model_dict): def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}") print(f"user_model_dict: {user_model_dict}")
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
def test_completion_custom_provider_model_name(): def test_completion_custom_provider_model_name():
try: try:
@ -61,6 +68,25 @@ def test_completion_claude():
# test_completion_claude() # test_completion_claude()
def test_completion_mistral_api():
try:
litellm.set_verbose=True
response = completion(
model="mistral/mistral-tiny",
messages=[
{
"role": "user",
"content": "Hey, how's it going?",
}
],
safe_mode = True
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_api()
def test_completion_claude2_1(): def test_completion_claude2_1():
try: try:
print("claude2.1 test request") print("claude2.1 test request")
@ -287,7 +313,7 @@ def hf_test_completion_tgi():
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
hf_test_completion_tgi() # hf_test_completion_tgi()
# ################### Hugging Face Conversational models ######################## # ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv(): # def hf_test_completion_conv():
@ -611,7 +637,7 @@ def test_completion_azure_key_completion_arg():
os.environ.pop("AZURE_API_KEY", None) os.environ.pop("AZURE_API_KEY", None)
try: try:
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=False litellm.set_verbose=True
## Test azure call ## Test azure call
response = completion( response = completion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
@ -696,6 +722,7 @@ def test_completion_azure():
print(response) print(response)
cost = completion_cost(completion_response=response) cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for azure completion request", cost) print("Cost for azure completion request", cost)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -1013,15 +1040,56 @@ def test_completion_together_ai():
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
cost = completion_cost(completion_response=response) cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_together_ai_mixtral():
model_name = "together_ai/DiscoResearch/DiscoLM-mixtral-8x7b-v2"
try:
messages =[
{"role": "user", "content": "Who are you"},
{"role": "assistant", "content": "I am your helpful assistant."},
{"role": "user", "content": "Tell me a joke"},
]
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_together_ai_mixtral()
def test_completion_together_ai_yi_chat():
model_name = "together_ai/zero-one-ai/Yi-34B-Chat"
try:
messages =[
{"role": "user", "content": "What llm are you?"},
]
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
assert cost > 0.0
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_together_ai_yi_chat()
# test_completion_together_ai() # test_completion_together_ai()
def test_customprompt_together_ai(): def test_customprompt_together_ai():
try: try:
litellm.set_verbose = False litellm.set_verbose = False
litellm.num_retries = 0 litellm.num_retries = 0
print("in test_customprompt_together_ai")
print(litellm.success_callback)
print(litellm._async_success_callback)
response = completion( response = completion(
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages, messages=messages,
@ -1030,7 +1098,6 @@ def test_customprompt_together_ai():
print(response) print(response)
except litellm.exceptions.Timeout as e: except litellm.exceptions.Timeout as e:
print(f"Timeout Error") print(f"Timeout Error")
litellm.num_retries = 3 # reset retries
pass pass
except Exception as e: except Exception as e:
print(f"ERROR TYPE {type(e)}") print(f"ERROR TYPE {type(e)}")

View file

@ -2,7 +2,7 @@ from litellm.integrations.custom_logger import CustomLogger
import inspect import inspect
import litellm import litellm
class MyCustomHandler(CustomLogger): class testCustomCallbackProxy(CustomLogger):
def __init__(self): def __init__(self):
self.success: bool = False # type: ignore self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore self.failure: bool = False # type: ignore
@ -55,8 +55,11 @@ class MyCustomHandler(CustomLogger):
self.async_success = True self.async_success = True
print("Value of async success: ", self.async_success) print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs) print("\n kwargs: ", kwargs)
if kwargs.get("model") == "azure-embedding-model": if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
print("Got an embedding model", kwargs.get("model"))
print("Setting embedding success to True")
self.async_success_embedding = True self.async_success_embedding = True
print("Value of async success embedding: ", self.async_success_embedding)
self.async_embedding_kwargs = kwargs self.async_embedding_kwargs = kwargs
self.async_embedding_response = response_obj self.async_embedding_response = response_obj
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
@ -79,6 +82,9 @@ class MyCustomHandler(CustomLogger):
# tokens used in response # tokens used in response
usage = response_obj["usage"] usage = response_obj["usage"]
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print( print(
f""" f"""
Model: {model}, Model: {model},
@ -104,4 +110,4 @@ class MyCustomHandler(CustomLogger):
self.async_completion_kwargs_fail = kwargs self.async_completion_kwargs_fail = kwargs
my_custom_logger = MyCustomHandler() my_custom_logger = testCustomCallbackProxy()

View file

@ -0,0 +1,16 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
api_key: bad-key
model: gpt-3.5-turbo
- model_name: azure-gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: bad-key
- model_name: azure-embedding
litellm_params:
model: azure/azure-embedding-model
api_base: os.environ/AZURE_API_BASE
api_key: bad-key

View file

@ -19,3 +19,63 @@ model_list:
model_info: model_info:
description: this is a test openai model description: this is a test openai model
model_name: test_openai_models model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 56f1bd94-3b54-4b67-9ea2-7c70e9a3a709
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 4d1ee26c-abca-450c-8744-8e87fd6755e9
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 00e19c0f-b63d-42bb-88e9-016fb0c60764
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 79fc75bf-8e1b-47d5-8d24-9365a854af03
model_name: test_openai_models
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/azure-embedding-model
model_name: azure-embedding-model
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 55848c55-4162-40f9-a6e2-9a722b9ef404
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 34339b1e-e030-4bcc-a531-c48559f10ce4
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: f6f74e14-ac64-4403-9365-319e584dcdc5
model_name: test_openai_models
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 9b1ef341-322c-410a-8992-903987fef439
model_name: test_openai_models
- model_name: amazon-embeddings
litellm_params:
model: "bedrock/amazon.titan-embed-text-v1"
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"

View file

@ -0,0 +1,631 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional, Literal, List, Union
from litellm import completion, embedding, Cache
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## 5. Caching
# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock
# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router + proxy tests
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("async_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# COMPLETION
## Test OpenAI + sync
def test_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync openai"
}])
## test streaming
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
for chunk in response:
continue
## test failure callback
try:
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
## test streaming
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync
def test_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}])
# test streaming
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
stream=True)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream()
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}])
## test streaming
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync
def test_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}])
# test streaming
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}],
stream=True)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}],
aws_region_name="my-bad-region",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_bedrock_stream()
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}])
# test streaming
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}],
stream=True)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}],
aws_region_name="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_bedrock_stream())
# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_embedding_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="text-embedding-ada-002",
input=["good morning from litellm"],
api_key="my-bad-key")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_openai())
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_azure())
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
litellm.set_verbose = True
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="my-bad-region")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_bedrock())
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching")
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.aembedding(model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True)
await asyncio.sleep(1) # set cache is async for aembedding()
response2 = await litellm.aembedding(model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
# asyncio.run(
# test_async_embedding_azure_caching()
# )

View file

@ -0,0 +1,488 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional, Literal, List
from litellm import Router, Cache
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## fallbacks
## retries
# Test cases
## 1. Simple Azure OpenAI acompletion + streaming call
## 2. Simple Azure OpenAI aembedding call
## 3. Azure OpenAI acompletion + streaming call with retries
## 4. Azure OpenAI aembedding call with retries
## 5. Azure OpenAI acompletion + streaming call with fallbacks
## 6. Azure OpenAI aembedding call with fallbacks
# Test interfaces
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
print(f'received kwargs in pre-input: {kwargs}')
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
"""
No-op.
Not implemented yet.
"""
pass
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"received original response: {kwargs['original_response']}")
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure():
try:
customHandler_completion_azure_router = CompletionCustomHandler()
customHandler_streaming_azure_router = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_completion_azure_router]
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
await asyncio.sleep(2)
assert len(customHandler_completion_azure_router.errors) == 0
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
print(f"async azure router chunk: {chunk}")
continue
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
# failure
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
print(f"response in router3 acompletion: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure())
## EMBEDDING
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
customHandler = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler]
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
await asyncio.sleep(2)
assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success
# failure
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
print(f"response in router3 aembedding: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Fallbacks
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure_with_fallbacks():
try:
customHandler_fallbacks = CompletionCustomHandler()
litellm.callbacks = [customHandler_fallbacks]
# with fallbacks
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
await asyncio.sleep(2)
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
assert len(customHandler_fallbacks.errors) == 0
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
litellm.callbacks = []
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure_with_fallbacks())
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
]
router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -1,5 +1,5 @@
### What this tests #### ### What this tests ####
import sys, os, time, inspect, asyncio import sys, os, time, inspect, asyncio, traceback
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath('../..'))
@ -7,9 +7,8 @@ from litellm import completion, embedding
import litellm import litellm
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
async_success = False
complete_streaming_response_in_callback = ""
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
complete_streaming_response_in_callback = ""
def __init__(self): def __init__(self):
self.success: bool = False # type: ignore self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore self.failure: bool = False # type: ignore
@ -27,9 +26,12 @@ class MyCustomHandler(CustomLogger):
self.stream_collected_response = None # type: ignore self.stream_collected_response = None # type: ignore
self.sync_stream_collected_response = None # type: ignore self.sync_stream_collected_response = None # type: ignore
self.user = None # type: ignore
self.data_sent_to_api: dict = {}
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print(f"Pre-API Call")
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print(f"Post-API Call")
@ -50,9 +52,8 @@ class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success") print(f"On Async success")
print(f"received kwargs user: {kwargs['user']}")
self.async_success = True self.async_success = True
print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs)
if kwargs.get("model") == "text-embedding-ada-002": if kwargs.get("model") == "text-embedding-ada-002":
self.async_success_embedding = True self.async_success_embedding = True
self.async_embedding_kwargs = kwargs self.async_embedding_kwargs = kwargs
@ -60,31 +61,32 @@ class MyCustomHandler(CustomLogger):
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
self.stream_collected_response = response_obj self.stream_collected_response = response_obj
self.async_completion_kwargs = kwargs self.async_completion_kwargs = kwargs
self.user = kwargs.get("user", None)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure") print(f"On Async Failure")
self.async_failure = True self.async_failure = True
print("Value of async failure: ", self.async_failure)
print("\n kwargs: ", kwargs)
if kwargs.get("model") == "text-embedding-ada-002": if kwargs.get("model") == "text-embedding-ada-002":
self.async_failure_embedding = True self.async_failure_embedding = True
self.async_embedding_kwargs_fail = kwargs self.async_embedding_kwargs_fail = kwargs
self.async_completion_kwargs_fail = kwargs self.async_completion_kwargs_fail = kwargs
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time): class TmpFunction:
global async_success, complete_streaming_response_in_callback complete_streaming_response_in_callback = ""
async_success: bool = False
async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
print(f"ON ASYNC LOGGING") print(f"ON ASYNC LOGGING")
async_success = True self.async_success = True
print("\nKWARGS", kwargs) print(f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}')
complete_streaming_response_in_callback = kwargs.get("complete_streaming_response") self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
def test_async_chat_openai_stream(): def test_async_chat_openai_stream():
try: try:
global complete_streaming_response_in_callback tmp_function = TmpFunction()
# litellm.set_verbose = True # litellm.set_verbose = True
litellm.success_callback = [async_test_logging_fn] litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = "" complete_streaming_response = ""
async def call_gpt(): async def call_gpt():
nonlocal complete_streaming_response nonlocal complete_streaming_response
@ -98,12 +100,16 @@ def test_async_chat_openai_stream():
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response) print(complete_streaming_response)
asyncio.run(call_gpt()) asyncio.run(call_gpt())
assert complete_streaming_response_in_callback["choices"][0]["message"]["content"] == complete_streaming_response complete_streaming_response = complete_streaming_response.strip("'")
assert async_success == True response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
response2 = complete_streaming_response
# assert [ord(c) for c in response1] == [ord(c) for c in response2]
assert response1 == response2
assert tmp_function.async_success == True
except Exception as e: except Exception as e:
print(e) print(e)
pytest.fail(f"An error occurred - {str(e)}") pytest.fail(f"An error occurred - {str(e)}")
test_async_chat_openai_stream() # test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure(): def test_completion_azure_stream_moderation_failure():
try: try:
@ -205,38 +211,15 @@ def test_azure_completion_stream():
assert response_in_success_handler == complete_streaming_response assert response_in_success_handler == complete_streaming_response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_azure_completion_stream()
def test_async_custom_handler(): @pytest.mark.asyncio
async def test_async_custom_handler_completion():
try: try:
customHandler2 = MyCustomHandler() customHandler_success = MyCustomHandler()
litellm.callbacks = [customHandler2] customHandler_failure = MyCustomHandler()
litellm.set_verbose = True # success
messages = [ assert customHandler_success.async_success == False
{"role": "system", "content": "You are a helpful assistant."}, litellm.callbacks = [customHandler_success]
{
"role": "user",
"content": "how do i kill someone",
},
]
async def test_1():
try:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
api_key="test",
)
except:
pass
assert customHandler2.async_failure == False
asyncio.run(test_1())
assert customHandler2.async_failure == True, "async failure is not set to True even after failure"
assert customHandler2.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
assert len(str(customHandler2.async_completion_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
print("Passed setting async failure")
async def test_2():
response = await litellm.acompletion( response = await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{ messages=[{
@ -244,52 +227,113 @@ def test_async_custom_handler():
"content": "hello from litellm test", "content": "hello from litellm test",
}] }]
) )
print("\n response", response) await asyncio.sleep(1)
assert customHandler2.async_success == False assert customHandler_success.async_success == True, "async success is not set to True even after success"
asyncio.run(test_2()) assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
assert customHandler2.async_success == True, "async success is not set to True even after success" # failure
assert customHandler2.async_completion_kwargs.get("model") == "gpt-3.5-turbo" litellm.callbacks = [customHandler_failure]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how do i kill someone",
},
]
assert customHandler_failure.async_failure == False
try:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
api_key="my-bad-key",
)
except:
pass
assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure"
assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
litellm.callbacks = []
print("Passed setting async failure")
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_custom_handler_completion())
async def test_3(): @pytest.mark.asyncio
async def test_async_custom_handler_embedding():
try:
customHandler_embedding = MyCustomHandler()
litellm.callbacks = [customHandler_embedding]
# success
assert customHandler_embedding.async_success_embedding == False
response = await litellm.aembedding( response = await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input = ["hello world"], input = ["hello world"],
) )
print("\n response", response) await asyncio.sleep(1)
assert customHandler2.async_success_embedding == False assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success"
asyncio.run(test_3()) assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
assert customHandler2.async_success_embedding == True, "async_success_embedding is not set to True even after success" assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2
assert customHandler2.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
assert customHandler2.async_embedding_response["usage"]["prompt_tokens"] ==2
print("Passed setting async success: Embedding") print("Passed setting async success: Embedding")
# failure
assert customHandler_embedding.async_failure_embedding == False
print("Testing custom failure callback for embedding")
async def test_4():
try: try:
response = await litellm.aembedding( response = await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input = ["hello world"], input = ["hello world"],
api_key="test", api_key="my-bad-key",
) )
except: except:
pass pass
assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
assert customHandler2.async_failure_embedding == False assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
asyncio.run(test_4()) assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
assert customHandler2.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
assert customHandler2.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
assert len(str(customHandler2.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
print("Passed setting async failure")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"An exception occurred - {str(e)}")
# test_async_custom_handler() # asyncio.run(test_async_custom_handler_embedding())
@pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param():
"""
Tests if the openai optional params for embedding - user + encoding_format,
are logged
"""
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input = ["hello world"],
user = "John"
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John"
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
# asyncio.run(test_async_custom_handler_embedding_optional_param())
@pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param_bedrock():
"""
Tests if the openai optional params for embedding - user + encoding_format,
are logged
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
"""
litellm.drop_params = True
litellm.set_verbose = True
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="bedrock/amazon.titan-embed-text-v1",
input = ["hello world"],
user = "John"
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John"
assert "user" not in customHandler_optional_params.data_sent_to_api
from litellm import Cache
def test_redis_cache_completion_stream(): def test_redis_cache_completion_stream():
from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set # Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random import random
try: try:
@ -316,13 +360,10 @@ def test_redis_cache_completion_stream():
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = []
litellm.cache = None litellm.cache = None
except Exception as e: except Exception as e:
print(e) print(e)
litellm.success_callback = [] litellm.success_callback = []
raise e raise e
"""
1 & 2 should be exactly the same
"""
# test_redis_cache_completion_stream() # test_redis_cache_completion_stream()

View file

@ -0,0 +1,120 @@
import sys
import os
import io, asyncio
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath('../..'))
from litellm import completion
import litellm
litellm.num_retries = 3
import time, random
import pytest
def pre_request():
file_name = f"dynamo.log"
log_file = open(file_name, "a+")
# Clear the contents of the file by truncating it
log_file.truncate(0)
# Save the original stdout so that we can restore it later
original_stdout = sys.stdout
# Redirect stdout to the file
sys.stdout = log_file
return original_stdout, log_file, file_name
import re
def verify_log_file(log_file_path):
with open(log_file_path, 'r') as log_file:
log_content = log_file.read()
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content)
# Define the pattern to search for in the log file
pattern = r"Response from DynamoDB:{.*?}"
# Find all matches in the log content
matches = re.findall(pattern, log_content)
# Print the DynamoDB success log matches
print("DynamoDB Success Log Matches:")
for match in matches:
print(match)
# Print the total count of lines containing the specified response
print(f"Total occurrences of specified response: {len(matches)}")
# Count the occurrences of successful responses (status code 200 or 201)
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match)
# Print the count of successful responses
print(f"Count of successful responses from DynamoDB: {success_count}")
assert success_count == 3 # Expect 3 success logs from dynamoDB
def test_dynamo_logging():
# all dynamodb requests need to be in one test function
# since we are modifying stdout, and pytests runs tests in parallel
try:
# pre
# redirect stdout to log_file
litellm.success_callback = ["dynamodb"]
litellm.dynamodb_table_name = "litellm-logs-1"
litellm.set_verbose = True
original_stdout, log_file, file_name = pre_request()
print("Testing async dynamoDB logging")
async def _test():
return await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
max_tokens=100,
temperature=0.7,
user = "ishaan-2"
)
response = asyncio.run(_test())
print(f"response: {response}")
# streaming + async
async def _test2():
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
max_tokens=10,
temperature=0.7,
user = "ishaan-2",
stream=True
)
async for chunk in response:
pass
asyncio.run(_test2())
# aembedding()
async def _test3():
return await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hi"],
user = "ishaan-2"
)
response = asyncio.run(_test3())
time.sleep(1)
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
finally:
# post, close log file and verify
# Reset stdout to the original value
sys.stdout = original_stdout
# Close the file
log_file.close()
verify_log_file(file_name)
print("Passed! Testing async dynamoDB logging")
# test_dynamo_logging_async()

View file

@ -164,7 +164,7 @@ def test_bedrock_embedding_titan():
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_titan() test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere(): def test_bedrock_embedding_cohere():
try: try:

View file

@ -21,6 +21,7 @@ from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1" litellm.vertex_location = "us-central1"
litellm.num_retries=0
# litellm.failure_callback = ["sentry"] # litellm.failure_callback = ["sentry"]
#### What this tests #### #### What this tests ####
@ -38,10 +39,11 @@ models = ["command-nightly"]
# Test 1: Context Window Errors # Test 1: Context Window Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window(model): def test_context_window(model):
print("Testing context window error")
sample_text = "Say error 50 times" * 1000000 sample_text = "Say error 50 times" * 1000000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
try: try:
litellm.set_verbose = False litellm.set_verbose = True
response = completion(model=model, messages=messages) response = completion(model=model, messages=messages)
print(f"response: {response}") print(f"response: {response}")
print("FAILED!") print("FAILED!")
@ -176,7 +178,7 @@ def test_completion_azure_exception():
try: try:
import openai import openai
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=False litellm.set_verbose=True
## Test azure call ## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"] old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning" os.environ["AZURE_API_KEY"] = "good morning"
@ -189,6 +191,7 @@ def test_completion_azure_exception():
} }
], ],
) )
os.environ["AZURE_API_KEY"] = old_azure_key
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
@ -196,14 +199,14 @@ def test_completion_azure_exception():
print("good job got the correct error for azure when key not set") print("good job got the correct error for azure when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_azure_exception() # test_completion_azure_exception()
async def asynctest_completion_azure_exception(): async def asynctest_completion_azure_exception():
try: try:
import openai import openai
import litellm import litellm
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=False litellm.set_verbose=True
## Test azure call ## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"] old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning" os.environ["AZURE_API_KEY"] = "good morning"
@ -226,19 +229,75 @@ async def asynctest_completion_azure_exception():
print("Got wrong exception") print("Got wrong exception")
print("exception", e) print("exception", e)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# import asyncio # import asyncio
# asyncio.run( # asyncio.run(
# asynctest_completion_azure_exception() # asynctest_completion_azure_exception()
# ) # )
def asynctest_completion_openai_exception_bad_model():
try:
import openai
import litellm, asyncio
print("azure exception bad model\n\n")
litellm.set_verbose=True
## Test azure call
async def test():
response = await litellm.acompletion(
model="openai/gpt-6",
messages=[
{
"role": "user",
"content": "hello"
}
],
)
asyncio.run(test())
except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!")
print("Passed")
except Exception as e:
print("Raised wrong type of exception", type(e))
assert isinstance(e, openai.BadRequestError)
pytest.fail(f"Error occurred: {e}")
# asynctest_completion_openai_exception_bad_model()
def asynctest_completion_azure_exception_bad_model():
try:
import openai
import litellm, asyncio
print("azure exception bad model\n\n")
litellm.set_verbose=True
## Test azure call
async def test():
response = await litellm.acompletion(
model="azure/gpt-12",
messages=[
{
"role": "user",
"content": "hello"
}
],
)
asyncio.run(test())
except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!")
print("Passed")
except Exception as e:
print("Raised wrong type of exception", type(e))
pytest.fail(f"Error occurred: {e}")
# asynctest_completion_azure_exception_bad_model()
def test_completion_openai_exception(): def test_completion_openai_exception():
# test if openai:gpt raises openai.AuthenticationError # test if openai:gpt raises openai.AuthenticationError
try: try:
import openai import openai
print("openai gpt-3.5 test\n\n") print("openai gpt-3.5 test\n\n")
litellm.set_verbose=False litellm.set_verbose=True
## Test azure call ## Test azure call
old_azure_key = os.environ["OPENAI_API_KEY"] old_azure_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "good morning" os.environ["OPENAI_API_KEY"] = "good morning"
@ -255,11 +314,38 @@ def test_completion_openai_exception():
print(response) print(response)
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
os.environ["OPENAI_API_KEY"] = old_azure_key os.environ["OPENAI_API_KEY"] = old_azure_key
print("good job got the correct error for openai when key not set") print("OpenAI: good job got the correct error for openai when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_exception() # test_completion_openai_exception()
def test_completion_mistral_exception():
# test if mistral/mistral-tiny raises openai.AuthenticationError
try:
import openai
print("Testing mistral ai exception mapping")
litellm.set_verbose=True
## Test azure call
old_azure_key = os.environ["MISTRAL_API_KEY"]
os.environ["MISTRAL_API_KEY"] = "good morning"
response = completion(
model="mistral/mistral-tiny",
messages=[
{
"role": "user",
"content": "hello"
}
],
)
print(f"response: {response}")
print(response)
except openai.AuthenticationError as e:
os.environ["MISTRAL_API_KEY"] = old_azure_key
print("good job got the correct error for openai when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_exception()

View file

@ -9,33 +9,107 @@ from litellm import completion
import litellm import litellm
litellm.num_retries = 3 litellm.num_retries = 3
litellm.success_callback = ["langfuse"] litellm.success_callback = ["langfuse"]
# litellm.set_verbose = True os.environ["LANGFUSE_DEBUG"] = "True"
import time import time
import pytest import pytest
def search_logs(log_file_path):
"""
Searches the given log file for logs containing the "/api/public" string.
Parameters:
- log_file_path (str): The path to the log file to be searched.
Returns:
- None
Raises:
- Exception: If there are any bad logs found in the log file.
"""
import re
print("\n searching logs")
bad_logs = []
good_logs = []
all_logs = []
try:
with open(log_file_path, 'r') as log_file:
lines = log_file.readlines()
print(f"searching logslines: {lines}")
for line in lines:
all_logs.append(line.strip())
if "/api/public" in line:
print("Found log with /api/public:")
print(line.strip())
print("\n\n")
match = re.search(r'receive_response_headers.complete return_value=\(b\'HTTP/1.1\', (\d+),', line)
if match:
status_code = int(match.group(1))
if status_code != 200 and status_code != 201:
print("got a BAD log")
bad_logs.append(line.strip())
else:
good_logs.append(line.strip())
print("\nBad Logs")
print(bad_logs)
if len(bad_logs)>0:
raise Exception(f"bad logs, Bad logs = {bad_logs}")
print("\nGood Logs")
print(good_logs)
if len(good_logs) <= 0:
raise Exception(f"There were no Good Logs from Langfuse. No logs with /api/public status 200. \nAll logs:{all_logs}")
except Exception as e:
raise e
def pre_langfuse_setup():
"""
Set up the logging for the 'pre_langfuse_setup' function.
"""
# sends logs to langfuse.log
import logging
# Configure the logging to write to a file
logging.basicConfig(filename="langfuse.log", level=logging.DEBUG)
logger = logging.getLogger()
# Add a FileHandler to the logger
file_handler = logging.FileHandler("langfuse.log", mode='w')
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
return
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_async(): def test_langfuse_logging_async():
try: try:
pre_langfuse_setup()
litellm.set_verbose = True litellm.set_verbose = True
async def _test_langfuse(): async def _test_langfuse():
return await litellm.acompletion( return await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}], messages=[{"role": "user", "content":"This is a test"}],
max_tokens=1000, max_tokens=100,
temperature=0.7, temperature=0.7,
timeout=5, timeout=5,
) )
response = asyncio.run(_test_langfuse()) response = asyncio.run(_test_langfuse())
print(f"response: {response}") print(f"response: {response}")
# time.sleep(2)
# # check langfuse.log to see if there was a failed response
# search_logs("langfuse.log")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
# test_langfuse_logging_async() test_langfuse_logging_async()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging(): def test_langfuse_logging():
try: try:
# litellm.set_verbose = True pre_langfuse_setup()
litellm.set_verbose = True
response = completion(model="claude-instant-1.2", response = completion(model="claude-instant-1.2",
messages=[{ messages=[{
"role": "user", "role": "user",
@ -43,17 +117,20 @@ def test_langfuse_logging():
}], }],
max_tokens=10, max_tokens=10,
temperature=0.2, temperature=0.2,
metadata={"langfuse/key": "foo"}
) )
print(response) print(response)
# time.sleep(5)
# # check langfuse.log to see if there was a failed response
# search_logs("langfuse.log")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
print(e) pytest.fail(f"An exception occurred - {e}")
test_langfuse_logging() test_langfuse_logging()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_stream(): def test_langfuse_logging_stream():
try: try:
litellm.set_verbose=True litellm.set_verbose=True
@ -77,6 +154,7 @@ def test_langfuse_logging_stream():
# test_langfuse_logging_stream() # test_langfuse_logging_stream()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_custom_generation_name(): def test_langfuse_logging_custom_generation_name():
try: try:
litellm.set_verbose=True litellm.set_verbose=True
@ -99,8 +177,8 @@ def test_langfuse_logging_custom_generation_name():
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
print(e) print(e)
test_langfuse_logging_custom_generation_name() # test_langfuse_logging_custom_generation_name()
@pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging_function_calling(): def test_langfuse_logging_function_calling():
function1 = [ function1 = [
{ {

View file

@ -17,10 +17,10 @@ model_alias_map = {
"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf" "good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"
} }
litellm.model_alias_map = model_alias_map
def test_model_alias_map(): def test_model_alias_map():
try: try:
litellm.model_alias_map = model_alias_map
response = completion( response = completion(
"good-model", "good-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],

View file

@ -1,100 +1,37 @@
##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ###### import sys, os
# import aiohttp import traceback
# import json from dotenv import load_dotenv
# import asyncio
# import requests
#
# async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# session = aiohttp.ClientSession()
# url = f'{api_base}/api/generate'
# data = {
# "model": model,
# "prompt": prompt,
# }
# response = "" load_dotenv()
import os, io
# try: sys.path.insert(
# async with session.post(url, json=data) as resp: 0, os.path.abspath("../..")
# async for line in resp.content.iter_any(): ) # Adds the parent directory to the system path
# if line: import pytest
# try: import litellm
# json_chunk = line.decode("utf-8")
# chunks = json_chunk.split("\n")
# for chunk in chunks:
# if chunk.strip() != "":
# j = json.loads(chunk)
# if "response" in j:
# print(j["response"])
# yield {
# "role": "assistant",
# "content": j["response"]
# }
# # self.responses.append(j["response"])
# # yield "blank"
# except Exception as e:
# print(f"Error decoding JSON: {e}")
# finally:
# await session.close()
# async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
# response = ""
# async for elem in generator:
# print(elem)
# response += elem["content"]
# return response
# #generator = get_ollama_response_stream()
# result = asyncio.run(get_ollama_response_no_stream())
# print(result)
# # return this generator to the client for streaming requests
# async def get_response(): ## for ollama we can't test making the completion call
# global generator from litellm.utils import get_optional_params, get_llm_provider
# async for elem in generator:
# print(elem)
# asyncio.run(get_response()) def test_get_ollama_params():
try:
converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", max_tokens=20, temperature=0.5, stream=True)
print("Converted params", converted_params)
assert converted_params == {'num_predict': 20, 'stream': True, 'temperature': 0.5}, f"{converted_params} != {'num_predict': 20, 'stream': True, 'temperature': 0.5}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_get_ollama_params()
def test_get_ollama_model():
try:
model, custom_llm_provider, _, _ = get_llm_provider("ollama/code-llama-22")
print("Model", "custom_llm_provider", model, custom_llm_provider)
assert custom_llm_provider == "ollama", f"{custom_llm_provider} != ollama"
assert model == "code-llama-22", f"{model} != code-llama-22"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
##### latest implementation of making raw http post requests to local ollama server # test_get_ollama_model()
# import requests
# import json
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# url = f"{api_base}/api/generate"
# data = {
# "model": model,
# "prompt": prompt,
# }
# session = requests.Session()
# with session.post(url, json=data, stream=True) as resp:
# for line in resp.iter_lines():
# if line:
# try:
# json_chunk = line.decode("utf-8")
# chunks = json_chunk.split("\n")
# for chunk in chunks:
# if chunk.strip() != "":
# j = json.loads(chunk)
# if "response" in j:
# completion_obj = {
# "role": "assistant",
# "content": "",
# }
# completion_obj["content"] = j["response"]
# yield {"choices": [{"delta": completion_obj}]}
# except Exception as e:
# print(f"Error decoding JSON: {e}")
# session.close()
# response = get_ollama_response_stream()
# for chunk in response:
# print(chunk['choices'][0]['delta'])

View file

@ -16,6 +16,19 @@
# user_message = "respond in 20 words. who are you?" # user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# async def test_async_ollama_streaming():
# try:
# litellm.set_verbose = True
# response = await litellm.acompletion(model="ollama/mistral-openorca",
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
# stream=True)
# async for chunk in response:
# print(chunk)
# except Exception as e:
# print(e)
# asyncio.run(test_async_ollama_streaming())
# def test_completion_ollama(): # def test_completion_ollama():
# try: # try:
# response = completion( # response = completion(
@ -29,7 +42,7 @@
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama() # # test_completion_ollama()
# def test_completion_ollama_with_api_base(): # def test_completion_ollama_with_api_base():
# try: # try:
@ -42,7 +55,7 @@
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_with_api_base() # # test_completion_ollama_with_api_base()
# def test_completion_ollama_custom_prompt_template(): # def test_completion_ollama_custom_prompt_template():
@ -72,7 +85,7 @@
# traceback.print_exc() # traceback.print_exc()
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_custom_prompt_template() # # test_completion_ollama_custom_prompt_template()
# async def test_completion_ollama_async_stream(): # async def test_completion_ollama_async_stream():
# user_message = "what is the weather" # user_message = "what is the weather"
@ -98,8 +111,8 @@
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# import asyncio # # import asyncio
# asyncio.run(test_completion_ollama_async_stream()) # # asyncio.run(test_completion_ollama_async_stream())
@ -154,8 +167,35 @@
# pass # pass
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_expect_error() # # test_completion_expect_error()
# if __name__ == "__main__":
# import asyncio # def test_ollama_llava():
# asyncio.run(main()) # litellm.set_verbose=True
# # same params as gpt-4 vision
# response = completion(
# model = "ollama/llava",
# messages=[
# {
# "role": "user",
# "content": [
# {
# "type": "text",
# "text": "What is in this picture"
# },
# {
# "type": "image_url",
# "image_url": {
# "url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"
# }
# }
# ]
# }
# ],
# )
# print("Response from ollama/llava")
# print(response)
# test_ollama_llava()
# # PROCESSED CHUNK PRE CHUNK CREATOR

View file

@ -0,0 +1,27 @@
#### What this tests ####
# This tests if get_optional_params works as expected
import sys, os, time, inspect, asyncio, traceback
import pytest
sys.path.insert(0, os.path.abspath('../..'))
import litellm
from litellm.utils import get_optional_params_embeddings
## get_optional_params_embeddings
### Models: OpenAI, Azure, Bedrock
### Scenarios: w/ optional params + litellm.drop_params = True
def test_bedrock_optional_params_embeddings():
litellm.drop_params = True
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="bedrock")
assert len(optional_params) == 0
def test_openai_optional_params_embeddings():
litellm.drop_params = True
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="openai")
assert len(optional_params) == 1
assert optional_params["user"] == "John"
def test_azure_optional_params_embeddings():
litellm.drop_params = True
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="azure")
assert len(optional_params) == 1
assert optional_params["user"] == "John"

View file

@ -19,21 +19,23 @@ from litellm import RateLimitError
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
# Here you create a fixture that will be used by your tests # Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app) # Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True) @pytest.fixture(scope="function")
def client(): def client():
with TestClient(app) as client: from litellm.proxy.proxy_server import cleanup_router_config_variables
yield client cleanup_router_config_variables()
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
app = FastAPI()
initialize(config=config_fp)
app.include_router(router) # Include your router in the test app
return TestClient(app)
def test_custom_auth(client): def test_custom_auth(client):
try: try:

View file

@ -3,7 +3,7 @@ import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import os, io, asyncio
# this file is to test litellm/proxy # this file is to test litellm/proxy
@ -21,21 +21,24 @@ from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
python_file_path = f"{filepath}/test_configs/custom_callbacks.py" python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
# Here you create a fixture that will be used by your tests # @app.on_event("startup")
# Make sure the fixture returns TestClient(app) # async def wrapper_startup_event():
@pytest.fixture(autouse=True) # initialize(config=config_fp)
# Use the app fixture in your client fixture
@pytest.fixture
def client(): def client():
with TestClient(app) as client: filepath = os.path.dirname(os.path.abspath(__file__))
yield client config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
initialize(config=config_fp)
app = FastAPI()
app.include_router(router) # Include your router in the test app
return TestClient(app)
# Your bearer token # Your bearer token
token = os.getenv("PROXY_MASTER_KEY") token = os.getenv("PROXY_MASTER_KEY")
@ -45,15 +48,76 @@ headers = {
} }
def test_chat_completion(client): print("Testing proxy custom logger")
def test_embedding(client):
try: try:
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(
value = "custom_callbacks.my_custom_logger",
config_file_path=python_file_path
)
print("id of initialized custom logger", id(my_custom_logger))
litellm.callbacks = [my_custom_logger]
# Your test data # Your test data
print("initialized proxy") print("initialized proxy")
# import the initialized custom logger # import the initialized custom logger
print(litellm.callbacks) print(litellm.callbacks)
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback # assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
my_custom_logger = litellm.callbacks[0] print("my_custom_logger", my_custom_logger)
assert my_custom_logger.async_success_embedding == False
test_data = {
"model": "azure-embedding-model",
"input": ["hello"]
}
response = client.post("/embeddings", json=test_data, headers=headers)
print("made request", response.status_code, response.text)
print("vars my custom logger /embeddings", vars(my_custom_logger), "id", id(my_custom_logger))
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
kwargs = my_custom_logger.async_embedding_kwargs
litellm_params = kwargs.get("litellm_params")
metadata = litellm_params.get("metadata", None)
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
assert metadata is not None
assert "user_api_key" in metadata
assert "headers" in metadata
proxy_server_request = litellm_params.get("proxy_server_request")
model_info = litellm_params.get("model_info")
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
result = response.json()
print(f"Received response: {result}")
print("Passed Embedding custom logger on proxy!")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
def test_chat_completion(client):
try:
# Your test data
print("initialized proxy")
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(
value = "custom_callbacks.my_custom_logger",
config_file_path=python_file_path
)
print("id of initialized custom logger", id(my_custom_logger))
litellm.callbacks = [my_custom_logger]
# import the initialized custom logger
print(litellm.callbacks)
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
print("LiteLLM Callbacks", litellm.callbacks)
print("my_custom_logger", my_custom_logger)
assert my_custom_logger.async_success == False assert my_custom_logger.async_success == False
test_data = { test_data = {
@ -61,7 +125,7 @@ def test_chat_completion(client):
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
"content": "hi" "content": "write a litellm poem"
}, },
], ],
"max_tokens": 10, "max_tokens": 10,
@ -70,33 +134,53 @@ def test_chat_completion(client):
response = client.post("/chat/completions", json=test_data, headers=headers) response = client.post("/chat/completions", json=test_data, headers=headers)
print("made request", response.status_code, response.text) print("made request", response.status_code, response.text)
print("LiteLLM Callbacks", litellm.callbacks)
asyncio.sleep(1) # sleep while waiting for callback to run
print("my_custom_logger in /chat/completions", my_custom_logger, "id", id(my_custom_logger))
print("vars my custom logger, ", vars(my_custom_logger))
assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct
print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs) print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs)
litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params") litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params")
metadata = litellm_params.get("metadata", None)
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
assert metadata is not None
assert "user_api_key" in metadata
assert "headers" in metadata
config_model_info = litellm_params.get("model_info") config_model_info = litellm_params.get("model_info")
proxy_server_request_object = litellm_params.get("proxy_server_request") proxy_server_request_object = litellm_params.get("proxy_server_request")
assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'} assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'}
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '105', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'hi'}], 'max_tokens': 10}} assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '123', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'write a litellm poem'}], 'max_tokens': 10}}
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")
print("\nPassed /chat/completions with Custom Logger!") print("\nPassed /chat/completions with Custom Logger!")
except Exception as e: except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e) pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
def test_chat_completion_stream(client): def test_chat_completion_stream(client):
try: try:
# Your test data # Your test data
litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn(
value = "custom_callbacks.my_custom_logger",
config_file_path=python_file_path
)
print("id of initialized custom logger", id(my_custom_logger))
litellm.callbacks = [my_custom_logger]
import json import json
print("initialized proxy") print("initialized proxy")
# import the initialized custom logger # import the initialized custom logger
print(litellm.callbacks) print(litellm.callbacks)
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
my_custom_logger = litellm.callbacks[0] print("LiteLLM Callbacks", litellm.callbacks)
print("my_custom_logger", my_custom_logger)
assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call
@ -148,37 +232,5 @@ def test_chat_completion_stream(client):
assert complete_response == streamed_response["choices"][0]["message"]["content"] assert complete_response == streamed_response["choices"][0]["message"]["content"]
except Exception as e: except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e) pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
def test_embedding(client):
try:
# Your test data
print("initialized proxy")
# import the initialized custom logger
print(litellm.callbacks)
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
my_custom_logger = litellm.callbacks[0]
assert my_custom_logger.async_success_embedding == False
test_data = {
"model": "azure-embedding-model",
"input": ["hello"]
}
response = client.post("/embeddings", json=test_data, headers=headers)
print("made request", response.status_code, response.text)
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
kwargs = my_custom_logger.async_embedding_kwargs
litellm_params = kwargs.get("litellm_params")
proxy_server_request = litellm_params.get("proxy_server_request")
model_info = litellm_params.get("model_info")
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
result = response.json()
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)

View file

@ -0,0 +1,177 @@
# test that the proxy actually does exception mapping to the OpenAI format
import sys, os
from dotenv import load_dotenv
load_dotenv()
import os, io, asyncio
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm, openai
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
@pytest.fixture
def client():
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
initialize(config=config_fp)
app = FastAPI()
app.include_router(router) # Include your router in the test app
return TestClient(app)
# raise openai.AuthenticationError
def test_chat_completion_exception(client):
try:
# Your test data
test_data = {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
response = client.post("/chat/completions", json=test_data)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
assert isinstance(openai_exception, openai.AuthenticationError)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
# raise openai.AuthenticationError
def test_chat_completion_exception_azure(client):
try:
# Your test data
test_data = {
"model": "azure-gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
response = client.post("/chat/completions", json=test_data)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
assert isinstance(openai_exception, openai.AuthenticationError)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
# raise openai.AuthenticationError
def test_embedding_auth_exception_azure(client):
try:
# Your test data
test_data = {
"model": "azure-embedding",
"input": ["hi"]
}
response = client.post("/embeddings", json=test_data)
print("Response from proxy=", response)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.AuthenticationError)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
# raise openai.BadRequestError
# chat/completions openai
def test_exception_openai_bad_model(client):
try:
# Your test data
test_data = {
"model": "azure/GPT-12",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
response = client.post("/chat/completions", json=test_data)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
print("Type of exception=", type(openai_exception))
assert isinstance(openai_exception, openai.NotFoundError)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
# chat/completions any model
def test_chat_completion_exception_any_model(client):
try:
# Your test data
test_data = {
"model": "Lite-GPT-12",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
response = client.post("/chat/completions", json=test_data)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.NotFoundError)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
# embeddings any model
def test_embedding_exception_any_model(client):
try:
# Your test data
test_data = {
"model": "Lite-GPT-12",
"input": ["hi"]
}
response = client.post("/embeddings", json=test_data)
print("Response from proxy=", response)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")
openai_exception = openai_client._make_status_error_from_response(response=response)
print("Exception raised=", openai_exception)
assert isinstance(openai_exception, openai.NotFoundError)
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")

View file

@ -24,30 +24,29 @@ logging.basicConfig(
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
initialize(config=config_fp)
# Your bearer token # Your bearer token
token = os.getenv("PROXY_MASTER_KEY") token = ""
headers = { headers = {
"Authorization": f"Bearer {token}" "Authorization": f"Bearer {token}"
} }
# Here you create a fixture that will be used by your tests @pytest.fixture(scope="function")
# Make sure the fixture returns TestClient(app) def client_no_auth():
@pytest.fixture(autouse=True) # Assuming litellm.proxy.proxy_server is an object
def client(): from litellm.proxy.proxy_server import cleanup_router_config_variables
with TestClient(app) as client: cleanup_router_config_variables()
yield client filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
initialize(config=config_fp)
app = FastAPI()
app.include_router(router) # Include your router in the test app
def test_chat_completion(client): return TestClient(app)
def test_chat_completion(client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -62,8 +61,8 @@ def test_chat_completion(client):
"max_tokens": 10, "max_tokens": 10,
} }
print("testing proxy server") print("testing proxy server with chat completions")
response = client.post("/v1/chat/completions", json=test_data, headers=headers) response = client_no_auth.post("/v1/chat/completions", json=test_data)
print(f"response - {response.text}") print(f"response - {response.text}")
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -73,7 +72,8 @@ def test_chat_completion(client):
# Run the test # Run the test
def test_chat_completion_azure(client): def test_chat_completion_azure(client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -88,8 +88,8 @@ def test_chat_completion_azure(client):
"max_tokens": 10, "max_tokens": 10,
} }
print("testing proxy server with Azure Request") print("testing proxy server with Azure Request /chat/completions")
response = client.post("/v1/chat/completions", json=test_data, headers=headers) response = client_no_auth.post("/v1/chat/completions", json=test_data)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -102,15 +102,55 @@ def test_chat_completion_azure(client):
# test_chat_completion_azure() # test_chat_completion_azure()
def test_embedding(client): def test_embedding(client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth
try: try:
test_data = { test_data = {
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"input": ["good morning from litellm"], "input": ["good morning from litellm"],
} }
print("testing proxy server with OpenAI embedding")
response = client.post("/v1/embeddings", json=test_data, headers=headers) response = client_no_auth.post("/v1/embeddings", json=test_data)
assert response.status_code == 200
result = response.json()
print(len(result["data"][0]["embedding"]))
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
def test_bedrock_embedding(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth
try:
test_data = {
"model": "amazon-embeddings",
"input": ["good morning from litellm"],
}
response = client_no_auth.post("/v1/embeddings", json=test_data)
assert response.status_code == 200
result = response.json()
print(len(result["data"][0]["embedding"]))
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
def test_sagemaker_embedding(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth
try:
test_data = {
"model": "GPT-J 6B - Sagemaker Text Embedding (Internal)",
"input": ["good morning from litellm"],
}
response = client_no_auth.post("/v1/embeddings", json=test_data)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -122,8 +162,8 @@ def test_embedding(client):
# Run the test # Run the test
# test_embedding() # test_embedding()
@pytest.mark.skip(reason="hitting yaml load issues on circle-ci") # @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
def test_add_new_model(client): def test_add_new_model(client_no_auth):
global headers global headers
try: try:
test_data = { test_data = {
@ -135,15 +175,15 @@ def test_add_new_model(client):
"description": "this is a test openai model" "description": "this is a test openai model"
} }
} }
client.post("/model/new", json=test_data, headers=headers) client_no_auth.post("/model/new", json=test_data, headers=headers)
response = client.get("/model/info", headers=headers) response = client_no_auth.get("/model/info", headers=headers)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"response: {result}") print(f"response: {result}")
model_info = None model_info = None
for m in result["data"]: for m in result["data"]:
if m["id"]["model_name"] == "test_openai_models": if m["model_name"] == "test_openai_models":
model_info = m["id"]["model_info"] model_info = m["model_info"]
assert model_info["description"] == "this is a test openai model" assert model_info["description"] == "this is a test openai model"
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
@ -164,10 +204,9 @@ class MyCustomHandler(CustomLogger):
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
def test_chat_completion_optional_params(client): def test_chat_completion_optional_params(client_no_auth):
# [PROXY: PROD TEST] - DO NOT DELETE # [PROXY: PROD TEST] - DO NOT DELETE
# This tests if all the /chat/completion params are passed to litellm # This tests if all the /chat/completion params are passed to litellm
try: try:
# Your test data # Your test data
litellm.set_verbose=True litellm.set_verbose=True
@ -185,7 +224,7 @@ def test_chat_completion_optional_params(client):
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
print("testing proxy server: optional params") print("testing proxy server: optional params")
response = client.post("/v1/chat/completions", json=test_data, headers=headers) response = client_no_auth.post("/v1/chat/completions", json=test_data)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")
@ -217,6 +256,29 @@ def test_load_router_config():
print(result) print(result)
assert len(result[1]) == 2 assert len(result[1]) == 2
# tests for litellm.cache set from config
print("testing reading proxy config for cache")
litellm.cache = None
load_router_config(
router=None,
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml"
)
assert litellm.cache is not None
assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy
assert litellm.cache.supported_call_types == ['completion', 'acompletion', 'embedding', 'aembedding'] # init with all call types
print("testing reading proxy config for cache with params")
load_router_config(
router=None,
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml"
)
assert litellm.cache is not None
print(litellm.cache)
print(litellm.cache.supported_call_types)
print(vars(litellm.cache.cache))
assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy
assert litellm.cache.supported_call_types == ['embedding', 'aembedding'] # init with all call types
except Exception as e: except Exception as e:
pytest.fail("Proxy: Got exception reading config", e) pytest.fail("Proxy: Got exception reading config", e)
# test_load_router_config() # test_load_router_config()

View file

@ -37,6 +37,8 @@ async def wrapper_startup_event():
# Make sure the fixture returns TestClient(app) # Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def client(): def client():
from litellm.proxy.proxy_server import cleanup_router_config_variables
cleanup_router_config_variables()
with TestClient(app) as client: with TestClient(app) as client:
yield client yield client
@ -69,6 +71,38 @@ def test_add_new_key(client):
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
def test_update_new_key(client):
try:
# Your test data
test_data = {
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m"
}
print("testing proxy server")
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
assert result["key"].startswith("sk-")
def _post_data():
json_data = {'models': ['bedrock-models'], "key": result["key"]}
response = client.post("/key/update", json=json_data, headers=headers)
print(f"response text: {response.text}")
assert response.status_code == 200
return response
_post_data()
print(f"Received response: {result}")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
# # Run the test - only runs via pytest # # Run the test - only runs via pytest

View file

@ -366,69 +366,12 @@ def test_function_calling():
} }
] ]
router = Router(model_list=model_list, routing_strategy="latency-based-routing") router = Router(model_list=model_list)
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
router.reset() router.reset()
print(response) print(response)
def test_acompletion_on_router(): # test_acompletion_on_router()
# tests acompletion + caching on router
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
"tpm": 100000,
"rpm": 10000,
}
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
start_time = time.time()
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy="simple-shuffle")
async def get_response():
print("Testing acompletion + caching on router")
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response1: {response1}")
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response2: {response2}")
assert response1.id == response2.id
assert len(response1.choices[0].message.content) > 0
assert response1.choices[0].message.content == response2.choices[0].message.content
asyncio.run(get_response())
router.reset()
except litellm.Timeout as e:
end_time = time.time()
print(f"timeout error occurred: {end_time - start_time}")
pass
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
test_acompletion_on_router()
def test_function_calling_on_router(): def test_function_calling_on_router():
try: try:
@ -507,7 +450,6 @@ def test_aembedding_on_router():
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm 2"], input=["good morning from litellm 2"],
) )
print("sync embedding response: ", response)
router.reset() router.reset()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -591,6 +533,30 @@ def test_bedrock_on_router():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_bedrock_on_router() # test_bedrock_on_router()
# test openai-compatible endpoint
@pytest.mark.asyncio
async def test_mistral_on_router():
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "mistral/mistral-medium",
},
},
]
router = Router(model_list=model_list)
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello from litellm test",
}
]
)
print(response)
asyncio.run(test_mistral_on_router())
def test_openai_completion_on_router(): def test_openai_completion_on_router():
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream

View file

@ -0,0 +1,127 @@
#### What this tests ####
# This tests caching on the router
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
## Scenarios
## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo",
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
@pytest.mark.asyncio
async def test_acompletion_caching_on_router():
# tests acompletion + caching on router
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
"tpm": 100000,
"rpm": 10000,
}
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
start_time = time.time()
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy="simple-shuffle")
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response1: {response1}")
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response2: {response2}")
assert response1.id == response2.id
assert len(response1.choices[0].message.content) > 0
assert response1.choices[0].message.content == response2.choices[0].message.content
router.reset()
except litellm.Timeout as e:
end_time = time.time()
print(f"timeout error occurred: {end_time - start_time}")
pass
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_caching_on_router_caching_groups():
# tests acompletion + caching on router
try:
litellm.set_verbose = True
model_list = [
{
"model_name": "openai-gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
{
"model_name": "azure-gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
"tpm": 100000,
"rpm": 10000,
}
]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
start_time = time.time()
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
cache_responses=True,
timeout=30,
routing_strategy="simple-shuffle",
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response1: {response1}")
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
print(f"response2: {response2}")
assert response1.id == response2.id
assert len(response1.choices[0].message.content) > 0
assert response1.choices[0].message.content == response2.choices[0].message.content
router.reset()
except litellm.Timeout as e:
end_time = time.time()
print(f"timeout error occurred: {end_time - start_time}")
pass
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")

View file

@ -21,21 +21,36 @@ class MyCustomHandler(CustomLogger):
print(f"Pre-API Call") print(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print(f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream") print(f"On Stream")
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}") print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}")
self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]} self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}") print(f"self.previous_models: {self.previous_models}")
print(f"On Success") print(f"On Success")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}")
self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}")
print(f"On Success")
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure") print(f"On Failure")
model_list = [
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
def test_sync_fallbacks():
try:
model_list = [
{ # list of model deployments { # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name "model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
@ -87,14 +102,8 @@ model_list = [
"tpm": 1000000, "tpm": 1000000,
"rpm": 9000 "rpm": 9000
} }
] ]
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
def test_sync_fallbacks():
try:
litellm.set_verbose = True litellm.set_verbose = True
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
@ -106,18 +115,74 @@ def test_sync_fallbacks():
print(f"response: {response}") print(f"response: {response}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 1 # 0 retries, 1 fallback
print("Passed ! Test router_fallbacks: test_sync_fallbacks()")
router.reset() router.reset()
except Exception as e: except Exception as e:
print(e) print(e)
# test_sync_fallbacks() # test_sync_fallbacks()
def test_async_fallbacks(): @pytest.mark.asyncio
async def test_async_fallbacks():
litellm.set_verbose = False litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
router = Router(model_list=model_list, router = Router(model_list=model_list,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
set_verbose=False) set_verbose=False)
async def test_get_response():
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
@ -125,7 +190,7 @@ def test_async_fallbacks():
try: try:
response = await router.acompletion(**kwargs) response = await router.acompletion(**kwargs)
print(f"customHandler.previous_models: {customHandler.previous_models}") print(f"customHandler.previous_models: {customHandler.previous_models}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread await asyncio.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 1 # 0 retries, 1 fallback
router.reset() router.reset()
except litellm.Timeout as e: except litellm.Timeout as e:
@ -134,34 +199,9 @@ def test_async_fallbacks():
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
finally: finally:
router.reset() router.reset()
asyncio.run(test_get_response())
# test_async_fallbacks() # test_async_fallbacks()
## COMMENTING OUT as the context size exceeds both gpt-3.5-turbo and gpt-3.5-turbo-16k, need a better message here
# def test_sync_context_window_fallbacks():
# try:
# customHandler = MyCustomHandler()
# litellm.callbacks = [customHandler]
# sample_text = "Say error 50 times" * 10000
# kwargs["model"] = "azure/gpt-3.5-turbo-context-fallback"
# kwargs["messages"] = [{"role": "user", "content": sample_text}]
# router = Router(model_list=model_list,
# fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
# context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
# set_verbose=False)
# response = router.completion(**kwargs)
# print(f"response: {response}")
# time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
# assert customHandler.previous_models == 1 # 0 retries, 1 fallback
# router.reset()
# except Exception as e:
# print(f"An exception occurred - {e}")
# finally:
# router.reset()
# test_sync_context_window_fallbacks()
def test_dynamic_fallbacks_sync(): def test_dynamic_fallbacks_sync():
""" """
Allow setting the fallback in the router.completion() call. Allow setting the fallback in the router.completion() call.
@ -169,6 +209,60 @@ def test_dynamic_fallbacks_sync():
try: try:
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
router = Router(model_list=model_list, set_verbose=True) router = Router(model_list=model_list, set_verbose=True)
kwargs = {} kwargs = {}
kwargs["model"] = "azure/gpt-3.5-turbo" kwargs["model"] = "azure/gpt-3.5-turbo"
@ -184,12 +278,71 @@ def test_dynamic_fallbacks_sync():
# test_dynamic_fallbacks_sync() # test_dynamic_fallbacks_sync()
def test_dynamic_fallbacks_async(): @pytest.mark.asyncio
async def test_dynamic_fallbacks_async():
""" """
Allow setting the fallback in the router.completion() call. Allow setting the fallback in the router.completion() call.
""" """
async def test_get_response():
try: try:
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
print()
print()
print()
print()
print(f"STARTING DYNAMIC ASYNC")
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
router = Router(model_list=model_list, set_verbose=True) router = Router(model_list=model_list, set_verbose=True)
@ -198,12 +351,10 @@ def test_dynamic_fallbacks_async():
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}] kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
response = await router.acompletion(**kwargs) response = await router.acompletion(**kwargs)
print(f"response: {response}") print(f"RESPONSE: {response}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread await asyncio.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback assert customHandler.previous_models == 1 # 0 retries, 1 fallback
router.reset() router.reset()
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
asyncio.run(test_get_response()) # asyncio.run(test_dynamic_fallbacks_async())
# test_dynamic_fallbacks_async()

Some files were not shown because too many files have changed in this diff Show more