forked from phoenix/litellm-mirror
Merge remote-tracking branch 'upstream/main' into patch-1
This commit is contained in:
commit
cb13018a28
121 changed files with 23783 additions and 1997 deletions
|
@ -29,12 +29,16 @@ jobs:
|
|||
pip install pytest-asyncio
|
||||
pip install mypy
|
||||
pip install -q google-generativeai
|
||||
pip install google-cloud-aiplatform
|
||||
pip install "boto3>=1.28.57"
|
||||
pip install appdirs
|
||||
pip install langchain
|
||||
pip install langfuse
|
||||
pip install numpydoc
|
||||
pip install traceloop-sdk==0.0.69
|
||||
pip install openai
|
||||
pip install prisma
|
||||
pip install langfuse
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -44,7 +48,7 @@ jobs:
|
|||
command: |
|
||||
cd litellm
|
||||
python -m pip install types-requests types-setuptools types-redis
|
||||
if ! python -m mypy . --ignore-missing-imports --explicit-package-bases; then
|
||||
if ! python -m mypy . --ignore-missing-imports; then
|
||||
echo "mypy detected errors"
|
||||
exit 1
|
||||
fi
|
||||
|
@ -57,7 +61,7 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv litellm/tests/ -x --junitxml=test-results/junit.xml --durations=5
|
||||
python -m pytest -vv litellm/tests/ -x --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
|
||||
# Store test results
|
||||
|
@ -74,6 +78,11 @@ jobs:
|
|||
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Copy model_prices_and_context_window File to model_prices_and_context_window_backup
|
||||
command: |
|
||||
cp model_prices_and_context_window.json litellm/model_prices_and_context_window_backup.json
|
||||
|
||||
- run:
|
||||
name: Check if litellm dir was updated or if pyproject.toml was modified
|
||||
|
|
|
@ -10,4 +10,5 @@ anthropic
|
|||
boto3
|
||||
appdirs
|
||||
orjson
|
||||
pydantic
|
||||
pydantic
|
||||
google-cloud-aiplatform
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -19,3 +19,6 @@ litellm/proxy/_secret_config.yaml
|
|||
litellm/tests/aiologs.log
|
||||
litellm/tests/exception_data.txt
|
||||
litellm/tests/config_*.yaml
|
||||
litellm/tests/langfuse.log
|
||||
litellm/tests/test_custom_logger.py
|
||||
litellm/tests/langfuse.log
|
||||
|
|
|
@ -3,6 +3,15 @@ repos:
|
|||
rev: 3.8.4 # The version of flake8 to use
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: ^litellm/tests/|^litellm/proxy/|^litellm/integrations/
|
||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
|
||||
additional_dependencies: [flake8-print]
|
||||
files: litellm/.*\.py
|
||||
files: litellm/.*\.py
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
||||
exclude: ^litellm/tests/
|
|
@ -34,6 +34,9 @@ COPY --from=builder /app/wheels /app/wheels
|
|||
|
||||
RUN pip install --no-index --find-links=/app/wheels -r requirements.txt
|
||||
|
||||
# Trigger the Prisma CLI to be installed
|
||||
RUN prisma -v
|
||||
|
||||
EXPOSE 4000/tcp
|
||||
|
||||
# Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy
|
||||
|
|
18
README.md
18
README.md
|
@ -5,7 +5,7 @@
|
|||
<p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, Cohere, TogetherAI, Azure, OpenAI, etc.]
|
||||
<br>
|
||||
</p>
|
||||
<h4 align="center"><a href="https://github.com/BerriAI/litellm/tree/main/litellm/proxy" target="_blank">OpenAI-Compatible Server</a></h4>
|
||||
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a></h4>
|
||||
<h4 align="center">
|
||||
<a href="https://pypi.org/project/litellm/" target="_blank">
|
||||
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
|
||||
|
@ -62,6 +62,22 @@ response = completion(model="command-nightly", messages=messages)
|
|||
print(response)
|
||||
```
|
||||
|
||||
## Async ([Docs](https://docs.litellm.ai/docs/completion/stream#async-completion))
|
||||
|
||||
```python
|
||||
from litellm import acompletion
|
||||
import asyncio
|
||||
|
||||
async def test_get_response():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
return response
|
||||
|
||||
response = asyncio.run(test_get_response())
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream))
|
||||
liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response.
|
||||
Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.)
|
||||
|
|
BIN
dist/litellm-1.12.5.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.5.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.5.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.5.dev1.tar.gz
vendored
Normal file
Binary file not shown.
12
docker-compose.example.yml
Normal file
12
docker-compose.example.yml
Normal file
|
@ -0,0 +1,12 @@
|
|||
version: "3.9"
|
||||
services:
|
||||
litellm:
|
||||
image: ghcr.io/berriai/litellm:main
|
||||
ports:
|
||||
- "8000:8000" # Map the container port to the host, change the host port if necessary
|
||||
volumes:
|
||||
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
|
||||
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
|
||||
command: [ "--config", "/app/config.yaml", "--port", "8000", "--num_workers", "8" ]
|
||||
|
||||
# ...rest of your docker-compose config if any
|
|
@ -1,11 +1,14 @@
|
|||
# Redis Cache
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/caching.py#L71)
|
||||
|
||||
### Pre-requisites
|
||||
Install redis
|
||||
```
|
||||
pip install redis
|
||||
```
|
||||
For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
|
||||
### Usage
|
||||
### Quick Start
|
||||
```python
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
@ -55,6 +58,11 @@ litellm.cache = cache # set litellm.cache to your cache
|
|||
### Detecting Cached Responses
|
||||
For resposes that were returned as cache hit, the response includes a param `cache` = True
|
||||
|
||||
:::info
|
||||
|
||||
Only valid for OpenAI <= 0.28.1 [Let us know if you still need this](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+)
|
||||
:::
|
||||
|
||||
Example response with cache hit
|
||||
```python
|
||||
{
|
||||
|
|
|
@ -6,7 +6,7 @@ import TabItem from '@theme/TabItem';
|
|||
## Common Params
|
||||
LiteLLM accepts and translates the [OpenAI Chat Completion params](https://platform.openai.com/docs/api-reference/chat/create) across all providers.
|
||||
|
||||
### usage
|
||||
### Usage
|
||||
```python
|
||||
import litellm
|
||||
|
||||
|
@ -23,7 +23,7 @@ response = litellm.completion(
|
|||
print(response)
|
||||
```
|
||||
|
||||
### translated OpenAI params
|
||||
### Translated OpenAI params
|
||||
This is a list of openai params we translate across providers.
|
||||
|
||||
This list is constantly being updated.
|
||||
|
@ -40,7 +40,7 @@ This list is constantly being updated.
|
|||
|AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|
||||
|VertexAI| ✅ | ✅ | | ✅ | | | | | | |
|
||||
|Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|Sagemaker| ✅ | ✅ | | ✅ | | | | | | |
|
||||
|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | |
|
||||
|
@ -185,6 +185,25 @@ def completion(
|
|||
|
||||
- `metadata`: *dict (optional)* - Any additional data you want to be logged when the call is made (sent to logging integrations, eg. promptlayer and accessible via custom callback function)
|
||||
|
||||
**CUSTOM MODEL COST**
|
||||
- `input_cost_per_token`: *float (optional)* - The cost per input token for the completion call
|
||||
|
||||
- `output_cost_per_token`: *float (optional)* - The cost per output token for the completion call
|
||||
|
||||
**CUSTOM PROMPT TEMPLATE** (See [prompt formatting for more info](./prompt_formatting.md#format-prompt-yourself))
|
||||
- `initial_prompt_value`: *string (optional)* - Initial string applied at the start of the input messages
|
||||
|
||||
- `roles`: *dict (optional)* - Dictionary specifying how to format the prompt based on the role + message passed in via `messages`.
|
||||
|
||||
- `final_prompt_value`: *string (optional)* - Final string applied at the end of the input messages
|
||||
|
||||
- `bos_token`: *string (optional)* - Initial string applied at the start of a sequence
|
||||
|
||||
- `eos_token`: *string (optional)* - Initial string applied at the end of a sequence
|
||||
|
||||
- `hf_model_name`: *string (optional)* - [Sagemaker Only] The corresponding huggingface name of the model, used to pull the right chat template for the model.
|
||||
|
||||
|
||||
## Provider-specific Params
|
||||
Providers might offer params not supported by OpenAI (e.g. top_k). You can pass those in 2 ways:
|
||||
- via completion(): We'll pass the non-openai param, straight to the provider as part of the request body.
|
||||
|
|
|
@ -142,6 +142,8 @@ print(response)
|
|||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Titan Embeddings - G1 | `embedding(model="amazon.titan-embed-text-v1", input=input)` |
|
||||
| Cohere Embeddings - English | `embedding(model="cohere.embed-english-v3", input=input)` |
|
||||
| Cohere Embeddings - Multilingual | `embedding(model="cohere.embed-multilingual-v3", input=input)` |
|
||||
|
||||
|
||||
## Cohere Embedding Models
|
||||
|
@ -182,6 +184,17 @@ response = embedding(
|
|||
input=["good morning from litellm"]
|
||||
)
|
||||
```
|
||||
### Usage - Custom API Base
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
os.environ['HUGGINGFACE_API_KEY'] = ""
|
||||
response = embedding(
|
||||
model='huggingface/microsoft/codebert-base',
|
||||
input=["good morning from litellm"],
|
||||
api_base = "https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
)
|
||||
```
|
||||
|
||||
| Model Name | Function Call | Required OS Variables |
|
||||
|-----------------------|--------------------------------------------------------------|-------------------------------------------------|
|
||||
|
|
|
@ -85,6 +85,43 @@ print(response)
|
|||
|
||||
```
|
||||
|
||||
## Async Callback Functions
|
||||
|
||||
LiteLLM currently supports just async success callback functions for async completion/embedding calls.
|
||||
|
||||
```python
|
||||
import asyncio, litellm
|
||||
|
||||
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time):
|
||||
print(f"On Async Success!")
|
||||
|
||||
async def test_chat_openai():
|
||||
try:
|
||||
# litellm.set_verbose = True
|
||||
litellm.success_callback = [async_test_logging_fn]
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
asyncio.run(test_chat_openai())
|
||||
```
|
||||
|
||||
:::info
|
||||
|
||||
We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007)
|
||||
|
||||
|
||||
|
||||
:::
|
||||
|
||||
## What's in kwargs?
|
||||
|
||||
Notice we pass in a kwargs argument to custom callback.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# AWS Sagemaker
|
||||
LiteLLM supports Llama2 on Sagemaker
|
||||
LiteLLM supports All Sagemaker Huggingface Jumpstart Models
|
||||
|
||||
### API KEYS
|
||||
```python
|
||||
|
@ -42,6 +42,28 @@ response = completion(
|
|||
)
|
||||
```
|
||||
|
||||
### Applying Prompt Templates
|
||||
To apply the correct prompt template for your sagemaker deployment, pass in it's hf model name as well.
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = completion(
|
||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
hf_model_name="meta-llama/Llama-2-7b",
|
||||
)
|
||||
```
|
||||
|
||||
You can also pass in your own [custom prompt template](../completion/prompt_formatting.md#format-prompt-yourself)
|
||||
|
||||
### Usage - Streaming
|
||||
Sagemaker currently does not support streaming - LiteLLM fakes streaming by returning chunks of the response string
|
||||
|
||||
|
@ -64,14 +86,32 @@ for chunk in response:
|
|||
print(chunk)
|
||||
```
|
||||
|
||||
### AWS Sagemaker Models
|
||||
### Completion Models
|
||||
Here's an example of using a sagemaker model with LiteLLM
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-------------------------------|-------------------------------------------------------------------------------------------|
|
||||
| Your Custom Huggingface Model | `completion(model='sagemaker/<your-deployment-name>', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']`
|
||||
| Meta Llama 2 7B | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 7B (Chat/Fine-tuned) | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 13B | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-13b', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 13B (Chat/Fine-tuned) | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-13b-f', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 70B | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-70b', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 70B (Chat/Fine-tuned) | `completion(model='sagemaker/jumpstart-dft-meta-textgeneration-llama-2-70b-b-f', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
|
||||
### Embedding Models
|
||||
|
||||
LiteLLM supports all Sagemaker Jumpstart Huggingface Embedding models. Here's how to call it:
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = litellm.embedding(model="sagemaker/<your-deployment-name>", input=["good morning from litellm", "this is another item"])
|
||||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -174,3 +174,5 @@ print(response)
|
|||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Titan Embeddings - G1 | `embedding(model="amazon.titan-embed-text-v1", input=input)` |
|
||||
| Cohere Embeddings - English | `embedding(model="cohere.embed-english-v3", input=input)` |
|
||||
| Cohere Embeddings - Multilingual | `embedding(model="cohere.embed-multilingual-v3", input=input)` |
|
||||
|
|
|
@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM
|
|||
|
||||
Model Name | Function Call | Required OS Variables |
|
||||
-----------------------------|----------------------------------------------------------------|--------------------------------------|
|
||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
|
|
|
@ -11,18 +11,27 @@ model_list:
|
|||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: # init cache
|
||||
type: redis # tell litellm to use redis caching
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
```
|
||||
|
||||
#### Step 2: Add Redis Credentials to .env
|
||||
LiteLLM requires the following REDIS credentials in your env to enable caching
|
||||
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
||||
|
||||
```shell
|
||||
REDIS_URL = "" # REDIS_URL='redis://username:password@hostname:port/database'
|
||||
## OR ##
|
||||
REDIS_HOST = "" # REDIS_HOST='redis-18841.c274.us-east-1-3.ec2.cloud.redislabs.com'
|
||||
REDIS_PORT = "" # REDIS_PORT='18841'
|
||||
REDIS_PASSWORD = "" # REDIS_PASSWORD='liteLlmIsAmazing'
|
||||
```
|
||||
|
||||
**Additional kwargs**
|
||||
You can pass in any additional redis.Redis arg, by storing the variable + value in your os environment, like this:
|
||||
```shell
|
||||
REDIS_<redis-kwarg-name> = ""
|
||||
```
|
||||
|
||||
[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40)
|
||||
#### Step 3: Run proxy with config
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
|
|
|
@ -1,91 +1,86 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Proxy Config.yaml
|
||||
Set model list, `api_base`, `api_key`, `temperature` & proxy server settings (`master-key`) on the config.yaml.
|
||||
|
||||
| Param Name | Description |
|
||||
|----------------------|---------------------------------------------------------------|
|
||||
| `model_list` | List of supported models on the server, with model-specific configs |
|
||||
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base`, `litellm.cache` |
|
||||
| `router_settings` | litellm Router settings, example `routing_strategy="least-busy"` [**see all**](https://github.com/BerriAI/litellm/blob/6ef0e8485e0e720c0efa6f3075ce8119f2f62eea/litellm/router.py#L64)|
|
||||
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base`, `litellm.cache` [**see all**](https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py)|
|
||||
| `general_settings` | Server settings, example setting `master_key: sk-my_special_key` |
|
||||
| `environment_variables` | Environment Variables example, `REDIS_HOST`, `REDIS_PORT` |
|
||||
|
||||
#### Example Config
|
||||
**Complete List:** Check the Swagger UI docs on `<your-proxy-url>/#/config.yaml` (e.g. http://0.0.0.0:8000/#/config.yaml), for everything you can pass in the config.yaml.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
Set a model alias for your deployments.
|
||||
|
||||
In the `config.yaml` the model_name parameter is the user-facing name to use for your deployment.
|
||||
|
||||
In the config below requests with:
|
||||
- `model=vllm-models` will route to `openai/facebook/opt-125m`.
|
||||
- `model=gpt-3.5-turbo` will load balance between `azure/gpt-turbo-small-eu` and `azure/gpt-turbo-small-ca`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
- model_name: gpt-3.5-turbo # user-facing model alias
|
||||
litellm_params: # all params accepted by litellm.completion() - https://docs.litellm.ai/docs/completion/input
|
||||
model: azure/gpt-turbo-small-eu
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key:
|
||||
api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU")
|
||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||
- model_name: bedrock-claude-v1
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-instant-v1
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key:
|
||||
api_key: "os.environ/AZURE_API_KEY_CA"
|
||||
rpm: 6
|
||||
- model_name: gpt-3.5-turbo
|
||||
- model_name: vllm-models
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-large
|
||||
api_base: https://openai-france-1234.openai.azure.com/
|
||||
api_key:
|
||||
model: openai/facebook/opt-125m # the `openai/` prefix tells litellm it's openai compatible
|
||||
api_base: http://0.0.0.0:8000
|
||||
rpm: 1440
|
||||
model_info:
|
||||
version: 2
|
||||
|
||||
litellm_settings:
|
||||
litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py
|
||||
drop_params: True
|
||||
set_verbose: True
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||
|
||||
|
||||
environment_variables:
|
||||
OPENAI_API_KEY: sk-123
|
||||
REPLICATE_API_KEY: sk-cohere-is-okay
|
||||
REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
||||
REDIS_PORT: "16337"
|
||||
REDIS_PASSWORD:
|
||||
```
|
||||
|
||||
### Config for Multiple Models - GPT-4, Claude-2
|
||||
|
||||
Here's how you can use multiple llms with one proxy `config.yaml`.
|
||||
|
||||
#### Step 1: Setup Config
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: zephyr-alpha # the 1st model is the default on the proxy
|
||||
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-alpha
|
||||
api_base: http://0.0.0.0:8001
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: gpt-4
|
||||
api_key: sk-1233
|
||||
- model_name: claude-2
|
||||
litellm_params:
|
||||
model: claude-2
|
||||
api_key: sk-claude
|
||||
```
|
||||
|
||||
:::info
|
||||
|
||||
The proxy uses the first model in the config as the default model - in this config the default model is `zephyr-alpha`
|
||||
:::
|
||||
|
||||
|
||||
#### Step 2: Start Proxy with config
|
||||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
#### Step 3: Use proxy
|
||||
Curl Command
|
||||
|
||||
### Using Proxy - Curl Request, OpenAI Package, Langchain, Langchain JS
|
||||
Calling a model group
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
||||
Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml.
|
||||
|
||||
If multiple with `model_name=gpt-3.5-turbo` does [Load Balancing](https://docs.litellm.ai/docs/proxy/load_balancing)
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "zephyr-alpha",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -95,33 +90,109 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
|
|||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
### Config for Embedding Models - xorbitsai/inference
|
||||
<TabItem value="Curl2" label="Curl Request: Bedrock">
|
||||
|
||||
Here's how you can use multiple llms with one proxy `config.yaml`.
|
||||
Here is how [LiteLLM calls OpenAI Compatible Embedding models](https://docs.litellm.ai/docs/embedding/supported_embedding#openai-compatible-embedding-models)
|
||||
Sends this request to model where `model_name=bedrock-claude-v1` on config.yaml
|
||||
|
||||
#### Config
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: custom_embedding_model
|
||||
litellm_params:
|
||||
model: openai/custom_embedding # the `openai/` prefix tells litellm it's openai compatible
|
||||
api_base: http://0.0.0.0:8000/
|
||||
- model_name: custom_embedding_model
|
||||
litellm_params:
|
||||
model: openai/custom_embedding # the `openai/` prefix tells litellm it's openai compatible
|
||||
api_base: http://0.0.0.0:8001/
|
||||
```
|
||||
|
||||
Run the proxy using this config
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "bedrock-claude-v1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai" label="OpenAI v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:8000"
|
||||
)
|
||||
|
||||
# Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml.
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
|
||||
# Sends this request to model where `model_name=bedrock-claude-v1` on config.yaml
|
||||
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
### Save Model-specific params (API Base, API Keys, Temperature, Headers etc.)
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain Python">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
|
||||
# Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml.
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:8000", # set openai base to the proxy
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
response = chat(messages)
|
||||
print(response)
|
||||
|
||||
# Sends request to model where `model_name=bedrock-claude-v1` on config.yaml.
|
||||
claude_chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:8000", # set openai base to the proxy
|
||||
model = "bedrock-claude-v1",
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
response = claude_chat(messages)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Save Model-specific params (API Base, API Keys, Temperature, Headers etc.)
|
||||
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
|
||||
|
||||
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
|
||||
|
||||
**Step 1**: Create a `config.yaml` file
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -152,9 +223,11 @@ model_list:
|
|||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### Load API Keys from Vault
|
||||
## Load API Keys
|
||||
|
||||
If you have secrets saved in Azure Vault, etc. and don't want to expose them in the config.yaml, here's how to load model-specific keys from the environment.
|
||||
### Load API Keys from Environment
|
||||
|
||||
If you have secrets saved in your environment, and don't want to expose them in the config.yaml, here's how to load model-specific keys from the environment.
|
||||
|
||||
```python
|
||||
os.environ["AZURE_NORTH_AMERICA_API_KEY"] = "your-azure-api-key"
|
||||
|
@ -174,30 +247,42 @@ model_list:
|
|||
|
||||
s/o to [@David Manouchehri](https://www.linkedin.com/in/davidmanouchehri/) for helping with this.
|
||||
|
||||
### Config for setting Model Aliases
|
||||
### Load API Keys from Azure Vault
|
||||
|
||||
Set a model alias for your deployments.
|
||||
1. Install Proxy dependencies
|
||||
```bash
|
||||
$ pip install litellm[proxy] litellm[extra_proxy]
|
||||
```
|
||||
|
||||
In the `config.yaml` the model_name parameter is the user-facing name to use for your deployment.
|
||||
|
||||
In the config below requests with `model=gpt-4` will route to `ollama/llama2`
|
||||
2. Save Azure details in your environment
|
||||
```bash
|
||||
export["AZURE_CLIENT_ID"]="your-azure-app-client-id"
|
||||
export["AZURE_CLIENT_SECRET"]="your-azure-app-client-secret"
|
||||
export["AZURE_TENANT_ID"]="your-azure-tenant-id"
|
||||
export["AZURE_KEY_VAULT_URI"]="your-azure-key-vault-uri"
|
||||
```
|
||||
|
||||
3. Add to proxy config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: text-davinci-003
|
||||
litellm_params:
|
||||
model: ollama/zephyr
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: ollama/llama2
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: ollama/llama2
|
||||
model_list:
|
||||
- model_name: "my-azure-models" # model alias
|
||||
litellm_params:
|
||||
model: "azure/<your-deployment-name>"
|
||||
api_key: "os.environ/AZURE-API-KEY" # reads from key vault - get_secret("AZURE_API_KEY")
|
||||
api_base: "os.environ/AZURE-API-BASE" # reads from key vault - get_secret("AZURE_API_BASE")
|
||||
|
||||
general_settings:
|
||||
use_azure_key_vault: True
|
||||
```
|
||||
|
||||
You can now test this by starting your proxy:
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### Set Custom Prompt Templates
|
||||
|
||||
LiteLLM by default checks if a model has a [prompt template and applies it](./completion/prompt_formatting.md) (e.g. if a huggingface model has a saved chat template in it's tokenizer_config.json). However, you can also set a custom prompt template on your proxy in the `config.yaml`:
|
||||
LiteLLM by default checks if a model has a [prompt template and applies it](../completion/prompt_formatting.md) (e.g. if a huggingface model has a saved chat template in it's tokenizer_config.json). However, you can also set a custom prompt template on your proxy in the `config.yaml`:
|
||||
|
||||
**Step 1**: Save your prompt template in a `config.yaml`
|
||||
```yaml
|
||||
|
@ -220,4 +305,41 @@ model_list:
|
|||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
```
|
||||
|
||||
## Router Settings
|
||||
|
||||
Use this to configure things like routing strategy.
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
routing_strategy: "least-busy"
|
||||
|
||||
model_list: # will route requests to the least busy ollama model
|
||||
- model_name: ollama-models
|
||||
litellm_params:
|
||||
model: "ollama/mistral"
|
||||
api_base: "http://127.0.0.1:8001"
|
||||
- model_name: ollama-models
|
||||
litellm_params:
|
||||
model: "ollama/codellama"
|
||||
api_base: "http://127.0.0.1:8002"
|
||||
- model_name: ollama-models
|
||||
litellm_params:
|
||||
model: "ollama/llama2"
|
||||
api_base: "http://127.0.0.1:8003"
|
||||
```
|
||||
|
||||
## Max Parallel Requests
|
||||
|
||||
To rate limit a user based on the number of parallel requests, e.g.:
|
||||
if user's parallel requests > x, send a 429 error
|
||||
if user's parallel requests <= x, let them use the API freely.
|
||||
|
||||
set the max parallel request limit on the config.yaml (note: this expects the user to be passing in an api key).
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
max_parallel_requests: 100 # max parallel requests for a user = 100
|
||||
```
|
||||
|
||||
|
|
|
@ -1,5 +1,87 @@
|
|||
# Deploying LiteLLM Proxy
|
||||
|
||||
### Deploy on Render https://render.com/
|
||||
## Quick Start Docker Image: Github Container Registry
|
||||
|
||||
### Pull the litellm ghcr docker image
|
||||
See the latest available ghcr docker image here:
|
||||
https://github.com/berriai/litellm/pkgs/container/litellm
|
||||
|
||||
```shell
|
||||
docker pull ghcr.io/berriai/litellm:main-v1.10.1
|
||||
```
|
||||
|
||||
### Run the Docker Image
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0
|
||||
```
|
||||
|
||||
#### Run the Docker Image with LiteLLM CLI args
|
||||
|
||||
See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
|
||||
|
||||
Here's how you can run the docker image and pass your config to `litellm`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
|
||||
```
|
||||
|
||||
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
|
||||
```
|
||||
|
||||
#### Run the Docker Image using docker compose
|
||||
|
||||
**Step 1**
|
||||
|
||||
- (Recommended) Use the example file `docker-compose.example.yml` given in the project root. e.g. https://github.com/BerriAI/litellm/blob/main/docker-compose.example.yml
|
||||
|
||||
- Rename the file `docker-compose.example.yml` to `docker-compose.yml`.
|
||||
|
||||
Here's an example `docker-compose.yml` file
|
||||
```yaml
|
||||
version: "3.9"
|
||||
services:
|
||||
litellm:
|
||||
image: ghcr.io/berriai/litellm:main
|
||||
ports:
|
||||
- "8000:8000" # Map the container port to the host, change the host port if necessary
|
||||
volumes:
|
||||
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
|
||||
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
|
||||
command: [ "--config", "/app/config.yaml", "--port", "8000", "--num_workers", "8" ]
|
||||
|
||||
# ...rest of your docker-compose config if any
|
||||
```
|
||||
|
||||
**Step 2**
|
||||
|
||||
Create a `litellm-config.yaml` file with your LiteLLM config relative to your `docker-compose.yml` file.
|
||||
|
||||
Check the config doc [here](https://docs.litellm.ai/docs/proxy/configs)
|
||||
|
||||
**Step 3**
|
||||
|
||||
Run the command `docker-compose up` or `docker compose up` as per your docker installation.
|
||||
|
||||
> Use `-d` flag to run the container in detached mode (background) e.g. `docker compose up -d`
|
||||
|
||||
|
||||
Your LiteLLM container should be running now on the defined port e.g. `8000`.
|
||||
|
||||
|
||||
## Deploy on Render https://render.com/
|
||||
|
||||
<iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||
|
||||
|
||||
## LiteLLM Proxy Performance
|
||||
|
||||
LiteLLM proxy has been load tested to handle 1500 req/s.
|
||||
|
||||
### Throughput - 30% Increase
|
||||
LiteLLM proxy + Load Balancer gives **30% increase** in throughput compared to Raw OpenAI API
|
||||
<Image img={require('../../img/throughput.png')} />
|
||||
|
||||
### Latency Added - 0.00325 seconds
|
||||
LiteLLM proxy adds **0.00325 seconds** latency as compared to using the Raw OpenAI API
|
||||
<Image img={require('../../img/latency.png')} />
|
||||
|
|
|
@ -3,38 +3,39 @@
|
|||
Load balance multiple instances of the same model
|
||||
|
||||
The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput**
|
||||
## Quick Start - Load Balancing
|
||||
### Step 1 - Set deployments on config
|
||||
|
||||
#### Example config
|
||||
requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
|
||||
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-eu
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key:
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key:
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-large
|
||||
api_base: https://openai-france-1234.openai.azure.com/
|
||||
api_key:
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 1440
|
||||
```
|
||||
|
||||
#### Step 2: Start Proxy with config
|
||||
### Step 2: Start Proxy with config
|
||||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
#### Step 3: Use proxy
|
||||
### Step 3: Use proxy - Call a model group [Load Balancing]
|
||||
Curl Command
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
|
@ -51,7 +52,28 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
|
|||
'
|
||||
```
|
||||
|
||||
### Fallbacks + Cooldowns + Retries + Timeouts
|
||||
### Usage - Call a specific model deployment
|
||||
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
|
||||
|
||||
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "azure/gpt-turbo-small-ca",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
|
||||
## Fallbacks + Cooldowns + Retries + Timeouts
|
||||
|
||||
If a call fails after num_retries, fall back to another model group.
|
||||
|
||||
|
@ -85,7 +107,7 @@ model_list:
|
|||
|
||||
litellm_settings:
|
||||
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
|
||||
request_timeout: 10 # raise Timeout error if call takes longer than 10s
|
||||
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
|
||||
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
|
||||
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
|
||||
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
||||
|
@ -107,7 +129,71 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
|
|||
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||
"num_retries": 2,
|
||||
"request_timeout": 10
|
||||
"timeout": 10
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
## Custom Timeouts, Stream Timeouts - Per Model
|
||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-eu
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: <your-key>
|
||||
timeout: 0.1 # timeout in (seconds)
|
||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
max_retries: 5
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key:
|
||||
timeout: 0.1 # timeout in (seconds)
|
||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
max_retries: 5
|
||||
|
||||
```
|
||||
|
||||
#### Start Proxy
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Health Check LLMs on Proxy
|
||||
Use this to health check all LLMs defined in your config.yaml
|
||||
#### Request
|
||||
Make a GET Request to `/health` on the proxy
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/health'
|
||||
```
|
||||
|
||||
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
|
||||
```
|
||||
litellm --health
|
||||
```
|
||||
#### Response
|
||||
```shell
|
||||
{
|
||||
"healthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
|
||||
},
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
}
|
||||
],
|
||||
"unhealthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com/"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
|
@ -1,8 +1,340 @@
|
|||
# Logging - OpenTelemetry, Langfuse, ElasticSearch
|
||||
Log Proxy Input, Output, Exceptions to Langfuse, OpenTelemetry
|
||||
## OpenTelemetry, ElasticSearch
|
||||
# Logging - Custom Callbacks, OpenTelemetry, Langfuse, Sentry
|
||||
|
||||
### Step 1 Start OpenTelemetry Collecter Docker Container
|
||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry
|
||||
|
||||
## Custom Callback Class [Async]
|
||||
Use this when you want to run custom callbacks in `python`
|
||||
|
||||
### Step 1 - Create your custom `litellm` callback class
|
||||
We use `litellm.integrations.custom_logger` for this, **more details about litellm custom callbacks [here](https://docs.litellm.ai/docs/observability/custom_callback)**
|
||||
|
||||
Define your custom callback class in a python file.
|
||||
|
||||
Here's an example custom logger for tracking `key, user, model, prompt, response, tokens, cost`. We create a file called `custom_callbacks.py` and initialize `proxy_handler_instance`
|
||||
|
||||
```python
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
class MyCustomHandler(CustomLogger):
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"Post-API Call")
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print("On Success")
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success!")
|
||||
# log: key, user, model, prompt, response, tokens, cost
|
||||
# Access kwargs passed to litellm.completion()
|
||||
model = kwargs.get("model", None)
|
||||
messages = kwargs.get("messages", None)
|
||||
user = kwargs.get("user", None)
|
||||
|
||||
# Access litellm_params passed to litellm.completion(), example access `metadata`
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
|
||||
|
||||
# Calculate cost using litellm.completion_cost()
|
||||
cost = litellm.completion_cost(completion_response=response_obj)
|
||||
response = response_obj
|
||||
# tokens used in response
|
||||
usage = response_obj["usage"]
|
||||
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
Messages: {messages},
|
||||
User: {user},
|
||||
Usage: {usage},
|
||||
Cost: {cost},
|
||||
Response: {response}
|
||||
Proxy Metadata: {metadata}
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(f"On Async Failure !")
|
||||
print("\nkwargs", kwargs)
|
||||
# Access kwargs passed to litellm.completion()
|
||||
model = kwargs.get("model", None)
|
||||
messages = kwargs.get("messages", None)
|
||||
user = kwargs.get("user", None)
|
||||
|
||||
# Access litellm_params passed to litellm.completion(), example access `metadata`
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
|
||||
|
||||
# Acess Exceptions & Traceback
|
||||
exception_event = kwargs.get("exception", None)
|
||||
traceback_event = kwargs.get("traceback_exception", None)
|
||||
|
||||
# Calculate cost using litellm.completion_cost()
|
||||
cost = litellm.completion_cost(completion_response=response_obj)
|
||||
print("now checking response obj")
|
||||
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
Messages: {messages},
|
||||
User: {user},
|
||||
Cost: {cost},
|
||||
Response: {response_obj}
|
||||
Proxy Metadata: {metadata}
|
||||
Exception: {exception_event}
|
||||
Traceback: {traceback_event}
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
|
||||
proxy_handler_instance = MyCustomHandler()
|
||||
|
||||
# Set litellm.callbacks = [proxy_handler_instance] on the proxy
|
||||
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
|
||||
```
|
||||
|
||||
### Step 2 - Pass your custom callback class in `config.yaml`
|
||||
We pass the custom callback class defined in **Step1** to the config.yaml.
|
||||
Set `callbacks` to `python_filename.logger_instance_name`
|
||||
|
||||
In the config below, we pass
|
||||
- python_filename: `custom_callbacks.py`
|
||||
- logger_instance_name: `proxy_handler_instance`. This is defined in Step 1
|
||||
|
||||
`callbacks: custom_callbacks.proxy_handler_instance`
|
||||
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
|
||||
```
|
||||
|
||||
### Step 3 - Start proxy + test request
|
||||
```shell
|
||||
litellm --config proxy_config.yaml
|
||||
```
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "good morning good sir"
|
||||
}
|
||||
],
|
||||
"user": "ishaan-app",
|
||||
"temperature": 0.2
|
||||
}'
|
||||
```
|
||||
|
||||
#### Resulting Log on Proxy
|
||||
```shell
|
||||
On Success
|
||||
Model: gpt-3.5-turbo,
|
||||
Messages: [{'role': 'user', 'content': 'good morning good sir'}],
|
||||
User: ishaan-app,
|
||||
Usage: {'completion_tokens': 10, 'prompt_tokens': 11, 'total_tokens': 21},
|
||||
Cost: 3.65e-05,
|
||||
Response: {'id': 'chatcmpl-8S8avKJ1aVBg941y5xzGMSKrYCMvN', 'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'content': 'Good morning! How can I assist you today?', 'role': 'assistant'}}], 'created': 1701716913, 'model': 'gpt-3.5-turbo-0613', 'object': 'chat.completion', 'system_fingerprint': None, 'usage': {'completion_tokens': 10, 'prompt_tokens': 11, 'total_tokens': 21}}
|
||||
Proxy Metadata: {'user_api_key': None, 'headers': Headers({'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'authorization': 'Bearer sk-1234', 'content-length': '199', 'content-type': 'application/x-www-form-urlencoded'}), 'model_group': 'gpt-3.5-turbo', 'deployment': 'gpt-3.5-turbo-ModelID-gpt-3.5-turbo'}
|
||||
```
|
||||
|
||||
### Logging Proxy Request Object, Header, Url
|
||||
|
||||
Here's how you can access the `url`, `headers`, `request body` sent to the proxy for each request
|
||||
|
||||
```python
|
||||
class MyCustomHandler(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success!")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", None)
|
||||
proxy_server_request = litellm_params.get("proxy_server_request")
|
||||
print(proxy_server_request)
|
||||
```
|
||||
|
||||
**Expected Output**
|
||||
|
||||
```shell
|
||||
{
|
||||
"url": "http://testserver/chat/completions",
|
||||
"method": "POST",
|
||||
"headers": {
|
||||
"host": "testserver",
|
||||
"accept": "*/*",
|
||||
"accept-encoding": "gzip, deflate",
|
||||
"connection": "keep-alive",
|
||||
"user-agent": "testclient",
|
||||
"authorization": "Bearer None",
|
||||
"content-length": "105",
|
||||
"content-type": "application/json"
|
||||
},
|
||||
"body": {
|
||||
"model": "Azure OpenAI GPT-4 Canada",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
],
|
||||
"max_tokens": 10
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Logging `model_info` set in config.yaml
|
||||
|
||||
Here is how to log the `model_info` set in your proxy `config.yaml`. Information on setting `model_info` on [config.yaml](https://docs.litellm.ai/docs/proxy/configs)
|
||||
|
||||
```python
|
||||
class MyCustomHandler(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success!")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", None)
|
||||
model_info = litellm_params.get("model_info")
|
||||
print(model_info)
|
||||
```
|
||||
|
||||
**Expected Output**
|
||||
```json
|
||||
{'mode': 'embedding', 'input_cost_per_token': 0.002}
|
||||
```
|
||||
|
||||
### Logging responses from proxy
|
||||
Both `/chat/completions` and `/embeddings` responses are available as `response_obj`
|
||||
|
||||
**Note: for `/chat/completions`, both `stream=True` and `non stream` responses are available as `response_obj`**
|
||||
|
||||
```python
|
||||
class MyCustomHandler(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success!")
|
||||
print(response_obj)
|
||||
|
||||
```
|
||||
|
||||
**Expected Output /chat/completion [for both `stream` and `non-stream` responses]**
|
||||
```json
|
||||
ModelResponse(
|
||||
id='chatcmpl-8Tfu8GoMElwOZuj2JlHBhNHG01PPo',
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason='stop',
|
||||
index=0,
|
||||
message=Message(
|
||||
content='As an AI language model, I do not have a physical body and therefore do not possess any degree or educational qualifications. My knowledge and abilities come from the programming and algorithms that have been developed by my creators.',
|
||||
role='assistant'
|
||||
)
|
||||
)
|
||||
],
|
||||
created=1702083284,
|
||||
model='chatgpt-v-2',
|
||||
object='chat.completion',
|
||||
system_fingerprint=None,
|
||||
usage=Usage(
|
||||
completion_tokens=42,
|
||||
prompt_tokens=5,
|
||||
total_tokens=47
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
**Expected Output /embeddings**
|
||||
```json
|
||||
{
|
||||
'model': 'ada',
|
||||
'data': [
|
||||
{
|
||||
'embedding': [
|
||||
-0.035126980394124985, -0.020624293014407158, -0.015343423001468182,
|
||||
-0.03980357199907303, -0.02750781551003456, 0.02111034281551838,
|
||||
-0.022069307044148445, -0.019442008808255196, -0.00955679826438427,
|
||||
-0.013143060728907585, 0.029583381488919258, -0.004725852981209755,
|
||||
-0.015198921784758568, -0.014069183729588985, 0.00897879246622324,
|
||||
0.01521205808967352,
|
||||
# ... (truncated for brevity)
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## OpenTelemetry - Traceloop
|
||||
|
||||
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format
|
||||
|
||||
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop
|
||||
|
||||
**Step 1** Install traceloop-sdk and set Traceloop API key
|
||||
|
||||
```shell
|
||||
pip install traceloop-sdk -U
|
||||
```
|
||||
|
||||
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
||||
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
||||
|
||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
success_callback: ["traceloop"]
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
|
||||
<!-- ### Step 1 Start OpenTelemetry Collecter Docker Container
|
||||
This container sends logs to your selected destination
|
||||
|
||||
#### Install OpenTelemetry Collecter Docker Image
|
||||
|
@ -113,48 +445,13 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
|
|||
On successfull logging you should be able to see this log on your `OpenTelemetry Collecter` Docker Container
|
||||
```shell
|
||||
Events:
|
||||
SpanEvent #0
|
||||
-> Name: LiteLLM: Request Input
|
||||
-> Timestamp: 2023-12-02 05:05:53.71063 +0000 UTC
|
||||
-> DroppedAttributesCount: 0
|
||||
-> Attributes::
|
||||
-> type: Str(http)
|
||||
-> asgi: Str({'version': '3.0', 'spec_version': '2.3'})
|
||||
-> http_version: Str(1.1)
|
||||
-> server: Str(('127.0.0.1', 8000))
|
||||
-> client: Str(('127.0.0.1', 62796))
|
||||
-> scheme: Str(http)
|
||||
-> method: Str(POST)
|
||||
-> root_path: Str()
|
||||
-> path: Str(/chat/completions)
|
||||
-> raw_path: Str(b'/chat/completions')
|
||||
-> query_string: Str(b'')
|
||||
-> headers: Str([(b'host', b'0.0.0.0:8000'), (b'user-agent', b'curl/7.88.1'), (b'accept', b'*/*'), (b'authorization', b'Bearer sk-1244'), (b'content-length', b'147'), (b'content-type', b'application/x-www-form-urlencoded')])
|
||||
-> state: Str({})
|
||||
-> app: Str(<fastapi.applications.FastAPI object at 0x1253dd960>)
|
||||
-> fastapi_astack: Str(<contextlib.AsyncExitStack object at 0x127c8b7c0>)
|
||||
-> router: Str(<fastapi.routing.APIRouter object at 0x1253dda50>)
|
||||
-> endpoint: Str(<function chat_completion at 0x1254383a0>)
|
||||
-> path_params: Str({})
|
||||
-> route: Str(APIRoute(path='/chat/completions', name='chat_completion', methods=['POST']))
|
||||
SpanEvent #1
|
||||
-> Name: LiteLLM: Request Headers
|
||||
-> Timestamp: 2023-12-02 05:05:53.710652 +0000 UTC
|
||||
-> DroppedAttributesCount: 0
|
||||
-> Attributes::
|
||||
-> host: Str(0.0.0.0:8000)
|
||||
-> user-agent: Str(curl/7.88.1)
|
||||
-> accept: Str(*/*)
|
||||
-> authorization: Str(Bearer sk-1244)
|
||||
-> content-length: Str(147)
|
||||
-> content-type: Str(application/x-www-form-urlencoded)
|
||||
SpanEvent #2
|
||||
|
||||
```
|
||||
|
||||
### View Log on Elastic Search
|
||||
Here's the log view on Elastic Search. You can see the request `input`, `output` and `headers`
|
||||
|
||||
<Image img={require('../../img/elastic_otel.png')} />
|
||||
<Image img={require('../../img/elastic_otel.png')} /> -->
|
||||
|
||||
## Logging Proxy Input/Output - Langfuse
|
||||
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse
|
||||
|
@ -190,3 +487,41 @@ litellm --test
|
|||
Expected output on Langfuse
|
||||
|
||||
<Image img={require('../../img/langfuse_small.png')} />
|
||||
|
||||
## Logging Proxy Input/Output - Sentry
|
||||
|
||||
If api calls fail (llm/database) you can log those to Sentry:
|
||||
|
||||
**Step 1** Install Sentry
|
||||
```shell
|
||||
pip install --upgrade sentry-sdk
|
||||
```
|
||||
|
||||
**Step 2**: Save your Sentry_DSN and add `litellm_settings`: `failure_callback`
|
||||
```shell
|
||||
export SENTRY_DSN="your-sentry-dsn"
|
||||
```
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
# other settings
|
||||
failure_callback: ["sentry"]
|
||||
general_settings:
|
||||
database_url: "my-bad-url" # set a fake url to trigger a sentry exception
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
```
|
||||
litellm --test
|
||||
```
|
||||
|
|
74
docs/my-website/docs/proxy/model_management.md
Normal file
74
docs/my-website/docs/proxy/model_management.md
Normal file
|
@ -0,0 +1,74 @@
|
|||
# Model Management
|
||||
Add new models + Get model info without restarting proxy.
|
||||
|
||||
## Get Model Information
|
||||
|
||||
Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes.
|
||||
|
||||
<Tabs
|
||||
defaultValue="curl"
|
||||
values={[
|
||||
{ label: 'cURL', value: 'curl', },
|
||||
]}>
|
||||
<TabItem value="curl">
|
||||
|
||||
```bash
|
||||
curl -X GET "http://0.0.0.0:8000/model/info" \
|
||||
-H "accept: application/json" \
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Add a New Model
|
||||
|
||||
Add a new model to the list in the `config.yaml` by providing the model parameters. This allows you to update the model list without restarting the proxy.
|
||||
|
||||
<Tabs
|
||||
defaultValue="curl"
|
||||
values={[
|
||||
{ label: 'cURL', value: 'curl', },
|
||||
]}>
|
||||
<TabItem value="curl">
|
||||
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8000/model/new" \
|
||||
-H "accept: application/json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Model Parameters Structure
|
||||
|
||||
When adding a new model, your JSON payload should conform to the following structure:
|
||||
|
||||
- `model_name`: The name of the new model (required).
|
||||
- `litellm_params`: A dictionary containing parameters specific to the Litellm setup (required).
|
||||
- `model_info`: An optional dictionary to provide additional information about the model.
|
||||
|
||||
Here's an example of how to structure your `ModelParams`:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_name": "my_awesome_model",
|
||||
"litellm_params": {
|
||||
"some_parameter": "some_value",
|
||||
"another_parameter": "another_value"
|
||||
},
|
||||
"model_info": {
|
||||
"author": "Your Name",
|
||||
"version": "1.0",
|
||||
"description": "A brief description of the model."
|
||||
}
|
||||
}
|
||||
```
|
||||
---
|
||||
|
||||
Keep in mind that as both endpoints are in [BETA], you may need to visit the associated GitHub issues linked in the API descriptions to check for updates or provide feedback:
|
||||
|
||||
- Get Model Information: [Issue #933](https://github.com/BerriAI/litellm/issues/933)
|
||||
- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964)
|
||||
|
||||
Feedback on the beta endpoints is valuable and helps improve the API for all users.
|
|
@ -43,7 +43,7 @@ litellm --test
|
|||
|
||||
This will now automatically route any requests for gpt-3.5-turbo to bigcode starcoder, hosted on huggingface inference endpoints.
|
||||
|
||||
### Using LiteLLM Proxy - Curl Request, OpenAI Package
|
||||
### Using LiteLLM Proxy - Curl Request, OpenAI Package, Langchain, Langchain JS
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
@ -84,72 +84,75 @@ print(response)
|
|||
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:8000", # set openai_api_base to the LiteLLM Proxy
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="langchain-embedding" label="Langchain Embeddings">
|
||||
|
||||
```python
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="sagemaker-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||
|
||||
|
||||
text = "This is a test document."
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
print(f"SAGEMAKER EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="bedrock-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||
|
||||
text = "This is a test document."
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
print(f"BEDROCK EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="bedrock-titan-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||
|
||||
text = "This is a test document."
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
print(f"TITAN EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Quick Start - LiteLLM Proxy + Config.yaml
|
||||
The config allows you to create a model list and set `api_base`, `max_tokens` (all litellm params). See more details about the config [here](https://docs.litellm.ai/docs/proxy/configs)
|
||||
|
||||
### Create a Config for LiteLLM Proxy
|
||||
Example config
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-api-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
```
|
||||
|
||||
### Run proxy with config
|
||||
|
||||
```shell
|
||||
litellm --config your_config.yaml
|
||||
```
|
||||
|
||||
## Quick Start Docker Image: Github Container Registry
|
||||
|
||||
### Pull the litellm ghcr docker image
|
||||
See the latest available ghcr docker image here:
|
||||
https://github.com/berriai/litellm/pkgs/container/litellm
|
||||
|
||||
```shell
|
||||
docker pull ghcr.io/berriai/litellm:main-v1.10.1
|
||||
```
|
||||
|
||||
### Run the Docker Image
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0
|
||||
```
|
||||
|
||||
#### Run the Docker Image with LiteLLM CLI args
|
||||
|
||||
See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
|
||||
|
||||
Here's how you can run the docker image and pass your config to `litellm`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
|
||||
```
|
||||
|
||||
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
|
||||
```
|
||||
|
||||
## Server Endpoints
|
||||
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
|
||||
- POST `/completions` - completions endpoint
|
||||
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints
|
||||
- GET `/models` - available models on server
|
||||
- POST `/key/generate` - generate a key to access the proxy
|
||||
|
||||
## Supported LLMs
|
||||
### Supported LLMs
|
||||
All LiteLLM supported LLMs are supported on the Proxy. Seel all [supported llms](https://docs.litellm.ai/docs/providers)
|
||||
<Tabs>
|
||||
<TabItem value="bedrock" label="AWS Bedrock">
|
||||
|
@ -175,7 +178,7 @@ $ litellm --model azure/my-deployment-name
|
|||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="openai-proxy" label="OpenAI">
|
||||
<TabItem value="openai" label="OpenAI">
|
||||
|
||||
```shell
|
||||
$ export OPENAI_API_KEY=my-api-key
|
||||
|
@ -185,13 +188,23 @@ $ export OPENAI_API_KEY=my-api-key
|
|||
$ litellm --model gpt-3.5-turbo
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai-proxy" label="OpenAI Compatible Endpoint">
|
||||
|
||||
```shell
|
||||
$ export OPENAI_API_KEY=my-api-key
|
||||
```
|
||||
|
||||
```shell
|
||||
$ litellm --model openai/<your model name> --api_base <your-api-base> # e.g. http://0.0.0.0:3000
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="huggingface" label="Huggingface (TGI) Deployed">
|
||||
|
||||
```shell
|
||||
$ export HUGGINGFACE_API_KEY=my-api-key #[OPTIONAL]
|
||||
```
|
||||
```shell
|
||||
$ litellm --model huggingface/<your model name> --api_base https://k58ory32yinf1ly0.us-east-1.aws.endpoints.huggingface.cloud
|
||||
$ litellm --model huggingface/<your model name> --api_base <your-api-base> # e.g. http://0.0.0.0:3000
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
@ -301,6 +314,124 @@ $ litellm --model command-nightly
|
|||
</Tabs>
|
||||
|
||||
|
||||
|
||||
|
||||
## Quick Start - LiteLLM Proxy + Config.yaml
|
||||
The config allows you to create a model list and set `api_base`, `max_tokens` (all litellm params). See more details about the config [here](https://docs.litellm.ai/docs/proxy/configs)
|
||||
|
||||
### Create a Config for LiteLLM Proxy
|
||||
Example config
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo # user-facing model alias
|
||||
litellm_params: # all params accepted by litellm.completion() - https://docs.litellm.ai/docs/completion/input
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-api-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
- model_name: vllm-model
|
||||
litellm_params:
|
||||
model: openai/<your-model-name>
|
||||
api_base: <your-api-base> # e.g. http://0.0.0.0:3000
|
||||
```
|
||||
|
||||
### Run proxy with config
|
||||
|
||||
```shell
|
||||
litellm --config your_config.yaml
|
||||
```
|
||||
|
||||
[**More Info**](./configs.md)
|
||||
|
||||
## Server Endpoints
|
||||
- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs
|
||||
- POST `/completions` - completions endpoint
|
||||
- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints
|
||||
- GET `/models` - available models on server
|
||||
- POST `/key/generate` - generate a key to access the proxy
|
||||
|
||||
## Gunicorn + Proxy
|
||||
|
||||
Command:
|
||||
```python
|
||||
cmd = f"gunicorn litellm.proxy.proxy_server:app --workers {num_workers} --worker-class uvicorn.workers.UvicornWorker --bind {host}:{port}"
|
||||
```
|
||||
|
||||
[**Code**](https://github.com/BerriAI/litellm/blob/077f6b1298101079b72396bdf04f8ca0cf737720/litellm/tests/test_proxy_gunicorn.py#L4)
|
||||
## Quick Start Docker Image: Github Container Registry
|
||||
|
||||
### Pull the litellm ghcr docker image
|
||||
See the latest available ghcr docker image here:
|
||||
https://github.com/berriai/litellm/pkgs/container/litellm
|
||||
|
||||
```shell
|
||||
docker pull ghcr.io/berriai/litellm:main-v1.10.1
|
||||
```
|
||||
|
||||
### Run the Docker Image
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0
|
||||
```
|
||||
|
||||
#### Run the Docker Image with LiteLLM CLI args
|
||||
|
||||
See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
|
||||
|
||||
Here's how you can run the docker image and pass your config to `litellm`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
|
||||
```
|
||||
|
||||
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
|
||||
```
|
||||
|
||||
#### Run the Docker Image using docker compose
|
||||
|
||||
**Step 1**
|
||||
|
||||
- (Recommended) Use the example file `docker-compose.example.yml` given in the project root. e.g. https://github.com/BerriAI/litellm/blob/main/docker-compose.example.yml
|
||||
|
||||
- Rename the file `docker-compose.example.yml` to `docker-compose.yml`.
|
||||
|
||||
Here's an example `docker-compose.yml` file
|
||||
```yaml
|
||||
version: "3.9"
|
||||
services:
|
||||
litellm:
|
||||
image: ghcr.io/berriai/litellm:main
|
||||
ports:
|
||||
- "8000:8000" # Map the container port to the host, change the host port if necessary
|
||||
volumes:
|
||||
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
|
||||
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
|
||||
command: [ "--config", "/app/config.yaml", "--port", "8000", "--num_workers", "8" ]
|
||||
|
||||
# ...rest of your docker-compose config if any
|
||||
```
|
||||
|
||||
**Step 2**
|
||||
|
||||
Create a `litellm-config.yaml` file with your LiteLLM config relative to your `docker-compose.yml` file.
|
||||
|
||||
Check the config doc [here](https://docs.litellm.ai/docs/proxy/configs)
|
||||
|
||||
**Step 3**
|
||||
|
||||
Run the command `docker-compose up` or `docker compose up` as per your docker installation.
|
||||
|
||||
> Use `-d` flag to run the container in detached mode (background) e.g. `docker compose up -d`
|
||||
|
||||
|
||||
Your LiteLLM container should be running now on the defined port e.g. `8000`.
|
||||
|
||||
|
||||
## Using with OpenAI compatible projects
|
||||
Set `base_url` to the LiteLLM Proxy server
|
||||
|
||||
|
@ -473,37 +604,4 @@ curl -X POST \
|
|||
https://api.openai.com/v1/chat/completions \
|
||||
-H 'content-type: application/json' -H 'Authorization: Bearer sk-qnWGUIW9****************************************' \
|
||||
-d '{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "this is a test request, write a short poem"}]}'
|
||||
```
|
||||
|
||||
## Health Check LLMs on Proxy
|
||||
Use this to health check all LLMs defined in your config.yaml
|
||||
#### Request
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/health'
|
||||
```
|
||||
|
||||
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
|
||||
```
|
||||
litellm --health
|
||||
```
|
||||
#### Response
|
||||
```shell
|
||||
{
|
||||
"healthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
|
||||
},
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
}
|
||||
],
|
||||
"unhealthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com/"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
|
@ -1,9 +1,10 @@
|
|||
|
||||
# Cost Tracking & Virtual Keys
|
||||
# Key Management
|
||||
Track Spend and create virtual keys for the proxy
|
||||
|
||||
Grant other's temporary access to your proxy, with keys that expire after a set duration.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Requirements:
|
||||
|
||||
- Need to a postgres database (e.g. [Supabase](https://supabase.com/), [Neon](https://neon.tech/), etc)
|
||||
|
@ -55,7 +56,7 @@ Expected response:
|
|||
}
|
||||
```
|
||||
|
||||
### Managing Auth - Upgrade/Downgrade Models
|
||||
## Managing Auth - Upgrade/Downgrade Models
|
||||
|
||||
If a user is expected to use a given model (i.e. gpt3-5), and you want to:
|
||||
|
||||
|
@ -104,7 +105,7 @@ curl -X POST "https://0.0.0.0:8000/key/generate" \
|
|||
- **How to upgrade / downgrade request?** Change the alias mapping
|
||||
- **How are routing between diff keys/api bases done?** litellm handles this by shuffling between different models in the model list with the same model_name. [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
||||
|
||||
### Managing Auth - Tracking Spend
|
||||
## Managing Auth - Tracking Spend
|
||||
|
||||
You can get spend for a key by using the `/key/info` endpoint.
|
||||
|
||||
|
@ -136,4 +137,54 @@ This is automatically updated (in USD) when calls are made to /completions, /cha
|
|||
"config": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
```
|
||||
|
||||
## Custom Auth
|
||||
|
||||
You can now override the default api key auth.
|
||||
|
||||
Here's how:
|
||||
|
||||
### 1. Create a custom auth file.
|
||||
|
||||
Make sure the response type follows the `UserAPIKeyAuth` pydantic object. This is used by for logging usage specific to that user key.
|
||||
|
||||
```python
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
modified_master_key = "sk-my-master-key"
|
||||
if api_key == modified_master_key:
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
||||
```
|
||||
|
||||
### 2. Pass the filepath (relative to the config.yaml)
|
||||
|
||||
Pass the filepath to the config.yaml
|
||||
|
||||
e.g. if they're both in the same dir - `./config.yaml` and `./custom_auth.py`, this is what it looks like:
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "openai-model"
|
||||
litellm_params:
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
set_verbose: True
|
||||
|
||||
general_settings:
|
||||
custom_auth: custom_auth.user_api_key_auth
|
||||
```
|
||||
|
||||
[**Implementation Code**](https://github.com/BerriAI/litellm/blob/caf2a6b279ddbe89ebd1d8f4499f65715d684851/litellm/proxy/utils.py#L122)
|
||||
|
||||
### 3. Start the proxy
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
|
|
|
@ -356,6 +356,16 @@ router = Router(model_list=model_list,
|
|||
|
||||
print(response)
|
||||
```
|
||||
|
||||
**Pass in Redis URL, additional kwargs**
|
||||
```python
|
||||
router = Router(model_list: Optional[list] = None,
|
||||
## CACHING ##
|
||||
redis_url=os.getenv("REDIS_URL")",
|
||||
cache_kwargs= {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||
cache_responses=True)
|
||||
```
|
||||
|
||||
#### Default litellm.completion/embedding params
|
||||
|
||||
You can also set default params for litellm completion/embedding calls. Here's how to do that:
|
||||
|
|
|
@ -68,7 +68,7 @@ You can now test this by starting your proxy:
|
|||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
[Quick Test Proxy](./simple_proxy.md#using-litellm-proxy---curl-request-openai-package)
|
||||
[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
|
||||
|
||||
## Infisical Secret Manager
|
||||
Integrates with [Infisical's Secret Manager](https://infisical.com/) for secure storage and retrieval of API keys and sensitive data.
|
||||
|
|
8668
docs/my-website/package-lock.json
generated
8668
docs/my-website/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
@ -20,6 +20,7 @@
|
|||
"@docusaurus/preset-classic": "2.4.1",
|
||||
"@mdx-js/react": "^1.6.22",
|
||||
"clsx": "^1.2.1",
|
||||
"docusaurus": "^1.14.7",
|
||||
"docusaurus-lunr-search": "^2.4.1",
|
||||
"prism-react-renderer": "^1.3.5",
|
||||
"react": "^17.0.2",
|
||||
|
|
|
@ -99,6 +99,7 @@ const sidebars = {
|
|||
"proxy/configs",
|
||||
"proxy/load_balancing",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/model_management",
|
||||
"proxy/caching",
|
||||
"proxy/logging",
|
||||
"proxy/cli",
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2,15 +2,18 @@
|
|||
import threading, requests
|
||||
from typing import Callable, List, Optional, Dict, Union, Any
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import set_verbose
|
||||
import httpx
|
||||
|
||||
input_callback: List[Union[str, Callable]] = []
|
||||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
set_verbose = False
|
||||
email: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
|
@ -43,6 +46,7 @@ caching: bool = False # Not used anymore, will be removed in next MAJOR release
|
|||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
_current_cost = 0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
|
@ -338,7 +342,7 @@ cohere_embedding_models: List = [
|
|||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
]
|
||||
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1"]
|
||||
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
|
||||
|
||||
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||
|
||||
|
|
8
litellm/_logging.py
Normal file
8
litellm/_logging.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
set_verbose = False
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
93
litellm/_redis.py
Normal file
93
litellm/_redis.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
import inspect
|
||||
import redis, litellm
|
||||
from typing import List, Optional
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
|
||||
include_args = [
|
||||
"url"
|
||||
]
|
||||
|
||||
available_args = [
|
||||
x for x in arg_spec.args if x not in exclude_args
|
||||
] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
return {
|
||||
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
|
||||
}
|
||||
|
||||
|
||||
def _redis_kwargs_from_environment():
|
||||
mapping = _get_redis_env_kwarg_mapping()
|
||||
|
||||
return_dict = {}
|
||||
for k, v in mapping.items():
|
||||
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
|
||||
if value is not None:
|
||||
return_dict[v] = value
|
||||
return return_dict
|
||||
|
||||
|
||||
def get_redis_url_from_environment():
|
||||
if "REDIS_URL" in os.environ:
|
||||
return os.environ["REDIS_URL"]
|
||||
|
||||
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
||||
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.")
|
||||
|
||||
if "REDIS_PASSWORD" in os.environ:
|
||||
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
|
||||
else:
|
||||
redis_password = ""
|
||||
|
||||
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
### check if "os.environ/<key-name>" passed in
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
v = v.replace("os.environ/", "")
|
||||
value = litellm.get_secret(v)
|
||||
env_overrides[k] = value
|
||||
|
||||
redis_kwargs = {
|
||||
**_redis_kwargs_from_environment(),
|
||||
**env_overrides,
|
||||
}
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs['url'] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
|
||||
return redis.Redis.from_url(**redis_kwargs)
|
||||
elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis.Redis(**redis_kwargs)
|
|
@ -13,9 +13,12 @@ class BudgetManager:
|
|||
self.load_data()
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose:
|
||||
import logging
|
||||
logging.info(print_statement)
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
import logging
|
||||
logging.info(print_statement)
|
||||
except:
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
if self.client_type == "local":
|
||||
|
|
|
@ -25,8 +25,11 @@ def get_prompt(*args, **kwargs):
|
|||
return None
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
class BaseCache:
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
|
@ -58,8 +61,6 @@ class InMemoryCache(BaseCache):
|
|||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
if isinstance(cached_response, dict):
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
@ -69,13 +70,26 @@ class InMemoryCache(BaseCache):
|
|||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
def __init__(self, host, port, password):
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
import redis
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.redis_client = redis.Redis(host=host, port=port, password=password)
|
||||
from ._redis import get_redis_client
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
redis_kwargs["host"] = host
|
||||
if port is not None:
|
||||
redis_kwargs["port"] = port
|
||||
if password is not None:
|
||||
redis_kwargs["password"] = password
|
||||
|
||||
redis_kwargs.update(kwargs)
|
||||
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}")
|
||||
try:
|
||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
except Exception as e:
|
||||
|
@ -84,8 +98,9 @@ class RedisCache(BaseCache):
|
|||
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
# TODO convert this to a ModelResponse object
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
cached_response = self.redis_client.get(key)
|
||||
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
|
||||
if cached_response != None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||
|
@ -93,8 +108,6 @@ class RedisCache(BaseCache):
|
|||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if isinstance(cached_response, dict):
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
|
@ -168,7 +181,8 @@ class Cache:
|
|||
type="local",
|
||||
host=None,
|
||||
port=None,
|
||||
password=None
|
||||
password=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
@ -178,6 +192,7 @@ class Cache:
|
|||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
@ -186,13 +201,15 @@ class Cache:
|
|||
None
|
||||
"""
|
||||
if type == "redis":
|
||||
self.cache = RedisCache(host, port, password)
|
||||
self.cache = RedisCache(host, port, password, **kwargs)
|
||||
if type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.success_callback.append("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
|
||||
def get_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -205,16 +222,37 @@ class Cache:
|
|||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key =""
|
||||
for param in kwargs:
|
||||
cache_key = ""
|
||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||
|
||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||
print_verbose(f"\nReturning preset cache key: {cache_key}")
|
||||
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
|
||||
|
||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
|
||||
for param in completion_kwargs:
|
||||
# ignore litellm params here
|
||||
if param in set(["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]):
|
||||
if param in kwargs:
|
||||
# check if param == model and model_group is passed in, then override model with model_group
|
||||
if param == "model" and kwargs.get("metadata", None) is not None and kwargs["metadata"].get("model_group", None) is not None:
|
||||
param_value = kwargs["metadata"].get("model_group", None) # for litellm.Router use model_group for caching over `model`
|
||||
if param == "model":
|
||||
model_group = None
|
||||
metadata = kwargs.get("metadata", None)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group")
|
||||
if litellm_params is not None:
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group", None)
|
||||
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"]
|
||||
else:
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key+= f"{str(param)}: {str(param_value)}"
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
return cache_key
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
|
@ -241,9 +279,6 @@ class Cache:
|
|||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True:
|
||||
# if streaming is true and we got a cache hit, return a generator
|
||||
return self.generate_streaming_content(cached_result["choices"][0]['message']['content'])
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logging.debug(f"An exception occurred: {traceback.format_exc()}")
|
||||
|
|
|
@ -8,7 +8,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
|
|||
import traceback
|
||||
|
||||
|
||||
class CustomLogger:
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -27,9 +27,20 @@ class CustomLogger:
|
|||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
#### DEPRECATED ####
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||
|
||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
try:
|
||||
|
@ -46,6 +57,22 @@ class CustomLogger:
|
|||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["log_event_type"] = "pre_api_call"
|
||||
await callback_func(
|
||||
kwargs,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - model call details: {kwargs}"
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||
# Method definition
|
||||
try:
|
||||
|
@ -63,3 +90,21 @@ class CustomLogger:
|
|||
# traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||
# Method definition
|
||||
try:
|
||||
kwargs["log_event_type"] = "post_api_call"
|
||||
await callback_func(
|
||||
kwargs, # kwargs to func
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - final response object: {response_obj}"
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
|
@ -37,7 +37,7 @@ class LangFuseLogger:
|
|||
f"Langfuse Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
||||
prompt = [kwargs.get('messages')]
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
|
||||
|
@ -70,6 +70,6 @@ class LangFuseLogger:
|
|||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
azure_client = client
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
response.model = "azure/" + str(response.model)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
|
@ -318,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
client=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -327,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
|
@ -372,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -391,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
|
|||
}
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
|
|
|
@ -2,7 +2,7 @@ import json, copy, types
|
|||
import os
|
||||
from enum import Enum
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Any
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
@ -205,15 +205,25 @@ class AmazonLlamaConfig():
|
|||
|
||||
def init_bedrock_client(
|
||||
region_name = None,
|
||||
aws_access_key_id = None,
|
||||
aws_secret_access_key = None,
|
||||
aws_region_name=None,
|
||||
aws_bedrock_runtime_endpoint=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] =None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str]=None,
|
||||
):
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME")
|
||||
standard_aws_region_name = get_secret("AWS_REGION")
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith('os.environ/'):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
|
@ -472,7 +482,7 @@ def completion(
|
|||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response_body,
|
||||
original_response=json.dumps(response_body),
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
|
@ -533,37 +543,59 @@ def completion(
|
|||
def _embedding_func_single(
|
||||
model: str,
|
||||
input: str,
|
||||
client: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop(
|
||||
"aws_bedrock_client",
|
||||
# only pass variables that are not None
|
||||
init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
),
|
||||
## FORMAT EMBEDDING INPUT ##
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
if provider == "amazon":
|
||||
input = input.replace(os.linesep, " ")
|
||||
data = {"inputText": input, **inference_params}
|
||||
# data = json.dumps(data)
|
||||
elif provider == "cohere":
|
||||
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params} # type: ignore
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body},
|
||||
modelId={model},
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}, "request_str": request_str},
|
||||
)
|
||||
|
||||
input = input.replace(os.linesep, " ")
|
||||
body = json.dumps({"inputText": input})
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=model,
|
||||
accept="application/json",
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
return response_body.get("embedding")
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
if provider == "cohere":
|
||||
response = response_body.get("embeddings")
|
||||
# flatten list
|
||||
response = [item for sublist in response for item in sublist]
|
||||
return response
|
||||
elif provider == "amazon":
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||
|
||||
|
@ -576,17 +608,21 @@ def embedding(
|
|||
optional_params=None,
|
||||
encoding=None,
|
||||
):
|
||||
### BOTO3 INIT ###
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
)
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
## Embedding Call
|
||||
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
|
||||
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
|
||||
|
||||
|
||||
## Populate OpenAI compliant dictionary
|
||||
|
@ -614,14 +650,5 @@ def embedding(
|
|||
total_tokens=input_tokens + 0
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
original_response=embeddings,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -170,6 +170,11 @@ class Huggingface(BaseLLM):
|
|||
"content"
|
||||
] = completion_response["generated_text"] # type: ignore
|
||||
elif task == "text-generation-inference":
|
||||
if (not isinstance(completion_response, list)
|
||||
or not isinstance(completion_response[0], dict)
|
||||
or "generated_text" not in completion_response[0]):
|
||||
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}")
|
||||
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import requests, types
|
||||
import requests, types, time
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
import litellm
|
||||
import httpx
|
||||
|
||||
import httpx, aiohttp, asyncio
|
||||
try:
|
||||
from async_generator import async_generator, yield_ # optional dependency
|
||||
async_generator_imported = True
|
||||
|
@ -115,6 +114,9 @@ def get_ollama_response_stream(
|
|||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
model_response=None,
|
||||
encoding=None
|
||||
):
|
||||
if api_base.endswith("/api/generate"):
|
||||
url = api_base
|
||||
|
@ -136,8 +138,15 @@ def get_ollama_response_stream(
|
|||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=None,
|
||||
additional_args={"api_base": url, "complete_input_dict": data},
|
||||
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
|
||||
)
|
||||
if acompletion is True:
|
||||
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
return response
|
||||
else:
|
||||
return ollama_completion_stream(url=url, data=data)
|
||||
|
||||
def ollama_completion_stream(url, data):
|
||||
session = requests.Session()
|
||||
|
||||
with session.post(url, json=data, stream=True) as resp:
|
||||
|
@ -169,6 +178,52 @@ def get_ollama_response_stream(
|
|||
traceback.print_exc()
|
||||
session.close()
|
||||
|
||||
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
resp = await session.post(url, json=data)
|
||||
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise OllamaError(status_code=resp.status, message=text)
|
||||
|
||||
async for line in resp.content.iter_any():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
chunks = json_chunk.split("\n")
|
||||
completion_string = ""
|
||||
for chunk in chunks:
|
||||
if chunk.strip() != "":
|
||||
j = json.loads(chunk)
|
||||
if "error" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"error": j
|
||||
}
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": j["response"],
|
||||
}
|
||||
completion_string += completion_obj["content"]
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
model_response["choices"][0]["message"]["content"] = completion_string
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data['model']
|
||||
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
|
||||
completion_tokens = len(encoding.encode(completion_string))
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
if async_generator_imported:
|
||||
# ollama implementation
|
||||
@async_generator
|
||||
|
|
|
@ -195,23 +195,23 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
**optional_params
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
try:
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
else:
|
||||
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
if client is None:
|
||||
|
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_client = client
|
||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||
stringified_response = response.model_dump_json()
|
||||
logging_obj.post_call(
|
||||
input=None,
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
original_response=stringified_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
|
@ -259,6 +260,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
headers=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -266,8 +269,21 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_aclient = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
|
||||
)
|
||||
response = await openai_aclient.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
stringified_response = response.model_dump_json()
|
||||
logging_obj.post_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if response and hasattr(response, "text"):
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
|
@ -285,12 +301,19 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
client = None,
|
||||
max_retries=None
|
||||
max_retries=None,
|
||||
headers=None
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
|
||||
)
|
||||
response = openai_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
return streamwrapper
|
||||
|
@ -304,6 +327,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
headers=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -311,6 +335,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_aclient = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = await openai_aclient.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
@ -325,6 +356,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||
async def aembedding(
|
||||
self,
|
||||
input: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
|
@ -332,6 +364,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -340,9 +373,24 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
||||
return response
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
@ -367,13 +415,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -381,6 +423,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
## COMPLETION CALL
|
||||
response = openai_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
|
|
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||
import requests, traceback
|
||||
import json
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
@ -74,7 +74,7 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
|
|||
messages=messages
|
||||
)
|
||||
else:
|
||||
prompt = "".join(m["content"] for m in messages)
|
||||
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
|
||||
return prompt
|
||||
|
||||
def mistral_instruct_pt(messages):
|
||||
|
@ -159,26 +159,27 @@ def phind_codellama_pt(messages):
|
|||
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
||||
return prompt
|
||||
|
||||
def hf_chat_template(model: str, messages: list):
|
||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
||||
## get the tokenizer config from huggingface
|
||||
def _get_tokenizer_config(hf_model_name):
|
||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
# Make a GET request to fetch the JSON data
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# Parse the JSON data
|
||||
tokenizer_config = json.loads(response.content)
|
||||
return {"status": "success", "tokenizer": tokenizer_config}
|
||||
else:
|
||||
return {"status": "failure"}
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
|
||||
raise Exception("No chat template found")
|
||||
## read the bos token, eos token and chat template from the json
|
||||
tokenizer_config = tokenizer_config["tokenizer"]
|
||||
bos_token = tokenizer_config["bos_token"]
|
||||
eos_token = tokenizer_config["eos_token"]
|
||||
chat_template = tokenizer_config["chat_template"]
|
||||
if chat_template is None:
|
||||
def _get_tokenizer_config(hf_model_name):
|
||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
# Make a GET request to fetch the JSON data
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# Parse the JSON data
|
||||
tokenizer_config = json.loads(response.content)
|
||||
return {"status": "success", "tokenizer": tokenizer_config}
|
||||
else:
|
||||
return {"status": "failure"}
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
|
||||
raise Exception("No chat template found")
|
||||
## read the bos token, eos token and chat template from the json
|
||||
tokenizer_config = tokenizer_config["tokenizer"]
|
||||
bos_token = tokenizer_config["bos_token"]
|
||||
eos_token = tokenizer_config["eos_token"]
|
||||
chat_template = tokenizer_config["chat_template"]
|
||||
|
||||
def raise_exception(message):
|
||||
raise Exception(f"Error message - {message}")
|
||||
|
@ -231,12 +232,20 @@ def hf_chat_template(model: str, messages: list):
|
|||
|
||||
# Anthropic template
|
||||
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
|
||||
"""
|
||||
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
|
||||
- you can't just pass a system message
|
||||
- you can't pass a system message and follow that with an assistant message
|
||||
if system message is passed in, you can only do system, human, assistant or system, human
|
||||
|
||||
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
|
||||
"""
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
|
@ -245,15 +254,44 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
|
|||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
else:
|
||||
elif message["role"] == "assistant":
|
||||
if idx > 0 and messages[idx - 1]["role"] == "system":
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
|
||||
prompt += (
|
||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
)
|
||||
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
|
||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
|
||||
return prompt
|
||||
|
||||
### TOGETHER AI
|
||||
|
||||
def get_model_info(token, model):
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}'
|
||||
}
|
||||
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||
if response.status_code == 200:
|
||||
model_info = response.json()
|
||||
for m in model_info:
|
||||
if m["name"].lower().strip() == model.strip():
|
||||
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
||||
|
||||
if chat_template is not None:
|
||||
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
|
||||
elif prompt_format is not None:
|
||||
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
|
||||
else:
|
||||
prompt = default_pt(messages)
|
||||
return prompt
|
||||
|
||||
###
|
||||
|
||||
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
|
@ -320,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
prompt += final_prompt_value
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
|
||||
original_model_name = model
|
||||
model = model.lower()
|
||||
if custom_llm_provider == "ollama":
|
||||
|
@ -330,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return claude_2_1_pt(messages=messages)
|
||||
else:
|
||||
return anthropic_pt(messages=messages)
|
||||
|
||||
elif custom_llm_provider == "together_ai":
|
||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
|
@ -342,7 +382,7 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
elif "mosaicml/mpt" in model:
|
||||
if "chat" in model:
|
||||
return mpt_chat_pt(messages=messages)
|
||||
elif "codellama/codellama" in model:
|
||||
elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
|
||||
if "instruct" in model:
|
||||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
elif "wizardlm/wizardcoder" in model:
|
||||
|
@ -357,4 +397,4 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return hf_chat_template(original_model_name, messages)
|
||||
except:
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
|
|||
logging_obj.pre_call(
|
||||
input=input_data["prompt"],
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers},
|
||||
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url},
|
||||
)
|
||||
|
||||
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
|
||||
|
@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
|
||||
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
|
@ -194,41 +195,48 @@ def completion(
|
|||
):
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
|
||||
## Load Config
|
||||
config = litellm.ReplicateConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
system_prompt = None
|
||||
if optional_params is not None and "supports_system_prompt" in optional_params:
|
||||
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
||||
else:
|
||||
supports_sys_prompt = False
|
||||
|
||||
if supports_sys_prompt:
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == "system":
|
||||
first_sys_message = messages.pop(i)
|
||||
system_prompt = first_sys_message["content"]
|
||||
break
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
if "meta/llama-2-13b-chat" in model:
|
||||
system_prompt = ""
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
else:
|
||||
prompt += message["content"]
|
||||
# If system prompt is supported, and a system prompt is provided, use it
|
||||
if system_prompt is not None:
|
||||
input_data = {
|
||||
"system_prompt": system_prompt,
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
**optional_params
|
||||
}
|
||||
# Otherwise, use the prompt as is
|
||||
else:
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
|
|
|
@ -5,10 +5,11 @@ import requests
|
|||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class SagemakerError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -61,6 +62,8 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
hf_model_name=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -107,19 +110,24 @@ def completion(
|
|||
inference_params[k] = v
|
||||
|
||||
model = model
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
if hf_model_name is None:
|
||||
if "llama-2" in model.lower(): # llama-2 model
|
||||
if "chat" in model.lower(): # apply llama2 chat template
|
||||
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
else: # apply regular llama2 template
|
||||
hf_model_name = "meta-llama/Llama-2-7b"
|
||||
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
|
||||
data = json.dumps({
|
||||
"inputs": prompt,
|
||||
|
@ -138,15 +146,18 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
try:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
response = response["Body"].read().decode("utf8")
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -185,6 +196,133 @@ def completion(
|
|||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
def embedding(model: str,
|
||||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None):
|
||||
"""
|
||||
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||
"""
|
||||
### BOTO3 INIT
|
||||
import boto3
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME") or
|
||||
"us-west-2" # default to us-west-2 if user not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
#### HF EMBEDDING LOGIC
|
||||
data = json.dumps({
|
||||
"text_inputs": input
|
||||
}).encode('utf-8')
|
||||
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
## EMBEDDING CALL
|
||||
try:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
|
||||
response = json.loads(response["Body"].read().decode("utf8"))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
if "embedding" not in response:
|
||||
raise SagemakerError(status_code=500, message="embedding not found in response")
|
||||
embeddings = response['embedding']
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
|
||||
|
||||
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
)
|
||||
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens+=len(encoding.encode(text))
|
||||
|
||||
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -115,7 +115,7 @@ def completion(
|
|||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
|
|
@ -93,45 +93,58 @@ def completion(
|
|||
prompt = " ".join([message["content"] for message in messages])
|
||||
|
||||
mode = ""
|
||||
|
||||
request_str = ""
|
||||
if model in litellm.vertex_chat_models:
|
||||
chat_model = ChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"chat_model = ChatModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_text_models:
|
||||
text_model = TextGenerationModel.from_pretrained(model)
|
||||
mode = "text"
|
||||
request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_code_text_models:
|
||||
text_model = CodeGenerationModel.from_pretrained(model)
|
||||
mode = "text"
|
||||
request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||
else: # vertex_code_chat_models
|
||||
chat_model = CodeChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"chat_model = CodeChatModel.from_pretrained({model})\n"
|
||||
|
||||
if mode == "chat":
|
||||
chat = chat_model.start_chat()
|
||||
request_str+= f"chat = chat_model.start_chat()\n"
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
|
||||
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||
# we handle this by removing 'stream' from optional params and sending the request
|
||||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
completion_response = chat.send_message(prompt, **optional_params).text
|
||||
elif mode == "text":
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||
request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = text_model.predict_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"text_model.predict({prompt}, **{optional_params}).text\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
completion_response = text_model.predict(prompt, **optional_params).text
|
||||
|
||||
## LOGGING
|
||||
|
|
181
litellm/main.py
181
litellm/main.py
|
@ -99,17 +99,18 @@ class Chat():
|
|||
def __init__(self, params):
|
||||
self.params = params
|
||||
self.completions = Completions(self.params)
|
||||
|
||||
|
||||
class Completions():
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
def create(self, model, messages, **kwargs):
|
||||
def create(self, messages, model=None, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
model = model or self.params.get('model')
|
||||
response = completion(model=model, messages=messages, **self.params)
|
||||
return response
|
||||
return response
|
||||
|
||||
@client
|
||||
async def acompletion(*args, **kwargs):
|
||||
|
@ -174,7 +175,8 @@ async def acompletion(*args, **kwargs):
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
if kwargs.get("stream", False):
|
||||
response = completion(*args, **kwargs)
|
||||
else:
|
||||
|
@ -318,7 +320,6 @@ def completion(
|
|||
######### unpacking kwargs #####################
|
||||
args = locals()
|
||||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||
logger_fn = kwargs.get('logger_fn', None)
|
||||
|
@ -327,6 +328,8 @@ def completion(
|
|||
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
|
||||
id = kwargs.get('id', None)
|
||||
metadata = kwargs.get('metadata', None)
|
||||
model_info = kwargs.get('model_info', None)
|
||||
proxy_server_request = kwargs.get('proxy_server_request', None)
|
||||
fallbacks = kwargs.get('fallbacks', None)
|
||||
headers = kwargs.get("headers", None)
|
||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
|
@ -341,11 +344,14 @@ def completion(
|
|||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
### ASYNC CALLS ###
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
client = kwargs.get("client", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -442,7 +448,7 @@ def completion(
|
|||
|
||||
# For logging - save the values of the litellm-specific params passed in
|
||||
litellm_params = get_litellm_params(
|
||||
return_async=return_async,
|
||||
acompletion=acompletion,
|
||||
api_key=api_key,
|
||||
force_timeout=force_timeout,
|
||||
logger_fn=logger_fn,
|
||||
|
@ -452,7 +458,10 @@ def completion(
|
|||
litellm_call_id=kwargs.get('litellm_call_id', None),
|
||||
model_alias_map=litellm.model_alias_map,
|
||||
completion_call_id=id,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
model_info=model_info,
|
||||
proxy_server_request=proxy_server_request,
|
||||
preset_cache_key=preset_cache_key
|
||||
)
|
||||
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
|
||||
if custom_llm_provider == "azure":
|
||||
|
@ -516,17 +525,18 @@ def completion(
|
|||
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
elif (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider == "custom_openai"
|
||||
|
@ -596,13 +606,14 @@ def completion(
|
|||
)
|
||||
raise e
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "text-completion-openai"
|
||||
or "ft:babbage-002" in model
|
||||
|
@ -673,9 +684,14 @@ def completion(
|
|||
logger_fn=logger_fn
|
||||
)
|
||||
|
||||
# if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
|
||||
# return response
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=model_response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
response = model_response
|
||||
elif (
|
||||
"replicate" in model or
|
||||
|
@ -720,8 +736,16 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
|
||||
return response
|
||||
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=replicate_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
|
||||
response = model_response
|
||||
|
||||
elif custom_llm_provider=="anthropic":
|
||||
|
@ -741,7 +765,7 @@ def completion(
|
|||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = anthropic.completion(
|
||||
response = anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -757,9 +781,16 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
response = response
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
nlp_cloud_key = (
|
||||
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
|
||||
|
@ -772,7 +803,7 @@ def completion(
|
|||
or "https://api.nlpcloud.io/v1/gpu/"
|
||||
)
|
||||
|
||||
model_response = nlp_cloud.completion(
|
||||
response = nlp_cloud.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -788,9 +819,17 @@ def completion(
|
|||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
response = response
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
aleph_alpha_key = (
|
||||
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
|
||||
|
@ -1166,6 +1205,8 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging
|
||||
|
@ -1190,7 +1231,7 @@ def completion(
|
|||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = bedrock.completion(
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
|
@ -1208,16 +1249,24 @@ def completion(
|
|||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
return response
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
response = response
|
||||
elif custom_llm_provider == "vllm":
|
||||
model_response = vllm.completion(
|
||||
model=model,
|
||||
|
@ -1270,7 +1319,9 @@ def completion(
|
|||
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
||||
return async_generator
|
||||
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||
if acompletion is True:
|
||||
return generator
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -1673,7 +1724,7 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
|||
return responses
|
||||
|
||||
### EMBEDDING ENDPOINTS ####################
|
||||
|
||||
@client
|
||||
async def aembedding(*args, **kwargs):
|
||||
"""
|
||||
Asynchronously calls the `embedding` function with the given arguments and keyword arguments.
|
||||
|
@ -1770,17 +1821,24 @@ def embedding(
|
|||
client = kwargs.pop("client", None)
|
||||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.pop("aembedding", None)
|
||||
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = {}
|
||||
for param in kwargs:
|
||||
if param != "metadata": # filter out metadata from optional_params
|
||||
optional_params[param] = kwargs[param]
|
||||
for param in non_default_params:
|
||||
optional_params[param] = kwargs[param]
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||
|
||||
|
||||
try:
|
||||
response = None
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
|
||||
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}})
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
@ -1899,9 +1957,19 @@ def embedding(
|
|||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=kwargs,
|
||||
optional_params=optional_params,
|
||||
model_response= EmbeddingResponse()
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response= EmbeddingResponse(),
|
||||
print_verbose=print_verbose
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
|
@ -1910,7 +1978,7 @@ def embedding(
|
|||
## LOGGING
|
||||
logging.post_call(
|
||||
input=input,
|
||||
api_key=openai.api_key,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
## Map to OpenAI Exception
|
||||
|
@ -2123,8 +2191,11 @@ def moderation(input: str, api_key: Optional[str]=None):
|
|||
####### HELPER FUNCTIONS ################
|
||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
def config_completion(**kwargs):
|
||||
if litellm.config_path != None:
|
||||
|
|
|
@ -41,6 +41,20 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.0000015,
|
||||
|
@ -62,6 +76,13 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.0000010,
|
||||
"output_cost_per_token": 0.0000020,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.000003,
|
||||
|
@ -76,6 +97,62 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ft:gpt-3.5-turbo": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.000012,
|
||||
"output_cost_per_token": 0.000016,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"text-embedding-ada-002": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"azure/gpt-4-1106-preview": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-4-32k": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.00006,
|
||||
"output_cost_per_token": 0.00012,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-4": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.00003,
|
||||
"output_cost_per_token": 0.00006,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-3.5-turbo-16k": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000004,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-3.5-turbo": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.0000015,
|
||||
"output_cost_per_token": 0.000002,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/text-embedding-ada-002": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.000002,
|
||||
|
@ -127,6 +204,7 @@
|
|||
},
|
||||
"claude-instant-1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"litellm_provider": "anthropic",
|
||||
|
@ -134,15 +212,25 @@
|
|||
},
|
||||
"claude-instant-1.2": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000000163,
|
||||
"output_cost_per_token": 0.000000551,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat"
|
||||
},
|
||||
"claude-2": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat"
|
||||
},
|
||||
"claude-2.1": {
|
||||
"max_tokens": 200000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat"
|
||||
},
|
||||
|
@ -227,9 +315,51 @@
|
|||
"max_tokens": 32000,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "vertex_ai-chat-models",
|
||||
"litellm_provider": "vertex_ai-code-chat-models",
|
||||
"mode": "chat"
|
||||
},
|
||||
"palm/chat-bison": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "chat"
|
||||
},
|
||||
"palm/chat-bison-001": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "chat"
|
||||
},
|
||||
"palm/text-bison": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"palm/text-bison-001": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"palm/text-bison-safety-off": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"palm/text-bison-safety-recitation-off": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"command-nightly": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000015,
|
||||
|
@ -267,6 +397,8 @@
|
|||
},
|
||||
"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000,
|
||||
"output_cost_per_token": 0.0000,
|
||||
"litellm_provider": "replicate",
|
||||
"mode": "chat"
|
||||
},
|
||||
|
@ -293,6 +425,7 @@
|
|||
},
|
||||
"openrouter/anthropic/claude-instant-v1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"litellm_provider": "openrouter",
|
||||
|
@ -300,6 +433,7 @@
|
|||
},
|
||||
"openrouter/anthropic/claude-2": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"litellm_provider": "openrouter",
|
||||
|
@ -496,20 +630,31 @@
|
|||
},
|
||||
"anthropic.claude-v1": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anthropic.claude-v2": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anthropic.claude-v2:1": {
|
||||
"max_tokens": 200000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anthropic.claude-instant-v1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -529,26 +674,80 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"meta.llama2-70b-chat-v1": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000195,
|
||||
"output_cost_per_token": 0.00000256,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "completion"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b-f": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "chat"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-13b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "completion"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-13b-f": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "chat"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-70b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "completion"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-70b-b-f": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "chat"
|
||||
},
|
||||
"together-ai-up-to-3b": {
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.0000001
|
||||
"output_cost_per_token": 0.0000001,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-3.1b-7b": {
|
||||
"input_cost_per_token": 0.0000002,
|
||||
"output_cost_per_token": 0.0000002
|
||||
"output_cost_per_token": 0.0000002,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-7.1b-20b": {
|
||||
"max_tokens": 1000,
|
||||
"input_cost_per_token": 0.0000004,
|
||||
"output_cost_per_token": 0.0000004
|
||||
"output_cost_per_token": 0.0000004,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-20.1b-40b": {
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001
|
||||
"input_cost_per_token": 0.0000008,
|
||||
"output_cost_per_token": 0.0000008,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-40.1b-70b": {
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000003
|
||||
"input_cost_per_token": 0.0000009,
|
||||
"output_cost_per_token": 0.0000009,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"ollama/llama2": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -578,10 +777,38 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/mistral": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/codellama": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/orca-mini": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/vicuna": {
|
||||
"max_tokens": 2048,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-2-70b-chat-hf": {
|
||||
"max_tokens": 6144,
|
||||
"input_cost_per_token": 0.000001875,
|
||||
"output_cost_per_token": 0.000001875,
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000700,
|
||||
"output_cost_per_token": 0.000000950,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat"
|
||||
},
|
||||
|
@ -619,5 +846,103 @@
|
|||
"output_cost_per_token": 0.00000095,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-7b-chat": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-70b-chat": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-7b-online": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.0005,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-70b-online": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.0005,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/llama-2-13b-chat": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/llama-2-70b-chat": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/mistral-7b-instruct": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/replit-code-v1.5-3b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-13b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000025,
|
||||
"output_cost_per_token": 0.00000025,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-70b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/codellama/CodeLlama-34b-Instruct-hf": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
}
|
||||
}
|
||||
|
|
173
litellm/proxy/_types.py
Normal file
173
litellm/proxy/_types.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from typing import Optional, List, Union, Dict, Literal
|
||||
from datetime import datetime
|
||||
import uuid, json
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
######### Request Class Definition ######
|
||||
class ProxyChatCompletionRequest(LiteLLMBase):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[Dict[str, str]] = None
|
||||
seed: Optional[int] = None
|
||||
tools: Optional[List[str]] = None
|
||||
tool_choice: Optional[str] = None
|
||||
functions: Optional[List[str]] = None # soon to be deprecated
|
||||
function_call: Optional[str] = None # soon to be deprecated
|
||||
|
||||
# Optional LiteLLM params
|
||||
caching: Optional[bool] = None
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
num_retries: Optional[int] = None
|
||||
context_window_fallback_dict: Optional[Dict[str, str]] = None
|
||||
fallbacks: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
deployment_id: Optional[str] = None
|
||||
request_timeout: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
class ModelInfoDelete(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
|
||||
|
||||
class ModelInfo(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
||||
input_cost_per_token: Optional[float] = 0.0
|
||||
output_cost_per_token: Optional[float] = 0.0
|
||||
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||
|
||||
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
||||
# we look up the base model in model_prices_and_context_window.json
|
||||
base_model: Optional[Literal
|
||||
[
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4-32k',
|
||||
'gpt-4',
|
||||
'gpt-3.5-turbo-16k',
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
]
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow # Allow extra fields
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("id") is None:
|
||||
values.update({"id": str(uuid.uuid4())})
|
||||
if values.get("mode") is None:
|
||||
values.update({"mode": None})
|
||||
if values.get("input_cost_per_token") is None:
|
||||
values.update({"input_cost_per_token": None})
|
||||
if values.get("output_cost_per_token") is None:
|
||||
values.update({"output_cost_per_token": None})
|
||||
if values.get("max_tokens") is None:
|
||||
values.update({"max_tokens": None})
|
||||
if values.get("base_model") is None:
|
||||
values.update({"base_model": None})
|
||||
return values
|
||||
|
||||
|
||||
|
||||
class ModelParams(LiteLLMBase):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: ModelInfo
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("model_info") is None:
|
||||
values.update({"model_info": ModelInfo()})
|
||||
return values
|
||||
|
||||
class GenerateKeyRequest(LiteLLMBase):
|
||||
duration: Optional[str] = "1h"
|
||||
models: Optional[list] = []
|
||||
aliases: Optional[dict] = {}
|
||||
config: Optional[dict] = {}
|
||||
spend: Optional[float] = 0
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
api_key: Optional[str] = None
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: Optional[float] = 0
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
"""
|
||||
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
|
||||
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault")
|
||||
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy")
|
||||
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key")
|
||||
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.")
|
||||
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
|
||||
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
|
||||
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
|
||||
background_health_checks: Optional[bool] = Field(None, description="run health checks in background")
|
||||
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||
|
||||
|
||||
class ConfigYAML(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by the config.yaml
|
||||
"""
|
||||
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
|
||||
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
|
||||
general_settings: Optional[ConfigGeneralSettings] = None
|
||||
class Config:
|
||||
protected_namespaces = ()
|
14
litellm/proxy/custom_auth.py
Normal file
14
litellm/proxy/custom_auth.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
|
||||
if api_key == modified_master_key:
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
59
litellm/proxy/custom_callbacks.py
Normal file
59
litellm/proxy/custom_callbacks.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
import inspect
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
def __init__(self):
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
|
||||
try:
|
||||
print_verbose(f"Logger Initialized with following methods:")
|
||||
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
|
||||
|
||||
# Pretty print_verbose the methods
|
||||
for method in methods:
|
||||
print_verbose(f" - {method}")
|
||||
print_verbose(f"{reset_color_code}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print_verbose(f"Pre-API Call")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"Post-API Call")
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose("On Success!")
|
||||
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"On Async Success!")
|
||||
response_cost = litellm.completion_cost(completion_response=response_obj)
|
||||
assert response_cost > 0.0
|
||||
return
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print_verbose(f"On Async Failure !")
|
||||
except Exception as e:
|
||||
print_verbose(f"Exception: {e}")
|
||||
|
||||
|
||||
proxy_handler_instance = MyCustomHandler()
|
||||
|
||||
|
||||
# need to set litellm.callbacks = [customHandler] # on the proxy
|
||||
|
||||
# litellm.success_callback = [async_on_succes_logger]
|
|
@ -4,14 +4,18 @@ model_list:
|
|||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
azure_ad_token: eyJ0eXAiOiJ
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
tpm: 20_000
|
||||
timeout: 5 # 1 second timeout
|
||||
stream_timeout: 0.5 # 0.5 second timeout for streaming requests
|
||||
max_retries: 4
|
||||
- model_name: gpt-4-team2
|
||||
litellm_params:
|
||||
model: azure/gpt-4
|
||||
api_key: sk-123
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
||||
- model_name: gpt-4-team3
|
||||
litellm_params:
|
||||
model: azure/gpt-4
|
||||
api_key: sk-123
|
||||
tpm: 100_000
|
||||
timeout: 5 # 1 second timeout
|
||||
stream_timeout: 0.5 # 0.5 second timeout for streaming requests
|
||||
max_retries: 4
|
||||
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-35-1
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key: 73g
|
||||
tpm: 80_000
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-35-2
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: 9kj
|
||||
tpm: 80_000
|
|
@ -1,8 +0,0 @@
|
|||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
|
||||
general_settings:
|
||||
master_key: sk-hosted-litellm
|
||||
use_queue: True
|
||||
database_url: " # [OPTIONAL] use for token-based auth to proxy
|
|
@ -1,11 +0,0 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2 # actual model name
|
||||
api_key:
|
||||
api_version: 2023-07-01-preview
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
4
litellm/proxy/example_config_yaml/simple_config.yaml
Normal file
4
litellm/proxy/example_config_yaml/simple_config.yaml
Normal file
|
@ -0,0 +1,4 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
117
litellm/proxy/health_check.py
Normal file
117
litellm/proxy/health_check.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
# This file runs a health check for the LLM, used on litellm/proxy
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
import logging
|
||||
from litellm._logging import print_verbose
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ILLEGAL_DISPLAY_PARAMS = [
|
||||
"messages",
|
||||
"api_key"
|
||||
]
|
||||
|
||||
|
||||
def _get_random_llm_message():
|
||||
"""
|
||||
Get a random message from the LLM.
|
||||
"""
|
||||
messages = [
|
||||
"Hey how's it going?",
|
||||
"What's 1 + 1?"
|
||||
]
|
||||
|
||||
|
||||
return [
|
||||
{"role": "user", "content": random.choice(messages)}
|
||||
]
|
||||
|
||||
|
||||
def _clean_litellm_params(litellm_params: dict):
|
||||
"""
|
||||
Clean the litellm params for display to users.
|
||||
"""
|
||||
return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}
|
||||
|
||||
|
||||
async def _perform_health_check(model_list: list):
|
||||
"""
|
||||
Perform a health check for each model in the list.
|
||||
"""
|
||||
async def _check_embedding_model(model_params: dict):
|
||||
model_params.pop("messages", None)
|
||||
model_params["input"] = ["test from litellm"]
|
||||
try:
|
||||
await litellm.aembedding(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _check_model(model_params: dict):
|
||||
try:
|
||||
await litellm.acompletion(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
prepped_params = []
|
||||
tasks = []
|
||||
for model in model_list:
|
||||
litellm_params = model["litellm_params"]
|
||||
model_info = model.get("model_info", {})
|
||||
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
|
||||
litellm_params["messages"] = _get_random_llm_message()
|
||||
|
||||
prepped_params.append(litellm_params)
|
||||
if model_info.get("mode", None) == "embedding":
|
||||
# this is an embedding model
|
||||
tasks.append(_check_embedding_model(litellm_params))
|
||||
else:
|
||||
tasks.append(_check_model(litellm_params))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
healthy_endpoints = []
|
||||
unhealthy_endpoints = []
|
||||
|
||||
for is_healthy, model in zip(results, model_list):
|
||||
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
|
||||
|
||||
if is_healthy:
|
||||
healthy_endpoints.append(cleaned_litellm_params)
|
||||
else:
|
||||
unhealthy_endpoints.append(cleaned_litellm_params)
|
||||
|
||||
return healthy_endpoints, unhealthy_endpoints
|
||||
|
||||
|
||||
|
||||
|
||||
async def perform_health_check(model_list: list, model: Optional[str] = None):
|
||||
"""
|
||||
Perform a health check on the system.
|
||||
|
||||
Returns:
|
||||
(bool): True if the health check passes, False otherwise.
|
||||
"""
|
||||
if not model_list:
|
||||
return [], []
|
||||
|
||||
if model is not None:
|
||||
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]
|
||||
|
||||
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
|
||||
|
||||
return healthy_endpoints, unhealthy_endpoints
|
||||
|
||||
|
1
litellm/proxy/hooks/__init__.py
Normal file
1
litellm/proxy/hooks/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from . import *
|
70
litellm/proxy/hooks/parallel_request_limiter.py
Normal file
70
litellm/proxy/hooks/parallel_request_limiter.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from typing import Optional
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
|
||||
class MaxParallelRequestsHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
|
||||
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if max_parallel_requests is None:
|
||||
return
|
||||
|
||||
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
self.print_verbose(f"current: {current}")
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
||||
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
if user_api_key is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{user_api_key}_request_count"
|
||||
# check if it has collected an entire stream response
|
||||
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
|
||||
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in success call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
|
||||
async def async_log_failure_call(self, api_key, user_api_key_cache):
|
||||
try:
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
|
@ -26,7 +26,7 @@ def run_ollama_serve():
|
|||
except Exception as e:
|
||||
print(f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""")
|
||||
""") # noqa
|
||||
|
||||
def clone_subfolder(repo_url, subfolder, destination):
|
||||
# Clone the full repo
|
||||
|
@ -73,8 +73,7 @@ def is_port_in_use(port):
|
|||
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls')
|
||||
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params')
|
||||
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt')
|
||||
@click.option('--config', '-c', default=None, help='Configure Litellm')
|
||||
@click.option('--file', '-f', help='Path to config file')
|
||||
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
|
||||
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
|
||||
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
||||
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
|
||||
|
@ -83,7 +82,7 @@ def is_port_in_use(port):
|
|||
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
|
||||
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
|
||||
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
||||
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, file, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health):
|
||||
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health):
|
||||
global feature_telemetry
|
||||
args = locals()
|
||||
if local:
|
||||
|
@ -110,11 +109,11 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
# get n recent logs
|
||||
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
|
||||
|
||||
print(json.dumps(recent_logs, indent=4))
|
||||
print(json.dumps(recent_logs, indent=4)) # noqa
|
||||
except:
|
||||
print("LiteLLM: No logs saved!")
|
||||
raise Exception("LiteLLM: No logs saved!")
|
||||
return
|
||||
if model and "ollama" in model:
|
||||
if model and "ollama" in model and api_base is None:
|
||||
run_ollama_serve()
|
||||
if test_async is True:
|
||||
import requests, concurrent, time
|
||||
|
@ -141,7 +140,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
if status == "finished":
|
||||
llm_response = polling_response["result"]
|
||||
break
|
||||
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
|
||||
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
print("got exception in polling", e)
|
||||
|
|
|
@ -1,8 +1,54 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
- model_name: Azure OpenAI GPT-4 Canada-East (External)
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
mode: chat
|
||||
input_cost_per_token: 0.0.00006
|
||||
output_cost_per_token: 0.00003
|
||||
max_tokens: 4096
|
||||
base_model: gpt-3.5-turbo
|
||||
- model_name: BEDROCK_GROUP
|
||||
litellm_params:
|
||||
model: bedrock/cohere.command-text-v14
|
||||
- model_name: Azure OpenAI GPT-4 Canada-East (External)
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
mode: chat
|
||||
- model_name: azure-embedding-model
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
mode: embedding
|
||||
base_model: text-embedding-ada-002
|
||||
- model_name: text-embedding-ada-002
|
||||
litellm_params:
|
||||
model: text-embedding-ada-002
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
mode: embedding
|
||||
- model_name: text-davinci-003
|
||||
litellm_params:
|
||||
model: text-davinci-003
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
mode: completion
|
||||
|
||||
litellm_settings:
|
||||
# cache: True
|
||||
# setting callback class
|
||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
|
||||
general_settings:
|
||||
|
||||
environment_variables:
|
||||
# otel: True # OpenTelemetry Logger
|
||||
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -45,7 +45,7 @@ celery_app.conf.update(
|
|||
@celery_app.task(name='process_job', max_retries=3)
|
||||
def process_job(*args, **kwargs):
|
||||
try:
|
||||
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list"))
|
||||
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
|
||||
response = llm_router.completion(*args, **kwargs) # type: ignore
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
response = response.model_dump_json()
|
||||
|
|
|
@ -16,4 +16,5 @@ model LiteLLM_VerificationToken {
|
|||
aliases Json @default("{}")
|
||||
config Json @default("{}")
|
||||
user_id String?
|
||||
max_parallel_requests Int?
|
||||
}
|
40
litellm/proxy/tests/test_langchain_request.py
Normal file
40
litellm/proxy/tests/test_langchain_request.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
## LOCAL TEST
|
||||
# from langchain.chat_models import ChatOpenAI
|
||||
# from langchain.prompts.chat import (
|
||||
# ChatPromptTemplate,
|
||||
# HumanMessagePromptTemplate,
|
||||
# SystemMessagePromptTemplate,
|
||||
# )
|
||||
# from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
# chat = ChatOpenAI(
|
||||
# openai_api_base="http://0.0.0.0:8000",
|
||||
# model = "gpt-3.5-turbo",
|
||||
# temperature=0.1
|
||||
# )
|
||||
|
||||
# messages = [
|
||||
# SystemMessage(
|
||||
# content="You are a helpful assistant that im using to make a test request to."
|
||||
# ),
|
||||
# HumanMessage(
|
||||
# content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
# ),
|
||||
# ]
|
||||
# response = chat(messages)
|
||||
|
||||
# print(response)
|
||||
|
||||
# claude_chat = ChatOpenAI(
|
||||
# openai_api_base="http://0.0.0.0:8000",
|
||||
# model = "claude-v1",
|
||||
# temperature=0.1
|
||||
# )
|
||||
|
||||
# response = claude_chat(messages)
|
||||
|
||||
# print(response)
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,87 +1,342 @@
|
|||
from typing import Optional, List, Any
|
||||
import os, subprocess, hashlib
|
||||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging/Custom Handlers for proxy.
|
||||
|
||||
Implemented mainly to:
|
||||
- log successful/failed db read/writes
|
||||
- support the max parallel request integration
|
||||
"""
|
||||
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
):
|
||||
callback_list = list(
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.utils.set_callbacks(
|
||||
callback_list=callback_list
|
||||
)
|
||||
|
||||
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def failure_handler(self, original_exception):
|
||||
"""
|
||||
Log failed db read/writes
|
||||
|
||||
Currently only logs exceptions to sentry
|
||||
"""
|
||||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await self.max_parallel_request_limiter.async_log_failure_call(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
# Define the retry decorator with backoff strategy
|
||||
# Function to be called whenever a retry is about to happen
|
||||
def on_backoff(details):
|
||||
# The 'tries' key in the details dictionary contains the number of completed tries
|
||||
print_verbose(f"Backing off... this was attempt #{details['tries']}")
|
||||
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
# Save the current working directory
|
||||
original_dir = os.getcwd()
|
||||
# set the working directory to where this script is
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(abspath)
|
||||
os.chdir(dname)
|
||||
|
||||
try:
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client
|
||||
from prisma import Client # type: ignore
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
|
||||
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
return hashed_token
|
||||
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
return response
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = hashed_token
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
return new_verification_token
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**db_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def update_data(self, token: str, data: dict):
|
||||
"""
|
||||
Update existing data
|
||||
"""
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": hashed_token
|
||||
},
|
||||
data={**data} # type: ignore
|
||||
)
|
||||
return {"token": token, "data": data}
|
||||
try:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
raise e
|
||||
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
try:
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def connect(self):
|
||||
await self.db.connect()
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def disconnect(self):
|
||||
await self.db.disconnect()
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
try:
|
||||
print_verbose(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
# The module path is all but the last part, and the instance_name is the last part
|
||||
module_name = ".".join(parts[:-1])
|
||||
instance_name = parts[-1]
|
||||
|
||||
# If config_file_path is provided, use it to determine the module spec and load the module
|
||||
if config_file_path is not None:
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, *module_name.split('.'))
|
||||
module_file_path += '.py'
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not find a module specification for {module_file_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
else:
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Get the instance from the module
|
||||
instance = getattr(module, instance_name)
|
||||
|
||||
return instance
|
||||
except ImportError as e:
|
||||
# Re-raise the exception with a user-friendly message
|
||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
|
@ -7,16 +7,18 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal
|
||||
import random, threading, time, traceback
|
||||
from typing import Dict, List, Optional, Union, Literal, Any
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import logging, asyncio
|
||||
import inspect, concurrent
|
||||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
import copy
|
||||
class Router:
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -53,17 +55,22 @@ class Router:
|
|||
```
|
||||
"""
|
||||
model_names: List = []
|
||||
cache_responses: bool = False
|
||||
cache_responses: Optional[bool] = False
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
num_retries: int = 0
|
||||
tenacity = None
|
||||
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
|
||||
|
||||
def __init__(self,
|
||||
model_list: Optional[list] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
redis_port: Optional[int] = None,
|
||||
redis_password: Optional[str] = None,
|
||||
cache_responses: bool = False,
|
||||
cache_responses: Optional[bool] = False,
|
||||
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||
## RELIABILITY ##
|
||||
num_retries: int = 0,
|
||||
timeout: Optional[float] = None,
|
||||
default_litellm_params = {}, # default params for Router.chat.completion.create
|
||||
|
@ -74,7 +81,9 @@ class Router:
|
|||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||
|
||||
self.set_verbose = set_verbose
|
||||
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
if model_list:
|
||||
model_list = copy.deepcopy(model_list)
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
self.deployment_latency_map = {}
|
||||
|
@ -92,8 +101,8 @@ class Router:
|
|||
self.total_calls: defaultdict = defaultdict(int) # dict to store total calls made to each model
|
||||
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
|
||||
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
|
||||
|
||||
|
||||
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
|
||||
|
||||
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
||||
|
@ -102,28 +111,44 @@ class Router:
|
|||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
|
||||
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
self._start_health_check_thread()
|
||||
### CACHING ###
|
||||
cache_type = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
||||
cache_config = {
|
||||
'type': 'redis',
|
||||
'host': redis_host,
|
||||
'port': redis_port,
|
||||
'password': redis_password
|
||||
}
|
||||
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
|
||||
else: # use an in-memory cache
|
||||
cache_config = {
|
||||
"type": "local"
|
||||
}
|
||||
cache_config = {}
|
||||
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
|
||||
cache_type = "redis"
|
||||
|
||||
if redis_url is not None:
|
||||
cache_config['url'] = redis_url
|
||||
|
||||
if redis_host is not None:
|
||||
cache_config['host'] = redis_host
|
||||
|
||||
if redis_port is not None:
|
||||
cache_config['port'] = str(redis_port) # type: ignore
|
||||
|
||||
if redis_password is not None:
|
||||
cache_config['password'] = redis_password
|
||||
|
||||
# Add additional key-value pairs from cache_kwargs
|
||||
cache_config.update(cache_kwargs)
|
||||
redis_cache = RedisCache(**cache_config)
|
||||
if cache_responses:
|
||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||
if litellm.cache is None:
|
||||
# the cache can be initialized on the proxy server. We should not overwrite it
|
||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
||||
## add callback
|
||||
if isinstance(litellm.input_callback, list):
|
||||
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||
else:
|
||||
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.deployment_callback)
|
||||
|
@ -171,9 +196,10 @@ class Router:
|
|||
|
||||
try:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
|
@ -188,7 +214,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -219,8 +245,9 @@ class Router:
|
|||
try:
|
||||
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
||||
original_model_string = None # set a default for this variable
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
|
@ -234,7 +261,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("async_client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[original_model_string] +=1
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[original_model_string] +=1
|
||||
|
@ -255,7 +282,7 @@ class Router:
|
|||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
|
@ -288,8 +315,9 @@ class Router:
|
|||
is_async: Optional[bool] = False,
|
||||
**kwargs) -> Union[List[float], None]:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, input=input)
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("model_info", {})
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
|
@ -303,7 +331,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
||||
|
@ -313,9 +341,10 @@ class Router:
|
|||
is_async: Optional[bool] = True,
|
||||
**kwargs) -> Union[List[float], None]:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, input=input)
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
|
@ -328,7 +357,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("async_client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
|
||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
||||
|
@ -345,7 +374,7 @@ class Router:
|
|||
self.print_verbose(f'Async Response: {response}')
|
||||
return response
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurs")
|
||||
self.print_verbose(f"An exception occurs: {e}")
|
||||
original_exception = e
|
||||
try:
|
||||
self.print_verbose(f"Trying to fallback b/w models")
|
||||
|
@ -380,6 +409,8 @@ class Router:
|
|||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
kwargs["metadata"]["model_group"] = mg
|
||||
response = await self.async_function_with_retries(*args, **kwargs)
|
||||
|
@ -423,6 +454,10 @@ class Router:
|
|||
else:
|
||||
raise original_exception
|
||||
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||
try:
|
||||
|
@ -433,6 +468,8 @@ class Router:
|
|||
return response
|
||||
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
if "No models available" in str(e):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=1)
|
||||
|
@ -458,13 +495,12 @@ class Router:
|
|||
try:
|
||||
response = self.function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
self.print_verbose(f"An exception occurs {original_exception}")
|
||||
try:
|
||||
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
|
||||
if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None:
|
||||
self.print_verbose(f"inside context window fallbacks: {context_window_fallbacks}")
|
||||
fallback_model_group = None
|
||||
|
||||
for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||
|
@ -480,6 +516,8 @@ class Router:
|
|||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
response = self.function_with_fallbacks(*args, **kwargs)
|
||||
return response
|
||||
|
@ -501,11 +539,13 @@ class Router:
|
|||
Iterate through the model groups and try calling that deployment
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
response = self.function_with_fallbacks(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise original_exception
|
||||
|
@ -515,7 +555,6 @@ class Router:
|
|||
Try calling the model 3 times. Shuffle between available deployments.
|
||||
"""
|
||||
self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}")
|
||||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
|
@ -531,6 +570,9 @@ class Router:
|
|||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||
or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)):
|
||||
raise original_exception
|
||||
## LOGGING
|
||||
if num_retries > 0:
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
### RETRY
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}")
|
||||
|
@ -539,19 +581,19 @@ class Router:
|
|||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
if num_retries > 0:
|
||||
remaining_retries = num_retries - current_attempt
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
if "No models available" in str(e):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=1)
|
||||
time.sleep(timeout)
|
||||
elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
||||
if hasattr(e.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, response_headers=e.response.headers)
|
||||
else:
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries)
|
||||
time.sleep(timeout)
|
||||
else:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, immediately retry
|
||||
if num_retries > 0:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
raise original_exception
|
||||
|
@ -614,6 +656,27 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def log_retry(self, kwargs: dict, e: Exception) -> dict:
|
||||
"""
|
||||
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
|
||||
"""
|
||||
try:
|
||||
# Log failed model as the previous model
|
||||
previous_model = {"exception_type": type(e).__name__, "exception_string": str(e)}
|
||||
for k, v in kwargs.items(): # log everything in kwargs except the old previous_models value - prevent nesting
|
||||
if k != "metadata":
|
||||
previous_model[k] = v
|
||||
elif k == "metadata" and isinstance(v, dict):
|
||||
previous_model["metadata"] = {} # type: ignore
|
||||
for metadata_k, metadata_v in kwargs['metadata'].items():
|
||||
if metadata_k != "previous_models":
|
||||
previous_model[k][metadata_k] = metadata_v # type: ignore
|
||||
self.previous_models.append(previous_model)
|
||||
kwargs["metadata"]["previous_models"] = self.previous_models
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _set_cooldown_deployments(self,
|
||||
deployment: str):
|
||||
"""
|
||||
|
@ -817,12 +880,16 @@ class Router:
|
|||
return chosen_item
|
||||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = model_list
|
||||
self.model_list = copy.deepcopy(model_list)
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
import os
|
||||
for model in self.model_list:
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
#### MODEL ID INIT ########
|
||||
model_info = model.get("model_info", {})
|
||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||
model["model_info"] = model_info
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider is None:
|
||||
|
@ -845,6 +912,7 @@ class Router:
|
|||
if api_key and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = litellm.get_secret(api_key_env_name)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url = litellm_params.get("base_url")
|
||||
|
@ -852,13 +920,35 @@ class Router:
|
|||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = litellm.get_secret(api_base_env_name)
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = litellm.get_secret(api_version_env_name)
|
||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||
litellm_params["api_version"] = api_version
|
||||
|
||||
timeout = litellm_params.pop("timeout", None)
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout_env_name = timeout.replace("os.environ/", "")
|
||||
timeout = litellm.get_secret(timeout_env_name)
|
||||
litellm_params["timeout"] = timeout
|
||||
|
||||
stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
||||
max_retries = litellm_params.pop("max_retries", 2)
|
||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
if "azure" in model_name:
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Azure OpenAI. Set it on your config")
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
|
@ -869,40 +959,107 @@ class Router:
|
|||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# streaming clients can have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
else:
|
||||
self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}")
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
else:
|
||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}")
|
||||
model["async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
model["client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
############ End of initializing Clients for OpenAI/Azure ###################
|
||||
self.deployment_names.append(model["litellm_params"]["model"])
|
||||
model_id = ""
|
||||
for key in model["litellm_params"]:
|
||||
if key != "api_key":
|
||||
if key != "api_key" and key != "metadata":
|
||||
model_id+= str(model["litellm_params"][key])
|
||||
model["litellm_params"]["model"] += "-ModelID-" + model_id
|
||||
|
||||
self.print_verbose(f"\n Initialized Model List {self.model_list}")
|
||||
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||
# in this snippet we also set rpm to be a litellm_param
|
||||
|
@ -916,17 +1073,63 @@ class Router:
|
|||
def get_model_names(self):
|
||||
return self.model_names
|
||||
|
||||
def _get_client(self, deployment, kwargs, client_type=None):
|
||||
"""
|
||||
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
||||
|
||||
Parameters:
|
||||
deployment (dict): The deployment dictionary containing the clients.
|
||||
kwargs (dict): The keyword arguments passed to the function.
|
||||
client_type (str): The type of client to return.
|
||||
|
||||
Returns:
|
||||
The appropriate client based on the given client_type and kwargs.
|
||||
"""
|
||||
if client_type == "async":
|
||||
if kwargs.get("stream") == True:
|
||||
return deployment.get("stream_async_client", None)
|
||||
else:
|
||||
return deployment.get("async_client", None)
|
||||
else:
|
||||
if kwargs.get("stream") == True:
|
||||
return deployment.get("stream_client", None)
|
||||
else:
|
||||
return deployment.get("client", None)
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if self.set_verbose or litellm.set_verbose:
|
||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||
try:
|
||||
if self.set_verbose or litellm.set_verbose:
|
||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
def get_available_deployment(self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None):
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False
|
||||
):
|
||||
"""
|
||||
Returns the deployment based on routing strategy
|
||||
"""
|
||||
|
||||
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
||||
# When this was no explicit we had several issues with fallbacks timing out
|
||||
if specific_deployment == True:
|
||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||
for deployment in self.model_list:
|
||||
cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model"))
|
||||
if cleaned_model == model:
|
||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||
# return the first deployment where the `model` matches the specificed deployment name
|
||||
return deployment
|
||||
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
|
||||
|
||||
# check if aliases set on litellm model alias map
|
||||
if model in litellm.model_group_alias_map:
|
||||
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {litellm.model_group_alias_map.get(model)}")
|
||||
model = litellm.model_group_alias_map[model]
|
||||
|
||||
## get healthy deployments
|
||||
### get all deployments
|
||||
### filter out the deployments currently cooling down
|
||||
|
@ -934,6 +1137,7 @@ class Router:
|
|||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
|
@ -953,13 +1157,24 @@ class Router:
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
if self.routing_strategy == "least-busy":
|
||||
if len(self.healthy_deployments) > 0:
|
||||
for item in self.healthy_deployments:
|
||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||
return item[0]
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
|
||||
# pick least busy deployment
|
||||
min_traffic = float('inf')
|
||||
min_deployment = None
|
||||
for k, v in deployments.items():
|
||||
if v < min_traffic:
|
||||
min_deployment = k
|
||||
############## No Available Deployments passed, we do a random pick #################
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
############## Available Deployments passed, we find the relevant item #################
|
||||
else:
|
||||
raise ValueError("No models available.")
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
|
@ -1010,11 +1225,14 @@ class Router:
|
|||
raise ValueError("No models available.")
|
||||
|
||||
def flush_cache(self):
|
||||
litellm.cache = None
|
||||
self.cache.flush_cache()
|
||||
|
||||
def reset(self):
|
||||
## clean up on close
|
||||
litellm.success_callback = []
|
||||
litellm.__async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
self.flush_cache()
|
||||
|
96
litellm/router_strategy/least_busy.py
Normal file
96
litellm/router_strategy/least_busy.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
#### What this does ####
|
||||
# identifies least busy deployment
|
||||
# How is this achieved?
|
||||
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
import dotenv, os, requests
|
||||
from typing import Optional
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class LeastBusyLoggingHandler(CustomLogger):
|
||||
|
||||
def __init__(self, router_cache: DualCache):
|
||||
self.router_cache = router_cache
|
||||
self.mapping_deployment_to_id: dict = {}
|
||||
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
"""
|
||||
Log when a model is being used.
|
||||
|
||||
Caching based on model group.
|
||||
"""
|
||||
try:
|
||||
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
|
||||
if deployment is None or model_group is None or id is None:
|
||||
return
|
||||
|
||||
# map deployment to id
|
||||
self.mapping_deployment_to_id[deployment] = id
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# update cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
if deployment is None or model_group is None:
|
||||
return
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
if deployment is None or model_group is None:
|
||||
return
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def get_available_deployments(self, model_group: str):
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
# map deployment to id
|
||||
return_dict = {}
|
||||
for key, value in request_count_dict.items():
|
||||
return_dict[self.mapping_deployment_to_id[key]] = value
|
||||
return return_dict
|
34
litellm/tests/conftest.py
Normal file
34
litellm/tests/conftest.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# conftest.py
|
||||
|
||||
import pytest, sys, os
|
||||
import importlib
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
|
||||
import litellm
|
||||
importlib.reload(litellm)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
|
||||
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
30
litellm/tests/example_config_yaml/aliases_config.yaml
Normal file
30
litellm/tests/example_config_yaml/aliases_config.yaml
Normal file
|
@ -0,0 +1,30 @@
|
|||
model_list:
|
||||
- model_name: text-davinci-003
|
||||
litellm_params:
|
||||
model: ollama/zephyr
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: ollama/llama2
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: ollama/llama2
|
||||
temperature: 0.1
|
||||
max_tokens: 20
|
||||
|
||||
|
||||
# request to gpt-4, response from ollama/llama2
|
||||
# curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
# --header 'Content-Type: application/json' \
|
||||
# --data ' {
|
||||
# "model": "gpt-4",
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "what llm are you"
|
||||
# }
|
||||
# ],
|
||||
# }
|
||||
# '
|
||||
#
|
||||
|
||||
# {"id":"chatcmpl-27c85cf0-ab09-4bcf-8cb1-0ee950520743","choices":[{"finish_reason":"stop","index":0,"message":{"content":" Hello! I'm just an AI, I don't have personal experiences or emotions like humans do. However, I can help you with any questions or tasks you may have! Is there something specific you'd like to know or discuss?","role":"assistant","_logprobs":null}}],"created":1700094955.373751,"model":"ollama/llama2","object":"chat.completion","system_fingerprint":null,"usage":{"prompt_tokens":12,"completion_tokens":47,"total_tokens":59},"_response_ms":8028.017999999999}%
|
15
litellm/tests/example_config_yaml/azure_config.yaml
Normal file
15
litellm/tests/example_config_yaml/azure_config.yaml
Normal file
|
@ -0,0 +1,15 @@
|
|||
model_list:
|
||||
- model_name: gpt-4-team1
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
tpm: 20_000
|
||||
- model_name: gpt-4-team2
|
||||
litellm_params:
|
||||
model: azure/gpt-4
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
||||
tpm: 100_000
|
||||
|
7
litellm/tests/example_config_yaml/langfuse_config.yaml
Normal file
7
litellm/tests/example_config_yaml/langfuse_config.yaml
Normal file
|
@ -0,0 +1,7 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration
|
||||
|
28
litellm/tests/example_config_yaml/load_balancer.yaml
Normal file
28
litellm/tests/example_config_yaml/load_balancer.yaml
Normal file
|
@ -0,0 +1,28 @@
|
|||
litellm_settings:
|
||||
drop_params: True
|
||||
|
||||
# Model-specific settings
|
||||
model_list: # use the same model_name for using the litellm router. LiteLLM will use the router between gpt-3.5-turbo
|
||||
- model_name: gpt-3.5-turbo # litellm will
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: sk-uj6F
|
||||
tpm: 20000 # [OPTIONAL] REPLACE with your openai tpm
|
||||
rpm: 3 # [OPTIONAL] REPLACE with your openai rpm
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: sk-Imn
|
||||
tpm: 20000 # [OPTIONAL] REPLACE with your openai tpm
|
||||
rpm: 3 # [OPTIONAL] REPLACE with your openai rpm
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openrouter/gpt-3.5-turbo
|
||||
- model_name: mistral-7b-instruct
|
||||
litellm_params:
|
||||
model: mistralai/mistral-7b-instruct
|
||||
|
||||
environment_variables:
|
||||
REDIS_HOST: localhost
|
||||
REDIS_PASSWORD:
|
||||
REDIS_PORT:
|
|
@ -0,0 +1,7 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
general_settings:
|
||||
otel: True # OpenTelemetry Logger this logs OTEL data to your collector
|
4
litellm/tests/example_config_yaml/simple_config.yaml
Normal file
4
litellm/tests/example_config_yaml/simple_config.yaml
Normal file
|
@ -0,0 +1,4 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
0
litellm/tests/langfuse.log
Normal file
0
litellm/tests/langfuse.log
Normal file
117
litellm/tests/test_amazing_vertex_completion.py
Normal file
117
litellm/tests/test_amazing_vertex_completion.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + '/vertex_key.json'
|
||||
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, 'r') as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Update the service_account_key_data with environment variables
|
||||
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||
private_key = private_key.replace("\\n", "\n")
|
||||
service_account_key_data["private_key_id"] = private_key_id
|
||||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
|
||||
# Write the updated content to the temporary file
|
||||
json.dump(service_account_key_data, temp_file, indent=2)
|
||||
|
||||
|
||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
||||
|
||||
|
||||
def test_vertex_ai():
|
||||
import random
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
litellm.set_verbose=False
|
||||
litellm.vertex_project = "hardy-device-386718"
|
||||
|
||||
test_models = random.sample(test_models, 4)
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}])
|
||||
print("\nModel Response", response)
|
||||
print(response)
|
||||
assert type(response.choices[0].message.content) == str
|
||||
assert len(response.choices[0].message.content) > 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai()
|
||||
|
||||
def test_vertex_ai_stream():
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose=False
|
||||
litellm.vertex_project = "hardy-device-386718"
|
||||
import random
|
||||
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = random.sample(test_models, 4)
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True)
|
||||
completed_str = ""
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
print("\n content", content)
|
||||
completed_str += content
|
||||
assert type(content) == str
|
||||
# pass
|
||||
assert len(completed_str) > 4
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai_stream()
|
|
@ -19,6 +19,13 @@ import random
|
|||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||
# comment
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
def generate_random_word(length=4):
|
||||
letters = string.ascii_lowercase
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
|
||||
messages = [{"role": "user", "content": "who is ishaan 5222"}]
|
||||
def test_caching_v2(): # test in memory cache
|
||||
try:
|
||||
|
@ -28,6 +35,8 @@ def test_caching_v2(): # test in memory cache
|
|||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
|
@ -51,6 +60,8 @@ def test_caching_with_models_v2():
|
|||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||
# if models are different, it should not return cached response
|
||||
print(f"response2: {response2}")
|
||||
|
@ -84,13 +95,15 @@ def test_embedding_caching():
|
|||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||
print(f"embedding1: {embedding1}")
|
||||
print(f"embedding2: {embedding2}")
|
||||
pytest.fail("Error occurred: Embedding caching failed")
|
||||
|
||||
test_embedding_caching()
|
||||
# test_embedding_caching()
|
||||
|
||||
|
||||
def test_embedding_caching_azure():
|
||||
|
@ -138,6 +151,8 @@ def test_embedding_caching_azure():
|
|||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||
print(f"embedding1: {embedding1}")
|
||||
|
@ -158,9 +173,9 @@ def test_redis_cache_completion():
|
|||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
print("test2 for caching")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=10, seed=1222)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=10, seed=1222)
|
||||
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=1)
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
|
||||
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5)
|
||||
response4 = completion(model="command-nightly", messages=messages, caching=True)
|
||||
|
||||
print("\nresponse 1", response1)
|
||||
|
@ -168,6 +183,8 @@ def test_redis_cache_completion():
|
|||
print("\nresponse 3", response3)
|
||||
print("\nresponse 4", response4)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
"""
|
||||
1 & 2 should be exactly the same
|
||||
|
@ -190,11 +207,132 @@ def test_redis_cache_completion():
|
|||
print(f"response4: {response4}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
|
||||
test_redis_cache_completion()
|
||||
# test_redis_cache_completion()
|
||||
|
||||
def test_redis_cache_completion_stream():
|
||||
try:
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.callbacks = []
|
||||
litellm.set_verbose = True
|
||||
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
print("test for caching, streaming + completion")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
|
||||
response_1_content = ""
|
||||
for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
time.sleep(0.5)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
|
||||
response_2_content = ""
|
||||
for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.success_callback = []
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
litellm.success_callback = []
|
||||
raise e
|
||||
"""
|
||||
|
||||
1 & 2 should be exactly the same
|
||||
"""
|
||||
# test_redis_cache_completion_stream()
|
||||
|
||||
|
||||
def test_redis_cache_acompletion_stream():
|
||||
import asyncio
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
random_word = generate_random_word()
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
print("test for caching, streaming + completion")
|
||||
response_1_content = ""
|
||||
response_2_content = ""
|
||||
|
||||
async def call1():
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
asyncio.run(call1())
|
||||
time.sleep(0.5)
|
||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
async def call2():
|
||||
nonlocal response_2_content
|
||||
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print(response_2_content)
|
||||
asyncio.run(call2())
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
# test_redis_cache_acompletion_stream()
|
||||
|
||||
def test_redis_cache_acompletion_stream_bedrock():
|
||||
import asyncio
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
random_word = generate_random_word()
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
print("test for caching, streaming + completion")
|
||||
response_1_content = ""
|
||||
response_2_content = ""
|
||||
|
||||
async def call1():
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
asyncio.run(call1())
|
||||
time.sleep(0.5)
|
||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
async def call2():
|
||||
nonlocal response_2_content
|
||||
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print(response_2_content)
|
||||
asyncio.run(call2())
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
# test_redis_cache_acompletion_stream_bedrock()
|
||||
# redis cache with custom keys
|
||||
def custom_get_cache_key(*args, **kwargs):
|
||||
# return key to use for your cache:
|
||||
# return key to use for your cache:
|
||||
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", ""))
|
||||
return key
|
||||
|
||||
|
@ -228,9 +366,50 @@ def test_custom_redis_cache_with_key():
|
|||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||
pytest.fail(f"Error occurred:")
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
# test_custom_redis_cache_with_key()
|
||||
|
||||
|
||||
def test_custom_redis_cache_params():
|
||||
# test if we can init redis with **kwargs
|
||||
try:
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ['REDIS_HOST'],
|
||||
port=os.environ['REDIS_PORT'],
|
||||
password=os.environ['REDIS_PASSWORD'],
|
||||
db = 0,
|
||||
ssl=True,
|
||||
ssl_certfile="./redis_user.crt",
|
||||
ssl_keyfile="./redis_user_private.key",
|
||||
ssl_ca_certs="./redis_ca.pem",
|
||||
)
|
||||
|
||||
print(litellm.cache.cache.redis_client)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred:", e)
|
||||
|
||||
|
||||
def test_get_cache_key():
|
||||
from litellm.caching import Cache
|
||||
try:
|
||||
cache_instance = Cache()
|
||||
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||
)
|
||||
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred:", e)
|
||||
|
||||
# test_get_cache_key()
|
||||
|
||||
# test_custom_redis_cache_params()
|
||||
|
||||
# def test_redis_cache_with_ttl():
|
||||
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
# sample_model_response_object_str = """{
|
||||
|
|
74
litellm/tests/test_caching_ssl.py
Normal file
74
litellm/tests/test_caching_ssl.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
#### What this tests ####
|
||||
# This tests using caching w/ litellm which requires SSL=True
|
||||
|
||||
import sys, os
|
||||
import time
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, Router
|
||||
from litellm.caching import Cache
|
||||
|
||||
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
|
||||
def test_caching_v2(): # test in memory cache
|
||||
try:
|
||||
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_caching_v2()
|
||||
|
||||
|
||||
def test_caching_router():
|
||||
"""
|
||||
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
|
||||
"""
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
|
||||
router = Router(model_list=model_list,
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=False,
|
||||
num_retries=1) # type: ignore
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content']
|
||||
except Exception as e:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_caching_router()
|
|
@ -7,25 +7,33 @@ import os, io
|
|||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
def logger_fn(user_model_dict):
|
||||
print(f"user_model_dict: {user_model_dict}")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_callbacks():
|
||||
print("\npytest fixture - resetting callbacks")
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm.callbacks = []
|
||||
|
||||
def test_completion_custom_provider_model_name():
|
||||
try:
|
||||
litellm.cache = None
|
||||
response = completion(
|
||||
model="together_ai/togethercomputer/llama-2-70b-chat",
|
||||
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
|
@ -53,7 +61,7 @@ def test_completion_claude():
|
|||
print(response)
|
||||
print(response.usage)
|
||||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
@ -63,6 +71,7 @@ def test_completion_claude():
|
|||
def test_completion_claude2_1():
|
||||
try:
|
||||
print("claude2.1 test request")
|
||||
messages=[{'role': 'system', 'content': 'Your goal is generate a joke on the topic user gives'}, {'role': 'assistant', 'content': 'Hi, how can i assist you today?'}, {'role': 'user', 'content': 'Generate a 3 liner joke for me'}]
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="claude-2.1",
|
||||
|
@ -131,6 +140,7 @@ def test_completion_gpt4_turbo():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_gpt4_turbo()
|
||||
|
||||
@pytest.mark.skip(reason="this test is flaky")
|
||||
def test_completion_gpt4_vision():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
|
@ -284,7 +294,7 @@ def hf_test_completion_tgi():
|
|||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# hf_test_completion_tgi()
|
||||
hf_test_completion_tgi()
|
||||
|
||||
# ################### Hugging Face Conversational models ########################
|
||||
# def hf_test_completion_conv():
|
||||
|
@ -442,9 +452,46 @@ def test_completion_text_openai():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_text_openai()
|
||||
|
||||
def custom_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response, # response from completion
|
||||
start_time, end_time # start/end time
|
||||
):
|
||||
# Your custom code here
|
||||
try:
|
||||
print("LITELLM: in custom callback function")
|
||||
print("\nkwargs\n", kwargs)
|
||||
model = kwargs["model"]
|
||||
messages = kwargs["messages"]
|
||||
user = kwargs.get("user")
|
||||
|
||||
#################################################
|
||||
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
Messages: {messages},
|
||||
User: {user},
|
||||
Seed: {kwargs["seed"]},
|
||||
temperature: {kwargs["temperature"]},
|
||||
"""
|
||||
)
|
||||
|
||||
assert kwargs["user"] == "ishaans app"
|
||||
assert kwargs["model"] == "gpt-3.5-turbo-1106"
|
||||
assert kwargs["seed"] == 12
|
||||
assert kwargs["temperature"] == 0.5
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_openai_with_optional_params():
|
||||
# [Proxy PROD TEST] WARNING: DO NOT DELETE THIS TEST
|
||||
# assert that `user` gets passed to the completion call
|
||||
# Note: This tests that we actually send the optional params to the completion call
|
||||
# We use custom callbacks to test this
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.success_callback = [custom_callback]
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
messages=[
|
||||
|
@ -458,11 +505,13 @@ def test_completion_openai_with_optional_params():
|
|||
seed=12,
|
||||
response_format={ "type": "json_object" },
|
||||
logit_bias=None,
|
||||
user = "ishaans app"
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
|
||||
print(response)
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
litellm.success_callback = [] # unset callbacks
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -569,7 +618,7 @@ def test_completion_azure_key_completion_arg():
|
|||
os.environ.pop("AZURE_API_KEY", None)
|
||||
try:
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
response = completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
|
@ -654,11 +703,12 @@ def test_completion_azure():
|
|||
print(response)
|
||||
|
||||
cost = completion_cost(completion_response=response)
|
||||
assert cost > 0.0
|
||||
print("Cost for azure completion request", cost)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_azure()
|
||||
test_completion_azure()
|
||||
|
||||
def test_azure_openai_ad_token():
|
||||
# this tests if the azure ad token is set in the request header
|
||||
|
@ -960,12 +1010,18 @@ def test_replicate_custom_prompt_dict():
|
|||
|
||||
######## Test TogetherAI ########
|
||||
def test_completion_together_ai():
|
||||
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
|
||||
model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct"
|
||||
try:
|
||||
messages =[
|
||||
{"role": "user", "content": "Who are you"},
|
||||
{"role": "assistant", "content": "I am your helpful assistant."},
|
||||
{"role": "user", "content": "Tell me a joke"},
|
||||
]
|
||||
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
cost = completion_cost(completion_response=response)
|
||||
assert cost > 0.0
|
||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
@ -975,15 +1031,17 @@ def test_customprompt_together_ai():
|
|||
try:
|
||||
litellm.set_verbose = False
|
||||
litellm.num_retries = 0
|
||||
print("in test_customprompt_together_ai")
|
||||
print(litellm.success_callback)
|
||||
print(litellm._async_success_callback)
|
||||
response = completion(
|
||||
model="together_ai/togethercomputer/llama-2-70b-chat",
|
||||
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||
)
|
||||
print(response)
|
||||
except litellm.exceptions.Timeout as e:
|
||||
print(f"Timeout Error")
|
||||
litellm.num_retries = 3 # reset retries
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"ERROR TYPE {type(e)}")
|
||||
|
@ -996,7 +1054,7 @@ def test_completion_sagemaker():
|
|||
print("testing sagemaker")
|
||||
litellm.set_verbose=True
|
||||
response = completion(
|
||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
|
@ -1009,18 +1067,40 @@ def test_completion_sagemaker():
|
|||
|
||||
def test_completion_chat_sagemaker():
|
||||
try:
|
||||
print("testing sagemaker")
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
litellm.set_verbose=True
|
||||
response = completion(
|
||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f",
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
messages=messages,
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
stream=True,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
# Add any assertions here to check the response
|
||||
complete_response = ""
|
||||
for chunk in response:
|
||||
complete_response += chunk.choices[0].delta.content or ""
|
||||
print(f"complete_response: {complete_response}")
|
||||
assert len(complete_response) > 0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_chat_sagemaker()
|
||||
|
||||
def test_completion_chat_sagemaker_mistral():
|
||||
try:
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
||||
response = completion(
|
||||
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct",
|
||||
messages=messages,
|
||||
max_tokens=100,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {str(e)}")
|
||||
|
||||
# test_completion_chat_sagemaker_mistral()
|
||||
def test_completion_bedrock_titan():
|
||||
try:
|
||||
response = completion(
|
||||
|
@ -1251,43 +1331,6 @@ def test_completion_bedrock_claude_completion_auth():
|
|||
|
||||
# test_completion_custom_api_base()
|
||||
|
||||
# def test_vertex_ai():
|
||||
# test_models = ["codechat-bison"] + litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
# # test_models = ["chat-bison"]
|
||||
# for model in test_models:
|
||||
# try:
|
||||
# if model in ["code-gecko@001", "code-gecko@latest"]:
|
||||
# # our account does not have access to this model
|
||||
# continue
|
||||
# print("making request", model)
|
||||
# response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}])
|
||||
# print(response)
|
||||
|
||||
# print(response.usage.completion_tokens)
|
||||
# print(response['usage']['completion_tokens'])
|
||||
# assert type(response.choices[0].message.content) == str
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai()
|
||||
|
||||
# def test_vertex_ai_stream():
|
||||
# litellm.set_verbose=False
|
||||
# test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
# for model in test_models:
|
||||
# try:
|
||||
# if model in ["code-gecko@001", "code-gecko@latest"]:
|
||||
# # our account does not have access to this model
|
||||
# continue
|
||||
# print("making request", model)
|
||||
# response = completion(model=model, messages=[{"role": "user", "content": "write 100 line code code for saying hi"}], stream=True)
|
||||
# for chunk in response:
|
||||
# print(chunk)
|
||||
# # pass
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai_stream()
|
||||
|
||||
|
||||
def test_completion_with_fallbacks():
|
||||
print(f"RUNNING TEST COMPLETION WITH FALLBACKS - test_completion_with_fallbacks")
|
||||
fallbacks = ["gpt-3.5-turbo", "gpt-3.5-turbo", "command-nightly"]
|
||||
|
@ -1337,7 +1380,7 @@ def test_azure_cloudflare_api():
|
|||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
test_azure_cloudflare_api()
|
||||
# test_azure_cloudflare_api()
|
||||
|
||||
def test_completion_anyscale_2():
|
||||
try:
|
||||
|
@ -1567,7 +1610,7 @@ def test_completion_together_ai_stream():
|
|||
messages = [{ "content": user_message,"role": "user"}]
|
||||
try:
|
||||
response = completion(
|
||||
model="together_ai/togethercomputer/llama-2-70b-chat",
|
||||
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages, stream=True,
|
||||
max_tokens=5
|
||||
)
|
||||
|
|
14
litellm/tests/test_configs/custom_auth.py
Normal file
14
litellm/tests/test_configs/custom_auth.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
print(f"api_key: {api_key}")
|
||||
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
113
litellm/tests/test_configs/custom_callbacks.py
Normal file
113
litellm/tests/test_configs/custom_callbacks.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import inspect
|
||||
import litellm
|
||||
|
||||
class testCustomCallbackProxy(CustomLogger):
|
||||
def __init__(self):
|
||||
self.success: bool = False # type: ignore
|
||||
self.failure: bool = False # type: ignore
|
||||
self.async_success: bool = False # type: ignore
|
||||
self.async_success_embedding: bool = False # type: ignore
|
||||
self.async_failure: bool = False # type: ignore
|
||||
self.async_failure_embedding: bool = False # type: ignore
|
||||
|
||||
self.async_completion_kwargs = None # type: ignore
|
||||
self.async_embedding_kwargs = None # type: ignore
|
||||
self.async_embedding_response = None # type: ignore
|
||||
|
||||
self.async_completion_kwargs_fail = None # type: ignore
|
||||
self.async_embedding_kwargs_fail = None # type: ignore
|
||||
|
||||
self.streaming_response_obj = None # type: ignore
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
print(f"{blue_color_code}Initialized LiteLLM custom logger")
|
||||
try:
|
||||
print(f"Logger Initialized with following methods:")
|
||||
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
|
||||
|
||||
# Pretty print the methods
|
||||
for method in methods:
|
||||
print(f" - {method}")
|
||||
print(f"{reset_color_code}")
|
||||
except:
|
||||
pass
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"Post-API Call")
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
self.success = True
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
self.failure = True
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async success")
|
||||
self.async_success = True
|
||||
print("Value of async success: ", self.async_success)
|
||||
print("\n kwargs: ", kwargs)
|
||||
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
|
||||
print("Got an embedding model", kwargs.get("model"))
|
||||
print("Setting embedding success to True")
|
||||
self.async_success_embedding = True
|
||||
print("Value of async success embedding: ", self.async_success_embedding)
|
||||
self.async_embedding_kwargs = kwargs
|
||||
self.async_embedding_response = response_obj
|
||||
if kwargs.get("stream") == True:
|
||||
self.streaming_response_obj = response_obj
|
||||
|
||||
|
||||
self.async_completion_kwargs = kwargs
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
messages = kwargs.get("messages", None)
|
||||
user = kwargs.get("user", None)
|
||||
|
||||
# Access litellm_params passed to litellm.completion(), example access `metadata`
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
|
||||
|
||||
# Calculate cost using litellm.completion_cost()
|
||||
cost = litellm.completion_cost(completion_response=response_obj)
|
||||
response = response_obj
|
||||
# tokens used in response
|
||||
usage = response_obj["usage"]
|
||||
|
||||
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
|
||||
|
||||
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
Messages: {messages},
|
||||
User: {user},
|
||||
Usage: {usage},
|
||||
Cost: {cost},
|
||||
Response: {response}
|
||||
Proxy Metadata: {metadata}
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
self.async_failure = True
|
||||
print("Value of async failure: ", self.async_failure)
|
||||
print("\n kwargs: ", kwargs)
|
||||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_failure_embedding = True
|
||||
self.async_embedding_kwargs_fail = kwargs
|
||||
|
||||
self.async_completion_kwargs_fail = kwargs
|
||||
|
||||
my_custom_logger = testCustomCallbackProxy()
|
28
litellm/tests/test_configs/test_config.yaml
Normal file
28
litellm/tests/test_configs/test_config.yaml
Normal file
|
@ -0,0 +1,28 @@
|
|||
general_settings:
|
||||
database_url: os.environ/PROXY_DATABASE_URL
|
||||
master_key: os.environ/PROXY_MASTER_KEY
|
||||
litellm_settings:
|
||||
drop_params: true
|
||||
success_callback: ["langfuse"]
|
||||
|
||||
model_list:
|
||||
- litellm_params:
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
model: azure/gpt-35-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com
|
||||
api_key: os.environ/AZURE_CANADA_API_KEY
|
||||
model: azure/gpt-35-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
api_base: https://openai-france-1234.openai.azure.com
|
||||
api_key: os.environ/AZURE_FRANCE_API_KEY
|
||||
model: azure/gpt-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
model_name: test_openai_models
|
11
litellm/tests/test_configs/test_config_custom_auth.yaml
Normal file
11
litellm/tests/test_configs/test_config_custom_auth.yaml
Normal file
|
@ -0,0 +1,11 @@
|
|||
model_list:
|
||||
- model_name: "openai-model"
|
||||
litellm_params:
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
set_verbose: True
|
||||
|
||||
general_settings:
|
||||
custom_auth: custom_auth.user_api_key_auth
|
75
litellm/tests/test_configs/test_config_no_auth.yaml
Normal file
75
litellm/tests/test_configs/test_config_no_auth.yaml
Normal file
|
@ -0,0 +1,75 @@
|
|||
model_list:
|
||||
- litellm_params:
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
model: azure/gpt-35-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com
|
||||
api_key: os.environ/AZURE_CANADA_API_KEY
|
||||
model: azure/gpt-35-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
api_base: https://openai-france-1234.openai.azure.com
|
||||
api_key: os.environ/AZURE_FRANCE_API_KEY
|
||||
model: azure/gpt-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 56f1bd94-3b54-4b67-9ea2-7c70e9a3a709
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 4d1ee26c-abca-450c-8744-8e87fd6755e9
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 00e19c0f-b63d-42bb-88e9-016fb0c60764
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 79fc75bf-8e1b-47d5-8d24-9365a854af03
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2023-07-01-preview
|
||||
model: azure/azure-embedding-model
|
||||
model_name: azure-embedding-model
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 55848c55-4162-40f9-a6e2-9a722b9ef404
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 34339b1e-e030-4bcc-a531-c48559f10ce4
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: f6f74e14-ac64-4403-9365-319e584dcdc5
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 9b1ef341-322c-410a-8992-903987fef439
|
||||
model_name: test_openai_models
|
26
litellm/tests/test_configs/test_custom_logger.yaml
Normal file
26
litellm/tests/test_configs/test_custom_logger.yaml
Normal file
|
@ -0,0 +1,26 @@
|
|||
model_list:
|
||||
- model_name: Azure OpenAI GPT-4 Canada
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
mode: chat
|
||||
input_cost_per_token: 0.0002
|
||||
id: gm
|
||||
- model_name: azure-embedding-model
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
mode: embedding
|
||||
input_cost_per_token: 0.002
|
||||
id: hello
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
set_verbose: True
|
||||
callbacks: custom_callbacks.my_custom_logger
|
579
litellm/tests/test_custom_callback_input.py
Normal file
579
litellm/tests/test_custom_callback_input.py
Normal file
|
@ -0,0 +1,579 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional, Literal, List, Union
|
||||
from litellm import completion, embedding
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
## 1: Pre-API-Call
|
||||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
|
||||
# Test models
|
||||
## 1. OpenAI
|
||||
## 2. Azure OpenAI
|
||||
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
||||
|
||||
# Test interfaces
|
||||
## 1. litellm.completion() + litellm.embeddings()
|
||||
## refer to test_custom_callback_input_router.py for the router + proxy tests
|
||||
|
||||
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("async_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, str, dict))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
|
||||
# COMPLETION
|
||||
## Test OpenAI + sync
|
||||
def test_chat_openai_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync openai"
|
||||
}])
|
||||
## test streaming
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_openai_stream()
|
||||
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
## test streaming
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_openai_stream())
|
||||
|
||||
## Test Azure + sync
|
||||
def test_chat_azure_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}])
|
||||
# test streaming
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
# test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_azure_stream()
|
||||
|
||||
## Test Azure + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}])
|
||||
## test streaming
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_azure_stream())
|
||||
|
||||
## Test Bedrock + sync
|
||||
def test_chat_bedrock_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync bedrock"
|
||||
}])
|
||||
# test streaming
|
||||
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync bedrock"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
# test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync bedrock"
|
||||
}],
|
||||
aws_region_name="my-bad-region",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_bedrock_stream()
|
||||
|
||||
## Test Bedrock + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_bedrock_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async bedrock"
|
||||
}])
|
||||
# test streaming
|
||||
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async bedrock"
|
||||
}],
|
||||
stream=True)
|
||||
print(f"response: {response}")
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async bedrock"
|
||||
}],
|
||||
aws_region_name="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_bedrock_stream())
|
||||
|
||||
# EMBEDDING
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_openai():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"])
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="text-embedding-ada-002",
|
||||
input=["good morning from litellm"],
|
||||
api_key="my-bad-key")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_openai())
|
||||
|
||||
## Test Azure + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"])
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"],
|
||||
api_key="my-bad-key")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_azure())
|
||||
|
||||
## Test Bedrock + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_bedrock():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"],
|
||||
aws_region_name="my-bad-region")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_bedrock())
|
437
litellm/tests/test_custom_callback_router.py
Normal file
437
litellm/tests/test_custom_callback_router.py
Normal file
|
@ -0,0 +1,437 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional, Literal, List
|
||||
from litellm import Router
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
## 1: Pre-API-Call
|
||||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
## fallbacks
|
||||
## retries
|
||||
|
||||
# Test cases
|
||||
## 1. Simple Azure OpenAI acompletion + streaming call
|
||||
## 2. Simple Azure OpenAI aembedding call
|
||||
## 3. Azure OpenAI acompletion + streaming call with retries
|
||||
## 4. Azure OpenAI aembedding call with retries
|
||||
## 5. Azure OpenAI acompletion + streaming call with fallbacks
|
||||
## 6. Azure OpenAI aembedding call with fallbacks
|
||||
|
||||
# Test interfaces
|
||||
## 1. router.completion() + router.embeddings()
|
||||
## 2. proxy.completions + proxy.embeddings
|
||||
|
||||
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
print(f'received kwargs in pre-input: {kwargs}')
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
"""
|
||||
No-op.
|
||||
Not implemented yet.
|
||||
"""
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(f"received original response: {kwargs['original_response']}")
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, str, dict))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
# Simple Azure OpenAI call
|
||||
## COMPLETION
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure():
|
||||
try:
|
||||
customHandler_completion_azure_router = CompletionCustomHandler()
|
||||
customHandler_streaming_azure_router = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_completion_azure_router]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
await asyncio.sleep(2)
|
||||
assert len(customHandler_completion_azure_router.errors) == 0
|
||||
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
|
||||
# streaming
|
||||
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||
router2 = Router(model_list=model_list) # type: ignore
|
||||
response = await router2.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
print(f"async azure router chunk: {chunk}")
|
||||
continue
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
|
||||
assert len(customHandler_streaming_azure_router.errors) == 0
|
||||
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
|
||||
# failure
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
print(f"response in router3 acompletion: {response}")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert "async_failure" in customHandler_failure.states
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_chat_azure())
|
||||
## EMBEDDING
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.aembedding(model="azure-embedding-model",
|
||||
input=["hello from litellm!"])
|
||||
await asyncio.sleep(2)
|
||||
assert len(customHandler.errors) == 0
|
||||
assert len(customHandler.states) == 3 # pre, post, success
|
||||
# failure
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.aembedding(model="azure-embedding-model",
|
||||
input=["hello from litellm!"])
|
||||
print(f"response in router3 aembedding: {response}")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert "async_failure" in customHandler_failure.states
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_embedding_azure())
|
||||
# Azure OpenAI call w/ Fallbacks
|
||||
## COMPLETION
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure_with_fallbacks():
|
||||
try:
|
||||
customHandler_fallbacks = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_fallbacks]
|
||||
# with fallbacks
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
await asyncio.sleep(2)
|
||||
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
|
||||
assert len(customHandler_fallbacks.errors) == 0
|
||||
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_chat_azure_with_fallbacks())
|
|
@ -1,5 +1,5 @@
|
|||
### What this tests ####
|
||||
import sys, os, time
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
|
@ -8,8 +8,24 @@ import litellm
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
success: bool = False
|
||||
failure: bool = False
|
||||
complete_streaming_response_in_callback = ""
|
||||
def __init__(self):
|
||||
self.success: bool = False # type: ignore
|
||||
self.failure: bool = False # type: ignore
|
||||
self.async_success: bool = False # type: ignore
|
||||
self.async_success_embedding: bool = False # type: ignore
|
||||
self.async_failure: bool = False # type: ignore
|
||||
self.async_failure_embedding: bool = False # type: ignore
|
||||
|
||||
self.async_completion_kwargs = None # type: ignore
|
||||
self.async_embedding_kwargs = None # type: ignore
|
||||
self.async_embedding_response = None # type: ignore
|
||||
|
||||
self.async_completion_kwargs_fail = None # type: ignore
|
||||
self.async_embedding_kwargs_fail = None # type: ignore
|
||||
|
||||
self.stream_collected_response = None # type: ignore
|
||||
self.sync_stream_collected_response = None # type: ignore
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
|
@ -23,29 +39,78 @@ class MyCustomHandler(CustomLogger):
|
|||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
self.success = True
|
||||
if kwargs.get("stream") == True:
|
||||
self.sync_stream_collected_response = response_obj
|
||||
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
self.failure = True
|
||||
|
||||
# def test_chat_openai():
|
||||
# try:
|
||||
# customHandler = MyCustomHandler()
|
||||
# litellm.callbacks = [customHandler]
|
||||
# response = completion(model="gpt-3.5-turbo",
|
||||
# messages=[{
|
||||
# "role": "user",
|
||||
# "content": "Hi 👋 - i'm openai"
|
||||
# }],
|
||||
# stream=True)
|
||||
# time.sleep(1)
|
||||
# assert customHandler.success == True
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"An error occurred - {str(e)}")
|
||||
# pass
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async success")
|
||||
self.async_success = True
|
||||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_success_embedding = True
|
||||
self.async_embedding_kwargs = kwargs
|
||||
self.async_embedding_response = response_obj
|
||||
if kwargs.get("stream") == True:
|
||||
self.stream_collected_response = response_obj
|
||||
self.async_completion_kwargs = kwargs
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
self.async_failure = True
|
||||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_failure_embedding = True
|
||||
self.async_embedding_kwargs_fail = kwargs
|
||||
|
||||
self.async_completion_kwargs_fail = kwargs
|
||||
|
||||
class TmpFunction:
|
||||
complete_streaming_response_in_callback = ""
|
||||
async_success: bool = False
|
||||
async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
|
||||
print(f"ON ASYNC LOGGING")
|
||||
self.async_success = True
|
||||
self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
|
||||
|
||||
|
||||
# test_chat_openai()
|
||||
def test_async_chat_openai_stream():
|
||||
try:
|
||||
tmp_function = TmpFunction()
|
||||
# litellm.set_verbose = True
|
||||
litellm.success_callback = [tmp_function.async_test_logging_fn]
|
||||
complete_streaming_response = ""
|
||||
async def call_gpt():
|
||||
nonlocal complete_streaming_response
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
|
||||
print(complete_streaming_response)
|
||||
asyncio.run(call_gpt())
|
||||
complete_streaming_response = complete_streaming_response.strip("'")
|
||||
print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
|
||||
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
|
||||
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
|
||||
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
|
||||
print(f"complete_streaming_response: {complete_streaming_response}")
|
||||
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
|
||||
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
|
||||
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
|
||||
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
|
||||
response2 = complete_streaming_response
|
||||
assert [ord(c) for c in response1] == [ord(c) for c in response2]
|
||||
assert tmp_function.async_success == True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
test_async_chat_openai_stream()
|
||||
|
||||
def test_completion_azure_stream_moderation_failure():
|
||||
try:
|
||||
|
@ -72,75 +137,192 @@ def test_completion_azure_stream_moderation_failure():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_azure_stream_moderation_failure()
|
||||
|
||||
def test_async_custom_handler_stream():
|
||||
try:
|
||||
# [PROD Test] - Do not DELETE
|
||||
# checks if the model response available in the async + stream callbacks is equal to the received response
|
||||
customHandler2 = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler2]
|
||||
litellm.set_verbose = False
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write 1 sentence about litellm being amazing",
|
||||
},
|
||||
]
|
||||
complete_streaming_response = ""
|
||||
async def test_1():
|
||||
nonlocal complete_streaming_response
|
||||
response = await litellm.acompletion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
stream=True
|
||||
)
|
||||
async for chunk in response:
|
||||
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
|
||||
print(complete_streaming_response)
|
||||
|
||||
asyncio.run(test_1())
|
||||
|
||||
response_in_success_handler = customHandler2.stream_collected_response
|
||||
response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"]
|
||||
print("\n\n")
|
||||
print("response_in_success_handler: ", response_in_success_handler)
|
||||
print("complete_streaming_response: ", complete_streaming_response)
|
||||
assert response_in_success_handler == complete_streaming_response
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_async_custom_handler_stream()
|
||||
|
||||
|
||||
# def custom_callback(
|
||||
# kwargs,
|
||||
# completion_response,
|
||||
# start_time,
|
||||
# end_time,
|
||||
# ):
|
||||
# print(
|
||||
# "in custom callback func"
|
||||
# )
|
||||
# print("kwargs", kwargs)
|
||||
# print(completion_response)
|
||||
# print(start_time)
|
||||
# print(end_time)
|
||||
# if "complete_streaming_response" in kwargs:
|
||||
# print("\n\n complete response\n\n")
|
||||
# complete_streaming_response = kwargs["complete_streaming_response"]
|
||||
# print(kwargs["complete_streaming_response"])
|
||||
# usage = complete_streaming_response["usage"]
|
||||
# print("usage", usage)
|
||||
# def send_slack_alert(
|
||||
# kwargs,
|
||||
# completion_response,
|
||||
# start_time,
|
||||
# end_time,
|
||||
# ):
|
||||
# print(
|
||||
# "in custom slack callback func"
|
||||
# )
|
||||
# import requests
|
||||
# import json
|
||||
def test_azure_completion_stream():
|
||||
# [PROD Test] - Do not DELETE
|
||||
# test if completion() + sync custom logger get the same complete stream response
|
||||
try:
|
||||
# checks if the model response available in the async + stream callbacks is equal to the received response
|
||||
customHandler2 = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler2]
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write 1 sentence about litellm being amazing",
|
||||
},
|
||||
]
|
||||
complete_streaming_response = ""
|
||||
|
||||
# # Define the Slack webhook URL
|
||||
# slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>"
|
||||
response = litellm.completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
stream=True
|
||||
)
|
||||
for chunk in response:
|
||||
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
|
||||
print(complete_streaming_response)
|
||||
|
||||
time.sleep(0.5) # wait 1/2 second before checking callbacks
|
||||
response_in_success_handler = customHandler2.sync_stream_collected_response
|
||||
response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"]
|
||||
print("\n\n")
|
||||
print("response_in_success_handler: ", response_in_success_handler)
|
||||
print("complete_streaming_response: ", complete_streaming_response)
|
||||
assert response_in_success_handler == complete_streaming_response
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # Define the text payload, send data available in litellm custom_callbacks
|
||||
# text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)}
|
||||
# """
|
||||
# payload = {
|
||||
# "text": text_payload
|
||||
# }
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_completion():
|
||||
try:
|
||||
customHandler_success = MyCustomHandler()
|
||||
customHandler_failure = MyCustomHandler()
|
||||
# success
|
||||
assert customHandler_success.async_success == False
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "hello from litellm test",
|
||||
}]
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
assert customHandler_success.async_success == True, "async success is not set to True even after success"
|
||||
assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
||||
# failure
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how do i kill someone",
|
||||
},
|
||||
]
|
||||
|
||||
# # Set the headers
|
||||
# headers = {
|
||||
# "Content-type": "application/json"
|
||||
# }
|
||||
assert customHandler_failure.async_failure == False
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
api_key="my-bad-key",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure"
|
||||
assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
|
||||
assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
litellm.callbacks = []
|
||||
print("Passed setting async failure")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_custom_handler_completion())
|
||||
|
||||
# # Make the POST request
|
||||
# response = requests.post(slack_webhook_url, json=payload, headers=headers)
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_embedding():
|
||||
try:
|
||||
customHandler_embedding = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_embedding]
|
||||
# success
|
||||
assert customHandler_embedding.async_success_embedding == False
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hello world"],
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success"
|
||||
assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
|
||||
assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2
|
||||
print("Passed setting async success: Embedding")
|
||||
# failure
|
||||
assert customHandler_embedding.async_failure_embedding == False
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hello world"],
|
||||
api_key="my-bad-key",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
|
||||
assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
|
||||
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
asyncio.run(test_async_custom_handler_embedding())
|
||||
from litellm import Cache
|
||||
def test_redis_cache_completion_stream():
|
||||
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
|
||||
import random
|
||||
try:
|
||||
print("\nrunning test_redis_cache_completion_stream")
|
||||
litellm.set_verbose = True
|
||||
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
print("test for caching, streaming + completion")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
|
||||
response_1_content = ""
|
||||
for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
|
||||
# # Check the response status
|
||||
# if response.status_code == 200:
|
||||
# print("Message sent successfully to Slack!")
|
||||
# else:
|
||||
# print(f"Failed to send message to Slack. Status code: {response.status_code}")
|
||||
# print(response.json())
|
||||
|
||||
# def get_transformed_inputs(
|
||||
# kwargs,
|
||||
# ):
|
||||
# params_to_model = kwargs["additional_args"]["complete_input_dict"]
|
||||
# print("params to model", params_to_model)
|
||||
|
||||
# litellm.success_callback = [custom_callback, send_slack_alert]
|
||||
# litellm.failure_callback = [send_slack_alert]
|
||||
|
||||
|
||||
# litellm.set_verbose = False
|
||||
|
||||
# # litellm.input_callback = [get_transformed_inputs]
|
||||
time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
|
||||
response_2_content = ""
|
||||
for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.cache = None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
litellm.success_callback = []
|
||||
raise e
|
||||
test_redis_cache_completion_stream()
|
|
@ -151,16 +151,36 @@ def test_cohere_embedding3():
|
|||
|
||||
# test_cohere_embedding3()
|
||||
|
||||
def test_bedrock_embedding():
|
||||
def test_bedrock_embedding_titan():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
response = embedding(
|
||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
||||
"lets test a second string for good measure"]
|
||||
)
|
||||
print(f"response:", response)
|
||||
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
|
||||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_bedrock_embedding()
|
||||
# test_bedrock_embedding_titan()
|
||||
|
||||
def test_bedrock_embedding_cohere():
|
||||
try:
|
||||
litellm.set_verbose=False
|
||||
response = embedding(
|
||||
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
|
||||
aws_region_name="os.environ/AWS_REGION_NAME_2"
|
||||
)
|
||||
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
|
||||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||
# print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_bedrock_embedding_cohere()
|
||||
|
||||
# comment out hf tests - since hf endpoints are unstable
|
||||
def test_hf_embedding():
|
||||
|
@ -214,7 +234,14 @@ def test_aembedding_azure():
|
|||
|
||||
# test_aembedding_azure()
|
||||
|
||||
# def test_custom_openai_embedding():
|
||||
def test_sagemaker_embeddings():
|
||||
try:
|
||||
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_sagemaker_embeddings()
|
||||
# def local_proxy_embeddings():
|
||||
# litellm.set_verbose=True
|
||||
# response = embedding(
|
||||
# model="openai/custom_embedding",
|
||||
|
@ -222,4 +249,5 @@ def test_aembedding_azure():
|
|||
# api_base="http://0.0.0.0:8000/"
|
||||
# )
|
||||
# print(response)
|
||||
# test_custom_openai_embedding()
|
||||
|
||||
# local_proxy_embeddings()
|
||||
|
|
|
@ -189,6 +189,7 @@ def test_completion_azure_exception():
|
|||
}
|
||||
],
|
||||
)
|
||||
os.environ["AZURE_API_KEY"] = old_azure_key
|
||||
print(f"response: {response}")
|
||||
print(response)
|
||||
except openai.AuthenticationError as e:
|
||||
|
|
19
litellm/tests/test_get_llm_provider.py
Normal file
19
litellm/tests/test_get_llm_provider.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
def test_get_llm_provider():
|
||||
_, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1")
|
||||
|
||||
assert response == "bedrock"
|
||||
|
||||
test_get_llm_provider()
|
|
@ -9,33 +9,107 @@ from litellm import completion
|
|||
import litellm
|
||||
litellm.num_retries = 3
|
||||
litellm.success_callback = ["langfuse"]
|
||||
# litellm.set_verbose = True
|
||||
os.environ["LANGFUSE_DEBUG"] = "True"
|
||||
import time
|
||||
import pytest
|
||||
|
||||
def search_logs(log_file_path):
|
||||
"""
|
||||
Searches the given log file for logs containing the "/api/public" string.
|
||||
|
||||
Parameters:
|
||||
- log_file_path (str): The path to the log file to be searched.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Exception: If there are any bad logs found in the log file.
|
||||
"""
|
||||
import re
|
||||
print("\n searching logs")
|
||||
bad_logs = []
|
||||
good_logs = []
|
||||
all_logs = []
|
||||
try:
|
||||
with open(log_file_path, 'r') as log_file:
|
||||
lines = log_file.readlines()
|
||||
print(f"searching logslines: {lines}")
|
||||
for line in lines:
|
||||
all_logs.append(line.strip())
|
||||
if "/api/public" in line:
|
||||
print("Found log with /api/public:")
|
||||
print(line.strip())
|
||||
print("\n\n")
|
||||
match = re.search(r'receive_response_headers.complete return_value=\(b\'HTTP/1.1\', (\d+),', line)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
if status_code != 200 and status_code != 201:
|
||||
print("got a BAD log")
|
||||
bad_logs.append(line.strip())
|
||||
else:
|
||||
|
||||
good_logs.append(line.strip())
|
||||
print("\nBad Logs")
|
||||
print(bad_logs)
|
||||
if len(bad_logs)>0:
|
||||
raise Exception(f"bad logs, Bad logs = {bad_logs}")
|
||||
|
||||
print("\nGood Logs")
|
||||
print(good_logs)
|
||||
if len(good_logs) <= 0:
|
||||
raise Exception(f"There were no Good Logs from Langfuse. No logs with /api/public status 200. \nAll logs:{all_logs}")
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def pre_langfuse_setup():
|
||||
"""
|
||||
Set up the logging for the 'pre_langfuse_setup' function.
|
||||
"""
|
||||
# sends logs to langfuse.log
|
||||
import logging
|
||||
# Configure the logging to write to a file
|
||||
logging.basicConfig(filename="langfuse.log", level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Add a FileHandler to the logger
|
||||
file_handler = logging.FileHandler("langfuse.log", mode='w')
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(file_handler)
|
||||
return
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_async():
|
||||
try:
|
||||
pre_langfuse_setup()
|
||||
litellm.set_verbose = True
|
||||
async def _test_langfuse():
|
||||
return await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content":"This is a test"}],
|
||||
max_tokens=1000,
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
timeout=5,
|
||||
)
|
||||
response = asyncio.run(_test_langfuse())
|
||||
print(f"response: {response}")
|
||||
|
||||
# time.sleep(2)
|
||||
# # check langfuse.log to see if there was a failed response
|
||||
# search_logs("langfuse.log")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {e}")
|
||||
|
||||
# test_langfuse_logging_async()
|
||||
test_langfuse_logging_async()
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging():
|
||||
try:
|
||||
# litellm.set_verbose = True
|
||||
pre_langfuse_setup()
|
||||
litellm.set_verbose = True
|
||||
response = completion(model="claude-instant-1.2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
|
@ -43,17 +117,20 @@ def test_langfuse_logging():
|
|||
}],
|
||||
max_tokens=10,
|
||||
temperature=0.2,
|
||||
metadata={"langfuse/key": "foo"}
|
||||
)
|
||||
print(response)
|
||||
# time.sleep(5)
|
||||
# # check langfuse.log to see if there was a failed response
|
||||
# search_logs("langfuse.log")
|
||||
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An exception occurred - {e}")
|
||||
|
||||
test_langfuse_logging()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_stream():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
|
@ -77,6 +154,7 @@ def test_langfuse_logging_stream():
|
|||
|
||||
# test_langfuse_logging_stream()
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_custom_generation_name():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
|
@ -99,8 +177,8 @@ def test_langfuse_logging_custom_generation_name():
|
|||
pytest.fail(f"An exception occurred - {e}")
|
||||
print(e)
|
||||
|
||||
test_langfuse_logging_custom_generation_name()
|
||||
|
||||
# test_langfuse_logging_custom_generation_name()
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_function_calling():
|
||||
function1 = [
|
||||
{
|
||||
|
|
79
litellm/tests/test_least_busy_routing.py
Normal file
79
litellm/tests/test_least_busy_routing.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# #### What this tests ####
|
||||
# # This tests the router's ability to identify the least busy deployment
|
||||
|
||||
# #
|
||||
# # How is this achieved?
|
||||
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# # - use litellm.success + failure callbacks to log when a request completed
|
||||
# # - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
# import sys, os, asyncio, time
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
# load_dotenv()
|
||||
# import os
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# from litellm import Router
|
||||
# import litellm
|
||||
|
||||
# async def test_least_busy_routing():
|
||||
# model_list = [{
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-turbo",
|
||||
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
# "api_base": "https://openai-france-1234.openai.azure.com",
|
||||
# "rpm": 1440,
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }]
|
||||
# router = Router(model_list=model_list,
|
||||
# routing_strategy="least-busy",
|
||||
# set_verbose=False,
|
||||
# num_retries=3) # type: ignore
|
||||
|
||||
# async def call_azure_completion():
|
||||
# try:
|
||||
# response = await router.acompletion(
|
||||
# model="azure-model",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "hello this request will pass"
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# print("\n response", response)
|
||||
# return response
|
||||
# except:
|
||||
# return None
|
||||
|
||||
# n = 1000
|
||||
# start_time = time.time()
|
||||
# tasks = [call_azure_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# print(n, time.time() - start_time, len(successful_completions))
|
||||
|
||||
# asyncio.run(test_least_busy_routing())
|
|
@ -17,10 +17,10 @@ model_alias_map = {
|
|||
"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"
|
||||
}
|
||||
|
||||
litellm.model_alias_map = model_alias_map
|
||||
|
||||
def test_model_alias_map():
|
||||
try:
|
||||
litellm.model_alias_map = model_alias_map
|
||||
response = completion(
|
||||
"good-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
|
|
|
@ -395,7 +395,7 @@ def sagemaker_test_completion():
|
|||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
response_1 = litellm.completion(
|
||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}],
|
||||
max_tokens=100
|
||||
)
|
||||
|
@ -404,7 +404,7 @@ def sagemaker_test_completion():
|
|||
|
||||
# USE CONFIG TOKENS
|
||||
response_2 = litellm.completion(
|
||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}],
|
||||
)
|
||||
response_2_text = response_2.choices[0].message.content
|
||||
|
|
65
litellm/tests/test_proxy_custom_auth.py
Normal file
65
litellm/tests/test_proxy_custom_auth.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
|
||||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(scope="function")
|
||||
def client():
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||
cleanup_router_config_variables()
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
app = FastAPI()
|
||||
initialize(config=config_fp)
|
||||
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_custom_auth(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "openai-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 401
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
236
litellm/tests/test_proxy_custom_logger.py
Normal file
236
litellm/tests/test_proxy_custom_logger.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io, asyncio
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
import importlib, inspect
|
||||
|
||||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
||||
|
||||
# @app.on_event("startup")
|
||||
# async def wrapper_startup_event():
|
||||
# initialize(config=config_fp)
|
||||
|
||||
# Use the app fixture in your client fixture
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||
initialize(config=config_fp)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
|
||||
print("Testing proxy custom logger")
|
||||
|
||||
def test_embedding(client):
|
||||
try:
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
value = "custom_callbacks.my_custom_logger",
|
||||
config_file_path=python_file_path
|
||||
)
|
||||
print("id of initialized custom logger", id(my_custom_logger))
|
||||
litellm.callbacks = [my_custom_logger]
|
||||
# Your test data
|
||||
print("initialized proxy")
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
assert my_custom_logger.async_success_embedding == False
|
||||
|
||||
test_data = {
|
||||
"model": "azure-embedding-model",
|
||||
"input": ["hello"]
|
||||
}
|
||||
response = client.post("/embeddings", json=test_data, headers=headers)
|
||||
print("made request", response.status_code, response.text)
|
||||
print("vars my custom logger /embeddings", vars(my_custom_logger), "id", id(my_custom_logger))
|
||||
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
|
||||
kwargs = my_custom_logger.async_embedding_kwargs
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||
assert metadata is not None
|
||||
assert "user_api_key" in metadata
|
||||
assert "headers" in metadata
|
||||
proxy_server_request = litellm_params.get("proxy_server_request")
|
||||
model_info = litellm_params.get("model_info")
|
||||
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
|
||||
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
print("Passed Embedding custom logger on proxy!")
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
def test_chat_completion(client):
|
||||
try:
|
||||
# Your test data
|
||||
|
||||
print("initialized proxy")
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
value = "custom_callbacks.my_custom_logger",
|
||||
config_file_path=python_file_path
|
||||
)
|
||||
|
||||
print("id of initialized custom logger", id(my_custom_logger))
|
||||
|
||||
litellm.callbacks = [my_custom_logger]
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
|
||||
print("LiteLLM Callbacks", litellm.callbacks)
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
assert my_custom_logger.async_success == False
|
||||
|
||||
test_data = {
|
||||
"model": "Azure OpenAI GPT-4 Canada",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write a litellm poem"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
print("made request", response.status_code, response.text)
|
||||
print("LiteLLM Callbacks", litellm.callbacks)
|
||||
asyncio.sleep(1) # sleep while waiting for callback to run
|
||||
|
||||
print("my_custom_logger in /chat/completions", my_custom_logger, "id", id(my_custom_logger))
|
||||
print("vars my custom logger, ", vars(my_custom_logger))
|
||||
assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||
assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct
|
||||
print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs)
|
||||
litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params")
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||
assert metadata is not None
|
||||
assert "user_api_key" in metadata
|
||||
assert "headers" in metadata
|
||||
config_model_info = litellm_params.get("model_info")
|
||||
proxy_server_request_object = litellm_params.get("proxy_server_request")
|
||||
|
||||
assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'}
|
||||
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '123', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'write a litellm poem'}], 'max_tokens': 10}}
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
print("\nPassed /chat/completions with Custom Logger!")
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
def test_chat_completion_stream(client):
|
||||
try:
|
||||
# Your test data
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
value = "custom_callbacks.my_custom_logger",
|
||||
config_file_path=python_file_path
|
||||
)
|
||||
|
||||
print("id of initialized custom logger", id(my_custom_logger))
|
||||
|
||||
litellm.callbacks = [my_custom_logger]
|
||||
import json
|
||||
print("initialized proxy")
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
|
||||
print("LiteLLM Callbacks", litellm.callbacks)
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
|
||||
assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call
|
||||
|
||||
test_data = {
|
||||
"model": "Azure OpenAI GPT-4 Canada",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write 1 line poem about LiteLLM"
|
||||
},
|
||||
],
|
||||
"max_tokens": 40,
|
||||
"stream": True # streaming call
|
||||
}
|
||||
|
||||
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
print("made request", response.status_code, response.text)
|
||||
complete_response = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
# Process the streaming data line here
|
||||
print("\n\n Line", line)
|
||||
print(line)
|
||||
line = str(line)
|
||||
|
||||
json_data = line.replace('data: ', '')
|
||||
|
||||
# Parse the JSON string
|
||||
data = json.loads(json_data)
|
||||
|
||||
print("\n\n decode_data", data)
|
||||
|
||||
# Access the content of choices[0]['message']['content']
|
||||
content = data['choices'][0]['delta']['content'] or ""
|
||||
|
||||
# Process the content as needed
|
||||
print("Content:", content)
|
||||
|
||||
complete_response+= content
|
||||
|
||||
print("\n\nHERE is the complete streaming response string", complete_response)
|
||||
print("\n\nHERE IS the streaming Response from callback\n\n")
|
||||
print(my_custom_logger.streaming_response_obj)
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
streamed_response = my_custom_logger.streaming_response_obj
|
||||
assert complete_response == streamed_response["choices"][0]["message"]["content"]
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
61
litellm/tests/test_proxy_gunicorn.py
Normal file
61
litellm/tests/test_proxy_gunicorn.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
# #### What this tests ####
|
||||
# # Allow the user to easily run the local proxy server with Gunicorn
|
||||
# # LOCAL TESTING ONLY
|
||||
# import sys, os, subprocess
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
# load_dotenv()
|
||||
# import os, io
|
||||
|
||||
# # this file is to test litellm/proxy
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# import litellm
|
||||
|
||||
# ### LOCAL Proxy Server INIT ###
|
||||
# from litellm.proxy.proxy_server import save_worker_config # Replace with the actual module where your FastAPI router is defined
|
||||
# filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
# config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
# def get_openai_info():
|
||||
# return {
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# }
|
||||
|
||||
# def run_server(host="0.0.0.0",port=8008,num_workers=None):
|
||||
# if num_workers is None:
|
||||
# # Set it to min(8,cpu_count())
|
||||
# import multiprocessing
|
||||
# num_workers = min(4,multiprocessing.cpu_count())
|
||||
|
||||
# ### LOAD KEYS ###
|
||||
|
||||
# # Load the Azure keys. For now get them from openai-usage
|
||||
# azure_info = get_openai_info()
|
||||
# print(f"Azure info:{azure_info}")
|
||||
# os.environ["AZURE_API_KEY"] = azure_info['api_key']
|
||||
# os.environ["AZURE_API_BASE"] = azure_info['api_base']
|
||||
# os.environ["AZURE_API_VERSION"] = "2023-09-01-preview"
|
||||
|
||||
# ### SAVE CONFIG ###
|
||||
|
||||
# os.environ["WORKER_CONFIG"] = config_fp
|
||||
|
||||
# # In order for the app to behave well with signals, run it with gunicorn
|
||||
# # The first argument must be the "name of the command run"
|
||||
# cmd = f"gunicorn litellm.proxy.proxy_server:app --workers {num_workers} --worker-class uvicorn.workers.UvicornWorker --bind {host}:{port}"
|
||||
# cmd = cmd.split()
|
||||
# print(f"Running command: {cmd}")
|
||||
# import sys
|
||||
# sys.stdout.flush()
|
||||
# sys.stderr.flush()
|
||||
|
||||
# # Make sure to propage env variables
|
||||
# subprocess.run(cmd) # This line actually starts Gunicorn
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# run_server()
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue