Merge branch 'main' into litellm_proxy_support_all_providers

This commit is contained in:
Ishaan Jaff 2024-07-25 20:15:37 -07:00 committed by GitHub
commit 079a41fbe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1329 additions and 350 deletions

View file

@ -166,6 +166,10 @@ $ litellm --model huggingface/bigcode/starcoder
### Step 2: Make ChatCompletions Request to Proxy ### Step 2: Make ChatCompletions Request to Proxy
> [!IMPORTANT]
> 💡 [Use LiteLLM Proxy with Langchain (Python, JS), OpenAI SDK (Python, JS) Anthropic SDK, Mistral SDK, LlamaIndex, Instructor, Curl](https://docs.litellm.ai/docs/proxy/user_keys)
```python ```python
import openai # openai v1.0.0+ import openai # openai v1.0.0+
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:4000") # set proxy to base_url client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:4000") # set proxy to base_url

View file

@ -0,0 +1,168 @@
# Custom API Server (Custom Format)
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
:::info
For calling an openai-compatible endpoint, [go here](./openai_compatible.md)
:::
## Quick Start
```python
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
{"provider": "my-custom-llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="my-custom-llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
```
## OpenAI Proxy Usage
1. Setup your `custom_handler.py` file
```python
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()
```
2. Add to `config.yaml`
In the config below, we pass
python_filename: `custom_handler.py`
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
custom_handler: `custom_handler.my_custom_llm`
```yaml
model_list:
- model_name: "test-model"
litellm_params:
model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
```
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "my-custom-model",
"messages": [{"role": "user", "content": "Say \"this is a test\" in JSON!"}],
}'
```
Expected Response
```
{
"id": "chatcmpl-06f1b9cd-08bc-43f7-9814-a69173921216",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "Hi!",
"role": "assistant",
"tool_calls": null,
"function_call": null
}
}
],
"created": 1721955063,
"model": "gpt-3.5-turbo",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
```
## Custom Handler Spec
```python
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from typing import Iterator, AsyncIterator
from litellm.llms.base import BaseLLM
class CustomLLMError(Exception): # use this for all your exceptions
def __init__(
self,
status_code,
message,
):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
```

View file

@ -1,129 +0,0 @@
# Custom API Server (OpenAI Format)
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
## API KEYS
No api keys required
## Set up your Custom API Server
Your server should have the following Endpoints:
Here's an example OpenAI proxy server with routes: https://replit.com/@BerriAI/openai-proxy#main.py
### Required Endpoints
- POST `/chat/completions` - chat completions endpoint
### Optional Endpoints
- POST `/completions` - completions endpoint
- Get `/models` - available models on server
- POST `/embeddings` - creates an embedding vector representing the input text.
## Example Usage
### Call `/chat/completions`
In order to use your custom OpenAI Chat Completion proxy with LiteLLM, ensure you set
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
* `custom_llm_provider` to `openai` this ensures litellm uses the `openai.ChatCompletion` to your api_base
```python
import os
from litellm import completion
## set ENV variables
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
messages = [{ "content": "Hello, how are you?","role": "user"}]
response = completion(
model="command-nightly",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://openai-proxy.berriai.repl.co",
custom_llm_provider="openai" # litellm will use the openai.ChatCompletion to make the request
)
print(response)
```
#### Response
```json
{
"object":
"chat.completion",
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content":
"The sky, a canvas of blue,\nA work of art, pure and true,\nA",
"role": "assistant"
}
}],
"id":
"chatcmpl-7fbd6077-de10-4cb4-a8a4-3ef11a98b7c8",
"created":
1699290237.408061,
"model":
"togethercomputer/llama-2-70b-chat",
"usage": {
"completion_tokens": 18,
"prompt_tokens": 14,
"total_tokens": 32
}
}
```
### Call `/completions`
In order to use your custom OpenAI Completion proxy with LiteLLM, ensure you set
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
* `custom_llm_provider` to `text-completion-openai` this ensures litellm uses the `openai.Completion` to your api_base
```python
import os
from litellm import completion
## set ENV variables
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
messages = [{ "content": "Hello, how are you?","role": "user"}]
response = completion(
model="command-nightly",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://openai-proxy.berriai.repl.co",
custom_llm_provider="text-completion-openai" # litellm will use the openai.Completion to make the request
)
print(response)
```
#### Response
```json
{
"warning":
"This model version is deprecated. Migrate before January 4, 2024 to avoid disruption of service. Learn more https://platform.openai.com/docs/deprecations",
"id":
"cmpl-8HxHqF5dymQdALmLplS0dWKZVFe3r",
"object":
"text_completion",
"created":
1699290166,
"model":
"text-davinci-003",
"choices": [{
"text":
"\n\nThe weather in San Francisco varies depending on what time of year and time",
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 7,
"completion_tokens": 16,
"total_tokens": 23
}
}
```

View file

@ -254,6 +254,15 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
**That's it ! That's the quick start to deploy litellm** **That's it ! That's the quick start to deploy litellm**
## Use with Langchain, OpenAI SDK, LlamaIndex, Instructor, Curl
:::info
💡 Go here 👉 [to make your first LLM API Request](user_keys)
LiteLLM is compatible with several SDKs - including OpenAI SDK, Anthropic SDK, Mistral SDK, LLamaIndex, Langchain (Js, Python)
:::
## Options to deploy LiteLLM ## Options to deploy LiteLLM
| Docs | When to Use | | Docs | When to Use |

View file

@ -41,28 +41,6 @@ litellm --health
} }
``` ```
### Background Health Checks
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
Here's how to use it:
1. in the config.yaml add:
```
general_settings:
background_health_checks: True # enable background health checks
health_check_interval: 300 # frequency of background health checks
```
2. Start server
```
$ litellm /path/to/config.yaml
```
3. Query health endpoint:
```
curl --location 'http://0.0.0.0:4000/health'
```
### Embedding Models ### Embedding Models
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
@ -124,6 +102,41 @@ model_list:
mode: audio_transcription mode: audio_transcription
``` ```
### Text to Speech Models
```yaml
# OpenAI Text to Speech Models
- model_name: tts
litellm_params:
model: openai/tts-1
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
```
## Background Health Checks
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
Here's how to use it:
1. in the config.yaml add:
```
general_settings:
background_health_checks: True # enable background health checks
health_check_interval: 300 # frequency of background health checks
```
2. Start server
```
$ litellm /path/to/config.yaml
```
3. Query health endpoint:
```
curl --location 'http://0.0.0.0:4000/health'
```
### Hide details ### Hide details
The health check response contains details like endpoint URLs, error messages, The health check response contains details like endpoint URLs, error messages,

View file

@ -255,6 +255,12 @@ litellm --config your_config.yaml
## Using LiteLLM Proxy - Curl Request, OpenAI Package, Langchain ## Using LiteLLM Proxy - Curl Request, OpenAI Package, Langchain
:::info
LiteLLM is compatible with several SDKs - including OpenAI SDK, Anthropic SDK, Mistral SDK, LLamaIndex, Langchain (Js, Python)
[More examples here](user_keys)
:::
<Tabs> <Tabs>
<TabItem value="Curl" label="Curl Request"> <TabItem value="Curl" label="Curl Request">
@ -382,6 +388,34 @@ print(response)
``` ```
</TabItem> </TabItem>
<TabItem value="anthropic-py" label="Anthropic Python SDK">
```python
import os
from anthropic import Anthropic
client = Anthropic(
base_url="http://localhost:4000", # proxy endpoint
api_key="sk-s4xN1IiLTCytwtZFJaYQrA", # litellm proxy virtual key
)
message = client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Hello, Claude",
}
],
model="claude-3-opus-20240229",
)
print(message.content)
```
</TabItem>
</Tabs> </Tabs>
[**More Info**](./configs.md) [**More Info**](./configs.md)
@ -396,165 +430,6 @@ print(response)
- POST `/key/generate` - generate a key to access the proxy - POST `/key/generate` - generate a key to access the proxy
## Using with OpenAI compatible projects
Set `base_url` to the LiteLLM Proxy server
<Tabs>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="librechat" label="LibreChat">
#### Start the LiteLLM proxy
```shell
litellm --model gpt-3.5-turbo
#INFO: Proxy running on http://0.0.0.0:4000
```
#### 1. Clone the repo
```shell
git clone https://github.com/danny-avila/LibreChat.git
```
#### 2. Modify Librechat's `docker-compose.yml`
LiteLLM Proxy is running on port `4000`, set `4000` as the proxy below
```yaml
OPENAI_REVERSE_PROXY=http://host.docker.internal:4000/v1/chat/completions
```
#### 3. Save fake OpenAI key in Librechat's `.env`
Copy Librechat's `.env.example` to `.env` and overwrite the default OPENAI_API_KEY (by default it requires the user to pass a key).
```env
OPENAI_API_KEY=sk-1234
```
#### 4. Run LibreChat:
```shell
docker compose up
```
</TabItem>
<TabItem value="continue-dev" label="ContinueDev">
Continue-Dev brings ChatGPT to VSCode. See how to [install it here](https://continue.dev/docs/quickstart).
In the [config.py](https://continue.dev/docs/reference/Models/openai) set this as your default model.
```python
default=OpenAI(
api_key="IGNORED",
model="fake-model-name",
context_length=2048, # customize if needed for your model
api_base="http://localhost:4000" # your proxy server url
),
```
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem>
<TabItem value="aider" label="Aider">
```shell
$ pip install aider
$ aider --openai-api-base http://0.0.0.0:4000 --openai-api-key fake-key
```
</TabItem>
<TabItem value="autogen" label="AutoGen">
```python
pip install pyautogen
```
```python
from autogen import AssistantAgent, UserProxyAgent, oai
config_list=[
{
"model": "my-fake-model",
"api_base": "http://localhost:4000", #litellm compatible endpoint
"api_type": "open_ai",
"api_key": "NULL", # just a placeholder
}
]
response = oai.Completion.create(config_list=config_list, prompt="Hi")
print(response) # works fine
llm_config={
"config_list": config_list,
}
assistant = AssistantAgent("assistant", llm_config=llm_config)
user_proxy = UserProxyAgent("user_proxy")
user_proxy.initiate_chat(assistant, message="Plot a chart of META and TESLA stock price change YTD.", config_list=config_list)
```
Credits [@victordibia](https://github.com/microsoft/autogen/issues/45#issuecomment-1749921972) for this tutorial.
</TabItem>
<TabItem value="guidance" label="guidance">
A guidance language for controlling large language models.
https://github.com/guidance-ai/guidance
**NOTE:** Guidance sends additional params like `stop_sequences` which can cause some models to fail if they don't support it.
**Fix**: Start your proxy using the `--drop_params` flag
```shell
litellm --model ollama/codellama --temperature 0.3 --max_tokens 2048 --drop_params
```
```python
import guidance
# set api_base to your proxy
# set api_key to anything
gpt4 = guidance.llms.OpenAI("gpt-4", api_base="http://0.0.0.0:4000", api_key="anything")
experts = guidance('''
{{#system~}}
You are a helpful and terse assistant.
{{~/system}}
{{#user~}}
I want a response to the following question:
{{query}}
Name 3 world-class experts (past or present) who would be great at answering this?
Don't answer the question yet.
{{~/user}}
{{#assistant~}}
{{gen 'expert_names' temperature=0 max_tokens=300}}
{{~/assistant}}
''', llm=gpt4)
result = experts(query='How can I be more productive?')
print(result)
```
</TabItem>
</Tabs>
## Debugging Proxy ## Debugging Proxy
Events that occur during normal operation Events that occur during normal operation

View file

@ -31,8 +31,19 @@ model_list:
api_base: https://openai-france-1234.openai.azure.com/ api_base: https://openai-france-1234.openai.azure.com/
api_key: <your-azure-api-key> api_key: <your-azure-api-key>
rpm: 1440 rpm: 1440
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
num_retries: 2
timeout: 30 # 30 seconds
redis_host: <your redis host> # set this when using multiple litellm proxy deployments, load balancing state stored in redis
redis_password: <your redis password>
redis_port: 1992
``` ```
:::info
Detailed information about [routing strategies can be found here](../routing)
:::
#### Step 2: Start Proxy with config #### Step 2: Start Proxy with config
```shell ```shell

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Use with Langchain, OpenAI SDK, LlamaIndex, Instructor, Curl # 💡 Use with Langchain, OpenAI SDK, LlamaIndex, Instructor, Curl
:::info :::info
@ -234,6 +234,54 @@ main();
``` ```
</TabItem> </TabItem>
<TabItem value="anthropic-py" label="Anthropic Python SDK">
```python
import os
from anthropic import Anthropic
client = Anthropic(
base_url="http://localhost:4000", # proxy endpoint
api_key="sk-s4xN1IiLTCytwtZFJaYQrA", # litellm proxy virtual key
)
message = client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Hello, Claude",
}
],
model="claude-3-opus-20240229",
)
print(message.content)
```
</TabItem>
<TabItem value="mistral-py" label="Mistral Python SDK">
```python
import os
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
client = MistralClient(api_key="sk-1234", endpoint="http://0.0.0.0:4000")
chat_response = client.chat(
model="mistral-small-latest",
messages=[
{"role": "user", "content": "this is a test request, write a short poem"}
],
)
print(chat_response.choices[0].message.content)
```
</TabItem>
<TabItem value="instructor" label="Instructor"> <TabItem value="instructor" label="Instructor">
```python ```python
@ -566,6 +614,166 @@ curl --location 'http://0.0.0.0:4000/moderations' \
``` ```
## Using with OpenAI compatible projects
Set `base_url` to the LiteLLM Proxy server
<Tabs>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="librechat" label="LibreChat">
#### Start the LiteLLM proxy
```shell
litellm --model gpt-3.5-turbo
#INFO: Proxy running on http://0.0.0.0:4000
```
#### 1. Clone the repo
```shell
git clone https://github.com/danny-avila/LibreChat.git
```
#### 2. Modify Librechat's `docker-compose.yml`
LiteLLM Proxy is running on port `4000`, set `4000` as the proxy below
```yaml
OPENAI_REVERSE_PROXY=http://host.docker.internal:4000/v1/chat/completions
```
#### 3. Save fake OpenAI key in Librechat's `.env`
Copy Librechat's `.env.example` to `.env` and overwrite the default OPENAI_API_KEY (by default it requires the user to pass a key).
```env
OPENAI_API_KEY=sk-1234
```
#### 4. Run LibreChat:
```shell
docker compose up
```
</TabItem>
<TabItem value="continue-dev" label="ContinueDev">
Continue-Dev brings ChatGPT to VSCode. See how to [install it here](https://continue.dev/docs/quickstart).
In the [config.py](https://continue.dev/docs/reference/Models/openai) set this as your default model.
```python
default=OpenAI(
api_key="IGNORED",
model="fake-model-name",
context_length=2048, # customize if needed for your model
api_base="http://localhost:4000" # your proxy server url
),
```
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem>
<TabItem value="aider" label="Aider">
```shell
$ pip install aider
$ aider --openai-api-base http://0.0.0.0:4000 --openai-api-key fake-key
```
</TabItem>
<TabItem value="autogen" label="AutoGen">
```python
pip install pyautogen
```
```python
from autogen import AssistantAgent, UserProxyAgent, oai
config_list=[
{
"model": "my-fake-model",
"api_base": "http://localhost:4000", #litellm compatible endpoint
"api_type": "open_ai",
"api_key": "NULL", # just a placeholder
}
]
response = oai.Completion.create(config_list=config_list, prompt="Hi")
print(response) # works fine
llm_config={
"config_list": config_list,
}
assistant = AssistantAgent("assistant", llm_config=llm_config)
user_proxy = UserProxyAgent("user_proxy")
user_proxy.initiate_chat(assistant, message="Plot a chart of META and TESLA stock price change YTD.", config_list=config_list)
```
Credits [@victordibia](https://github.com/microsoft/autogen/issues/45#issuecomment-1749921972) for this tutorial.
</TabItem>
<TabItem value="guidance" label="guidance">
A guidance language for controlling large language models.
https://github.com/guidance-ai/guidance
**NOTE:** Guidance sends additional params like `stop_sequences` which can cause some models to fail if they don't support it.
**Fix**: Start your proxy using the `--drop_params` flag
```shell
litellm --model ollama/codellama --temperature 0.3 --max_tokens 2048 --drop_params
```
```python
import guidance
# set api_base to your proxy
# set api_key to anything
gpt4 = guidance.llms.OpenAI("gpt-4", api_base="http://0.0.0.0:4000", api_key="anything")
experts = guidance('''
{{#system~}}
You are a helpful and terse assistant.
{{~/system}}
{{#user~}}
I want a response to the following question:
{{query}}
Name 3 world-class experts (past or present) who would be great at answering this?
Don't answer the question yet.
{{~/user}}
{{#assistant~}}
{{gen 'expert_names' temperature=0 max_tokens=300}}
{{~/assistant}}
''', llm=gpt4)
result = experts(query='How can I be more productive?')
print(result)
```
</TabItem>
</Tabs>
## Advanced ## Advanced
### (BETA) Batch Completions - pass multiple models ### (BETA) Batch Completions - pass multiple models

View file

@ -175,7 +175,8 @@ const sidebars = {
"providers/aleph_alpha", "providers/aleph_alpha",
"providers/baseten", "providers/baseten",
"providers/openrouter", "providers/openrouter",
"providers/custom_openai_proxy", // "providers/custom_openai_proxy",
"providers/custom_llm_server",
"providers/petals", "providers/petals",
], ],

View file

@ -813,6 +813,7 @@ from .utils import (
) )
from .types.utils import ImageObject from .types.utils import ImageObject
from .llms.custom_llm import CustomLLM
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
@ -909,3 +910,12 @@ from .cost_calculator import response_cost_calculator, cost_per_token
from .types.adapter import AdapterItem from .types.adapter import AdapterItem
adapters: List[AdapterItem] = [] adapters: List[AdapterItem] = []
### CUSTOM LLMs ###
from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers

View file

@ -2,10 +2,6 @@
# On success + failure, log events to Logfire # On success + failure, log events to Logfire
import os import os
import dotenv
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import uuid import uuid
from enum import Enum from enum import Enum

View file

@ -1864,6 +1864,23 @@ class AzureChatCompletion(BaseLLM):
model=model, # type: ignore model=model, # type: ignore
prompt=prompt, # type: ignore prompt=prompt, # type: ignore
) )
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else: else:
raise Exception("mode not set") raise Exception("mode not set")
response = {} response = {}

View file

@ -78,6 +78,8 @@ BEDROCK_CONVERSE_MODELS = [
"ai21.jamba-instruct-v1:0", "ai21.jamba-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
] ]
@ -1315,6 +1317,7 @@ class AmazonConverseConfig:
model.startswith("anthropic") model.startswith("anthropic")
or model.startswith("mistral") or model.startswith("mistral")
or model.startswith("cohere") or model.startswith("cohere")
or model.startswith("meta.llama3-1")
): ):
supported_params.append("tools") supported_params.append("tools")

161
litellm/llms/custom_llm.py Normal file
View file

@ -0,0 +1,161 @@
# What is this?
## Handler file for a Custom Chat LLM
"""
- completion
- acompletion
- streaming
- async_streaming
"""
import copy
import json
import os
import time
import types
from enum import Enum
from functools import partial
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class CustomLLMError(Exception): # use this for all your exceptions
def __init__(
self,
status_code,
message,
):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def custom_chat_llm_router(
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
):
"""
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
Validates if response is in expected format
"""
if async_fn:
if stream:
return custom_llm.astreaming
return custom_llm.acompletion
if stream:
return custom_llm.streaming
return custom_llm.completion

View file

@ -1,5 +1,6 @@
import hashlib import hashlib
import json import json
import os
import time import time
import traceback import traceback
import types import types
@ -1870,8 +1871,25 @@ class OpenAIChatCompletion(BaseLLM):
model=model, # type: ignore model=model, # type: ignore
prompt=prompt, # type: ignore prompt=prompt, # type: ignore
) )
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else: else:
raise Exception("mode not set") raise ValueError("mode not set, passed in mode: " + mode)
response = {} response = {}
if completion is None or not hasattr(completion, "headers"): if completion is None or not hasattr(completion, "headers"):

View file

@ -387,7 +387,7 @@ def process_response(
result = " " result = " "
## Building RESPONSE OBJECT ## Building RESPONSE OBJECT
if len(result) > 1: if len(result) >= 1:
model_response.choices[0].message.content = result # type: ignore model_response.choices[0].message.content = result # type: ignore
# Calculate usage # Calculate usage

View file

@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks import DatabricksChatCompletion from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -381,6 +382,7 @@ async def acompletion(
or custom_llm_provider == "clarifai" or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance( if isinstance(init_response, dict) or isinstance(
@ -2690,6 +2692,54 @@ def completion(
model_response.created = int(time.time()) model_response.created = int(time.time())
model_response.model = model model_response.model = model
response = model_response response = model_response
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
## ROUTE LLM CALL ##
handler_fn = custom_chat_llm_router(
async_fn=acompletion, stream=stream, custom_llm=custom_handler
)
headers = headers or litellm.headers
## CALL FUNCTION
response = handler_fn(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
)
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging,
)
else: else:
raise ValueError( raise ValueError(
f"Unable to map your input to a model. Check your input - {args}" f"Unable to map your input to a model. Check your input - {args}"

View file

@ -2996,6 +2996,15 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"mistral.mistral-large-2407-v1:0": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "bedrock",
"mode": "chat"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": { "bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 32000, "max_input_tokens": 32000,
@ -3731,6 +3740,15 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"meta.llama3-1-405b-instruct-v1:0": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000532,
"output_cost_per_token": 0.000016,
"litellm_provider": "bedrock",
"mode": "chat"
},
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": { "512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77, "max_tokens": 77,
"max_input_tokens": 77, "max_input_tokens": 77,

View file

@ -2,3 +2,10 @@ model_list:
- model_name: "test-model" - model_name: "test-model"
litellm_params: litellm_params:
model: "openai/text-embedding-ada-002" model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}

View file

@ -370,10 +370,17 @@ async def _cache_team_object(
team_id: str, team_id: str,
team_table: LiteLLM_TeamTable, team_table: LiteLLM_TeamTable,
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
): ):
key = "team_id:{}".format(team_id) key = "team_id:{}".format(team_id)
await user_api_key_cache.async_set_cache(key=key, value=team_table) await user_api_key_cache.async_set_cache(key=key, value=team_table)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.async_set_cache(
key=key, value=team_table
)
@log_to_opentelemetry @log_to_opentelemetry
async def get_team_object( async def get_team_object(
@ -395,7 +402,17 @@ async def get_team_object(
# check if in cache # check if in cache
key = "team_id:{}".format(team_id) key = "team_id:{}".format(team_id)
cached_team_obj: Optional[LiteLLM_TeamTable] = None
## CHECK REDIS CACHE ##
if proxy_logging_obj is not None:
cached_team_obj = await proxy_logging_obj.internal_usage_cache.async_get_cache(
key=key
)
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key) cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_team_obj is not None: if cached_team_obj is not None:
if isinstance(cached_team_obj, dict): if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj) return LiteLLM_TeamTable(**cached_team_obj)
@ -413,7 +430,10 @@ async def get_team_object(
_response = LiteLLM_TeamTable(**response.dict()) _response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache # save the team object to cache
await _cache_team_object( await _cache_team_object(
team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache team_id=team_id,
team_table=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
) )
return _response return _response

View file

@ -0,0 +1,21 @@
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()

View file

@ -334,6 +334,7 @@ async def update_team(
create_audit_log_for_update, create_audit_log_for_update,
litellm_proxy_admin_name, litellm_proxy_admin_name,
prisma_client, prisma_client,
proxy_logging_obj,
user_api_key_cache, user_api_key_cache,
) )
@ -380,6 +381,7 @@ async def update_team(
team_id=team_row.team_id, team_id=team_row.team_id,
team_table=team_row, team_table=team_row,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
) )
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True

View file

@ -15,6 +15,16 @@ model_list:
litellm_params: litellm_params:
model: openai/* model: openai/*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: mistral-small-latest
litellm_params:
model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY"
- model_name: tts
litellm_params:
model: openai/tts-1
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
alerting: ["slack"] alerting: ["slack"]

View file

@ -1507,6 +1507,21 @@ class ProxyConfig:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"litellm.post_call_rules: {litellm.post_call_rules}" f"litellm.post_call_rules: {litellm.post_call_rules}"
) )
elif key == "custom_provider_map":
from litellm.utils import custom_llm_setup
litellm.custom_provider_map = [
{
"provider": item["provider"],
"custom_handler": get_instance_fn(
value=item["custom_handler"],
config_file_path=config_file_path,
),
}
for item in value
]
custom_llm_setup()
elif key == "success_callback": elif key == "success_callback":
litellm.success_callback = [] litellm.success_callback = []

View file

@ -0,0 +1,13 @@
import os
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
client = MistralClient(api_key="sk-1234", endpoint="http://0.0.0.0:4000")
chat_response = client.chat(
model="mistral-small-latest",
messages=[
{"role": "user", "content": "this is a test request, write a short poem"}
],
)
print(chat_response.choices[0].message.content)

View file

@ -862,7 +862,7 @@ class PrismaClient:
) )
""" """
) )
if ret[0]['sum'] == 6: if ret[0]["sum"] == 6:
print("All necessary views exist!") # noqa print("All necessary views exist!") # noqa
return return
except Exception: except Exception:

View file

@ -0,0 +1,302 @@
# What is this?
## Unit tests for the CustomLLM class
import asyncio
import os
import sys
import time
import traceback
import openai
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
Optional,
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import (
ChatCompletionDeltaChunk,
ChatCompletionUsageBlock,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
acompletion,
completion,
get_llm_provider,
)
from litellm.utils import ModelResponseIterator
class CustomModelResponseIterator:
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="hello world",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=ChatCompletionUsageBlock(
prompt_tokens=10, completion_tokens=20, total_tokens=30
),
index=0,
)
# Sync iterator
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk: Any = self.streaming_response.__next__() # type: ignore
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self.streaming_response
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MyCustomLLM(CustomLLM):
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
completion_stream = ModelResponseIterator(
model_response=generic_streaming_chunk # type: ignore
)
custom_iterator = CustomModelResponseIterator(
streaming_response=completion_stream
)
return custom_iterator
async def astreaming( # type: ignore
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
yield generic_streaming_chunk # type: ignore
def test_get_llm_provider():
""""""
from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
custom_llm_setup()
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
assert provider == "custom_llm"
def test_simple_completion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
@pytest.mark.asyncio
async def test_simple_acompletion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
def test_simple_completion_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"
@pytest.mark.asyncio
async def test_simple_completion_async_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
async for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"

View file

@ -731,3 +731,67 @@ def test_load_router_config(mock_cache, fake_env_vars):
# test_load_router_config() # test_load_router_config()
@pytest.mark.asyncio
async def test_team_update_redis():
"""
Tests if team update, updates the redis cache if set
"""
from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.auth.auth_checks import _cache_team_object
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
)
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
with patch.object(
proxy_logging_obj.internal_usage_cache.redis_cache,
"async_set_cache",
new=MagicMock(),
) as mock_client:
await _cache_team_object(
team_id="1234",
team_table=LiteLLM_TeamTable(),
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)
mock_client.assert_called_once()
@pytest.mark.asyncio
async def test_get_team_redis(client_no_auth):
"""
Tests if get_team_object gets value from redis cache, if set
"""
from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
)
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
with patch.object(
proxy_logging_obj.internal_usage_cache.redis_cache,
"async_get_cache",
new=AsyncMock(),
) as mock_client:
try:
await get_team_object(
team_id="1234",
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
prisma_client=MagicMock(),
)
except Exception as e:
pass
mock_client.assert_called_once()

View file

@ -1117,6 +1117,8 @@ async def test_aimg_gen_on_router():
assert len(response.data) > 0 assert len(response.data) > 0
router.reset() router.reset()
except litellm.InternalServerError as e:
pass
except Exception as e: except Exception as e:
if "Your task failed as a result of our safety system." in str(e): if "Your task failed as a result of our safety system." in str(e):
pass pass

View file

@ -3248,6 +3248,56 @@ def test_unit_test_custom_stream_wrapper():
assert freq == 1 assert freq == 1
def test_unit_test_custom_stream_wrapper_openai():
"""
Test if last streaming chunk ends with '?', if the message repeats itself.
"""
litellm.set_verbose = False
chunk = {
"id": "chatcmpl-9mWtyDnikZZoB75DyfUzWUxiiE2Pi",
"choices": [
litellm.utils.StreamingChoices(
delta=litellm.utils.Delta(
content=None, function_call=None, role=None, tool_calls=None
),
finish_reason="content_filter",
index=0,
logprobs=None,
)
],
"created": 1721353246,
"model": "gpt-3.5-turbo-0613",
"object": "chat.completion.chunk",
"system_fingerprint": None,
"usage": None,
}
chunk = litellm.ModelResponse(**chunk, stream=True)
completion_stream = ModelResponseIterator(model_response=chunk)
response = litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model="gpt-3.5-turbo",
custom_llm_provider="azure",
logging_obj=litellm.Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
stream_finish_reason: Optional[str] = None
for chunk in response:
assert chunk.choices[0].delta.content is None
if chunk.choices[0].finish_reason is not None:
stream_finish_reason = chunk.choices[0].finish_reason
assert stream_finish_reason == "content_filter"
def test_aamazing_unit_test_custom_stream_wrapper_n(): def test_aamazing_unit_test_custom_stream_wrapper_n():
""" """
Test if the translated output maps exactly to the received openai input Test if the translated output maps exactly to the received openai input

View file

@ -0,0 +1,10 @@
from typing import List
from typing_extensions import Dict, Required, TypedDict, override
from litellm.llms.custom_llm import CustomLLM
class CustomLLMItem(TypedDict):
provider: str
custom_handler: CustomLLM

View file

@ -330,6 +330,18 @@ class Rules:
####### CLIENT ################### ####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
if custom_llm["provider"] not in litellm._custom_providers:
litellm._custom_providers.append(custom_llm["provider"])
def function_setup( def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -341,6 +353,10 @@ def function_setup(
try: try:
global callback_list, add_breadcrumb, user_logger_fn, Logging global callback_list, add_breadcrumb, user_logger_fn, Logging
## CUSTOM LLM SETUP ##
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0: if len(litellm.callbacks) > 0:
@ -3121,7 +3137,19 @@ def get_optional_params(
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
if "ai21" in model: if model in litellm.BEDROCK_CONVERSE_MODELS:
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "ai21" in model:
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[], # params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra # https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -3143,17 +3171,6 @@ def get_optional_params(
optional_params=optional_params, optional_params=optional_params,
) )
) )
elif model in litellm.BEDROCK_CONVERSE_MODELS:
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else: else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params( optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params, non_default_params=non_default_params,
@ -8825,21 +8842,6 @@ class CustomStreamWrapper:
if str_line.choices[0].finish_reason: if str_line.choices[0].finish_reason:
is_finished = True is_finished = True
finish_reason = str_line.choices[0].finish_reason finish_reason = str_line.choices[0].finish_reason
if finish_reason == "content_filter":
if hasattr(str_line.choices[0], "content_filter_result"):
error_message = json.dumps(
str_line.choices[0].content_filter_result
)
else:
error_message = "{} Response={}".format(
self.custom_llm_provider, str(dict(str_line))
)
raise litellm.ContentPolicyViolationError(
message=error_message,
llm_provider=self.custom_llm_provider,
model=self.model,
)
# checking for logprobs # checking for logprobs
if ( if (
@ -9248,7 +9250,10 @@ class CustomStreamWrapper:
try: try:
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic": if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
@ -10115,6 +10120,7 @@ class CustomStreamWrapper:
try: try:
if self.completion_stream is None: if self.completion_stream is None:
await self.fetch_stream() await self.fetch_stream()
if ( if (
self.custom_llm_provider == "openai" self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure" or self.custom_llm_provider == "azure"
@ -10139,6 +10145,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "triton" or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx" or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
or self.custom_llm_provider in litellm._custom_providers
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}") print_verbose(f"value of async chunk: {chunk}")
@ -10967,3 +10974,8 @@ class ModelResponseIterator:
raise StopAsyncIteration raise StopAsyncIteration
self.is_done = True self.is_done = True
return self.model_response return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()

View file

@ -2996,6 +2996,15 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"mistral.mistral-large-2407-v1:0": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "bedrock",
"mode": "chat"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": { "bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 32000, "max_input_tokens": 32000,
@ -3731,6 +3740,15 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"meta.llama3-1-405b-instruct-v1:0": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000532,
"output_cost_per_token": 0.000016,
"litellm_provider": "bedrock",
"mode": "chat"
},
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": { "512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77, "max_tokens": 77,
"max_input_tokens": 77, "max_input_tokens": 77,