forked from phoenix/litellm-mirror
Merge branch 'main' into bedrock-llama3.1-405b
This commit is contained in:
commit
b6ca4406b6
22 changed files with 888 additions and 159 deletions
168
docs/my-website/docs/providers/custom_llm_server.md
Normal file
168
docs/my-website/docs/providers/custom_llm_server.md
Normal file
|
@ -0,0 +1,168 @@
|
|||
# Custom API Server (Custom Format)
|
||||
|
||||
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
|
||||
|
||||
|
||||
:::info
|
||||
|
||||
For calling an openai-compatible endpoint, [go here](./openai_compatible.md)
|
||||
:::
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import CustomLLM, completion, get_llm_provider
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
|
||||
{"provider": "my-custom-llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
|
||||
resp = completion(
|
||||
model="my-custom-llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
||||
```
|
||||
|
||||
## OpenAI Proxy Usage
|
||||
|
||||
1. Setup your `custom_handler.py` file
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import CustomLLM, completion, get_llm_provider
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
||||
```
|
||||
|
||||
2. Add to `config.yaml`
|
||||
|
||||
In the config below, we pass
|
||||
|
||||
python_filename: `custom_handler.py`
|
||||
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
|
||||
|
||||
custom_handler: `custom_handler.my_custom_llm`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "test-model"
|
||||
litellm_params:
|
||||
model: "openai/text-embedding-ada-002"
|
||||
- model_name: "my-custom-model"
|
||||
litellm_params:
|
||||
model: "my-custom-llm/my-model"
|
||||
|
||||
litellm_settings:
|
||||
custom_provider_map:
|
||||
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
|
||||
```
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "my-custom-model",
|
||||
"messages": [{"role": "user", "content": "Say \"this is a test\" in JSON!"}],
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response
|
||||
|
||||
```
|
||||
{
|
||||
"id": "chatcmpl-06f1b9cd-08bc-43f7-9814-a69173921216",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "Hi!",
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"function_call": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1721955063,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": null,
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Handler Spec
|
||||
|
||||
```python
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from typing import Iterator, AsyncIterator
|
||||
from litellm.llms.base import BaseLLM
|
||||
|
||||
class CustomLLMError(Exception): # use this for all your exceptions
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
```
|
|
@ -1,129 +0,0 @@
|
|||
# Custom API Server (OpenAI Format)
|
||||
|
||||
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
|
||||
|
||||
## API KEYS
|
||||
No api keys required
|
||||
|
||||
## Set up your Custom API Server
|
||||
Your server should have the following Endpoints:
|
||||
|
||||
Here's an example OpenAI proxy server with routes: https://replit.com/@BerriAI/openai-proxy#main.py
|
||||
|
||||
### Required Endpoints
|
||||
- POST `/chat/completions` - chat completions endpoint
|
||||
|
||||
### Optional Endpoints
|
||||
- POST `/completions` - completions endpoint
|
||||
- Get `/models` - available models on server
|
||||
- POST `/embeddings` - creates an embedding vector representing the input text.
|
||||
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Call `/chat/completions`
|
||||
In order to use your custom OpenAI Chat Completion proxy with LiteLLM, ensure you set
|
||||
|
||||
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
|
||||
* `custom_llm_provider` to `openai` this ensures litellm uses the `openai.ChatCompletion` to your api_base
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
## set ENV variables
|
||||
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
|
||||
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
|
||||
response = completion(
|
||||
model="command-nightly",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
api_base="https://openai-proxy.berriai.repl.co",
|
||||
custom_llm_provider="openai" # litellm will use the openai.ChatCompletion to make the request
|
||||
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Response
|
||||
```json
|
||||
{
|
||||
"object":
|
||||
"chat.completion",
|
||||
"choices": [{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content":
|
||||
"The sky, a canvas of blue,\nA work of art, pure and true,\nA",
|
||||
"role": "assistant"
|
||||
}
|
||||
}],
|
||||
"id":
|
||||
"chatcmpl-7fbd6077-de10-4cb4-a8a4-3ef11a98b7c8",
|
||||
"created":
|
||||
1699290237.408061,
|
||||
"model":
|
||||
"togethercomputer/llama-2-70b-chat",
|
||||
"usage": {
|
||||
"completion_tokens": 18,
|
||||
"prompt_tokens": 14,
|
||||
"total_tokens": 32
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### Call `/completions`
|
||||
In order to use your custom OpenAI Completion proxy with LiteLLM, ensure you set
|
||||
|
||||
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
|
||||
* `custom_llm_provider` to `text-completion-openai` this ensures litellm uses the `openai.Completion` to your api_base
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
## set ENV variables
|
||||
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
|
||||
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
|
||||
response = completion(
|
||||
model="command-nightly",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
api_base="https://openai-proxy.berriai.repl.co",
|
||||
custom_llm_provider="text-completion-openai" # litellm will use the openai.Completion to make the request
|
||||
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Response
|
||||
```json
|
||||
{
|
||||
"warning":
|
||||
"This model version is deprecated. Migrate before January 4, 2024 to avoid disruption of service. Learn more https://platform.openai.com/docs/deprecations",
|
||||
"id":
|
||||
"cmpl-8HxHqF5dymQdALmLplS0dWKZVFe3r",
|
||||
"object":
|
||||
"text_completion",
|
||||
"created":
|
||||
1699290166,
|
||||
"model":
|
||||
"text-davinci-003",
|
||||
"choices": [{
|
||||
"text":
|
||||
"\n\nThe weather in San Francisco varies depending on what time of year and time",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 16,
|
||||
"total_tokens": 23
|
||||
}
|
||||
}
|
||||
```
|
|
@ -41,28 +41,6 @@ litellm --health
|
|||
}
|
||||
```
|
||||
|
||||
### Background Health Checks
|
||||
|
||||
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
|
||||
|
||||
Here's how to use it:
|
||||
1. in the config.yaml add:
|
||||
```
|
||||
general_settings:
|
||||
background_health_checks: True # enable background health checks
|
||||
health_check_interval: 300 # frequency of background health checks
|
||||
```
|
||||
|
||||
2. Start server
|
||||
```
|
||||
$ litellm /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Query health endpoint:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/health'
|
||||
```
|
||||
|
||||
### Embedding Models
|
||||
|
||||
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
|
||||
|
@ -124,6 +102,41 @@ model_list:
|
|||
mode: audio_transcription
|
||||
```
|
||||
|
||||
|
||||
### Text to Speech Models
|
||||
|
||||
```yaml
|
||||
# OpenAI Text to Speech Models
|
||||
- model_name: tts
|
||||
litellm_params:
|
||||
model: openai/tts-1
|
||||
api_key: "os.environ/OPENAI_API_KEY"
|
||||
model_info:
|
||||
mode: audio_speech
|
||||
```
|
||||
|
||||
## Background Health Checks
|
||||
|
||||
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
|
||||
|
||||
Here's how to use it:
|
||||
1. in the config.yaml add:
|
||||
```
|
||||
general_settings:
|
||||
background_health_checks: True # enable background health checks
|
||||
health_check_interval: 300 # frequency of background health checks
|
||||
```
|
||||
|
||||
2. Start server
|
||||
```
|
||||
$ litellm /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Query health endpoint:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/health'
|
||||
```
|
||||
|
||||
### Hide details
|
||||
|
||||
The health check response contains details like endpoint URLs, error messages,
|
||||
|
|
|
@ -31,8 +31,19 @@ model_list:
|
|||
api_base: https://openai-france-1234.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 1440
|
||||
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
|
||||
num_retries: 2
|
||||
timeout: 30 # 30 seconds
|
||||
redis_host: <your redis host> # set this when using multiple litellm proxy deployments, load balancing state stored in redis
|
||||
redis_password: <your redis password>
|
||||
redis_port: 1992
|
||||
```
|
||||
|
||||
:::info
|
||||
Detailed information about [routing strategies can be found here](../routing)
|
||||
:::
|
||||
|
||||
#### Step 2: Start Proxy with config
|
||||
|
||||
```shell
|
||||
|
|
|
@ -175,7 +175,8 @@ const sidebars = {
|
|||
"providers/aleph_alpha",
|
||||
"providers/baseten",
|
||||
"providers/openrouter",
|
||||
"providers/custom_openai_proxy",
|
||||
// "providers/custom_openai_proxy",
|
||||
"providers/custom_llm_server",
|
||||
"providers/petals",
|
||||
|
||||
],
|
||||
|
|
|
@ -813,6 +813,7 @@ from .utils import (
|
|||
)
|
||||
|
||||
from .types.utils import ImageObject
|
||||
from .llms.custom_llm import CustomLLM
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
||||
|
@ -909,3 +910,12 @@ from .cost_calculator import response_cost_calculator, cost_per_token
|
|||
from .types.adapter import AdapterItem
|
||||
|
||||
adapters: List[AdapterItem] = []
|
||||
|
||||
### CUSTOM LLMs ###
|
||||
from .types.llms.custom_llm import CustomLLMItem
|
||||
from .types.utils import GenericStreamingChunk
|
||||
|
||||
custom_provider_map: List[CustomLLMItem] = []
|
||||
_custom_providers: List[str] = (
|
||||
[]
|
||||
) # internal helper util, used to track names of custom providers
|
||||
|
|
|
@ -1864,6 +1864,23 @@ class AzureChatCompletion(BaseLLM):
|
|||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "audio_transcription":
|
||||
# Get the current directory of the file being run
|
||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
|
||||
audio_file = open(file_path, "rb")
|
||||
completion = await client.audio.transcriptions.with_raw_response.create(
|
||||
file=audio_file,
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "audio_speech":
|
||||
# Get the current directory of the file being run
|
||||
completion = await client.audio.speech.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
input=prompt, # type: ignore
|
||||
voice="alloy",
|
||||
)
|
||||
else:
|
||||
raise Exception("mode not set")
|
||||
response = {}
|
||||
|
|
|
@ -79,6 +79,7 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
]
|
||||
|
||||
|
||||
|
|
161
litellm/llms/custom_llm.py
Normal file
161
litellm/llms/custom_llm.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
# What is this?
|
||||
## Handler file for a Custom Chat LLM
|
||||
|
||||
"""
|
||||
- completion
|
||||
- acompletion
|
||||
- streaming
|
||||
- async_streaming
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class CustomLLMError(Exception): # use this for all your exceptions
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Iterator[GenericStreamingChunk]:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
async def astreaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> AsyncIterator[GenericStreamingChunk]:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
|
||||
def custom_chat_llm_router(
|
||||
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
|
||||
):
|
||||
"""
|
||||
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
|
||||
|
||||
Validates if response is in expected format
|
||||
"""
|
||||
if async_fn:
|
||||
if stream:
|
||||
return custom_llm.astreaming
|
||||
return custom_llm.acompletion
|
||||
if stream:
|
||||
return custom_llm.streaming
|
||||
return custom_llm.completion
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
|
@ -1870,8 +1871,25 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "audio_transcription":
|
||||
# Get the current directory of the file being run
|
||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
|
||||
audio_file = open(file_path, "rb")
|
||||
completion = await client.audio.transcriptions.with_raw_response.create(
|
||||
file=audio_file,
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "audio_speech":
|
||||
# Get the current directory of the file being run
|
||||
completion = await client.audio.speech.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
input=prompt, # type: ignore
|
||||
voice="alloy",
|
||||
)
|
||||
else:
|
||||
raise Exception("mode not set")
|
||||
raise ValueError("mode not set, passed in mode: " + mode)
|
||||
response = {}
|
||||
|
||||
if completion is None or not hasattr(completion, "headers"):
|
||||
|
|
|
@ -387,7 +387,7 @@ def process_response(
|
|||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
if len(result) >= 1:
|
||||
model_response.choices[0].message.content = result # type: ignore
|
||||
|
||||
# Calculate usage
|
||||
|
|
|
@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
|||
from .llms.azure import AzureChatCompletion
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks import DatabricksChatCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
|
@ -381,6 +382,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "clarifai"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider in litellm._custom_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(
|
||||
|
@ -2690,6 +2692,54 @@ def completion(
|
|||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
response = model_response
|
||||
elif (
|
||||
custom_llm_provider in litellm._custom_providers
|
||||
): # Assume custom LLM provider
|
||||
# Get the Custom Handler
|
||||
custom_handler: Optional[CustomLLM] = None
|
||||
for item in litellm.custom_provider_map:
|
||||
if item["provider"] == custom_llm_provider:
|
||||
custom_handler = item["custom_handler"]
|
||||
|
||||
if custom_handler is None:
|
||||
raise ValueError(
|
||||
f"Unable to map your input to a model. Check your input - {args}"
|
||||
)
|
||||
|
||||
## ROUTE LLM CALL ##
|
||||
handler_fn = custom_chat_llm_router(
|
||||
async_fn=acompletion, stream=stream, custom_llm=custom_handler
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
## CALL FUNCTION
|
||||
response = handler_fn(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
encoding=encoding,
|
||||
)
|
||||
if stream is True:
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to map your input to a model. Check your input - {args}"
|
||||
|
|
|
@ -2996,6 +2996,15 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"mistral.mistral-large-2407-v1:0": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000009,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 32000,
|
||||
|
|
|
@ -1,4 +1,11 @@
|
|||
model_list:
|
||||
- model_name: "test-model"
|
||||
litellm_params:
|
||||
model: "openai/gpt-3.5-turbo-instruct-0914"
|
||||
model: "openai/text-embedding-ada-002"
|
||||
- model_name: "my-custom-model"
|
||||
litellm_params:
|
||||
model: "my-custom-llm/my-model"
|
||||
|
||||
litellm_settings:
|
||||
custom_provider_map:
|
||||
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
|
21
litellm/proxy/custom_handler.py
Normal file
21
litellm/proxy/custom_handler.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import litellm
|
||||
from litellm import CustomLLM, completion, get_llm_provider
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
|
@ -8,6 +8,12 @@ model_list:
|
|||
litellm_params:
|
||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||
api_key: "os.environ/FIREWORKS"
|
||||
- model_name: tts
|
||||
litellm_params:
|
||||
model: openai/tts-1
|
||||
api_key: "os.environ/OPENAI_API_KEY"
|
||||
model_info:
|
||||
mode: audio_speech
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
alerting: ["slack"]
|
||||
|
|
|
@ -1507,6 +1507,21 @@ class ProxyConfig:
|
|||
verbose_proxy_logger.debug(
|
||||
f"litellm.post_call_rules: {litellm.post_call_rules}"
|
||||
)
|
||||
elif key == "custom_provider_map":
|
||||
from litellm.utils import custom_llm_setup
|
||||
|
||||
litellm.custom_provider_map = [
|
||||
{
|
||||
"provider": item["provider"],
|
||||
"custom_handler": get_instance_fn(
|
||||
value=item["custom_handler"],
|
||||
config_file_path=config_file_path,
|
||||
),
|
||||
}
|
||||
for item in value
|
||||
]
|
||||
|
||||
custom_llm_setup()
|
||||
elif key == "success_callback":
|
||||
litellm.success_callback = []
|
||||
|
||||
|
@ -3334,6 +3349,7 @@ async def embeddings(
|
|||
if (
|
||||
"input" in data
|
||||
and isinstance(data["input"], list)
|
||||
and len(data["input"]) > 0
|
||||
and isinstance(data["input"][0], list)
|
||||
and isinstance(data["input"][0][0], int)
|
||||
): # check if array of tokens passed in
|
||||
|
@ -3464,8 +3480,8 @@ async def embeddings(
|
|||
litellm_debug_info,
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
|
||||
str(e)
|
||||
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
|
|
@ -263,7 +263,9 @@ class Router:
|
|||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
self.deployment_latency_map = {}
|
||||
### CACHING ###
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
|
||||
"local" # default to an in-memory cache
|
||||
)
|
||||
redis_cache = None
|
||||
cache_config = {}
|
||||
self.client_ttl = client_ttl
|
||||
|
|
302
litellm/tests/test_custom_llm.py
Normal file
302
litellm/tests/test_custom_llm.py
Normal file
|
@ -0,0 +1,302 @@
|
|||
# What is this?
|
||||
## Unit tests for the CustomLLM class
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
CustomLLM,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
acompletion,
|
||||
completion,
|
||||
get_llm_provider,
|
||||
)
|
||||
from litellm.utils import ModelResponseIterator
|
||||
|
||||
|
||||
class CustomModelResponseIterator:
|
||||
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
|
||||
self.streaming_response = streaming_response
|
||||
|
||||
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
|
||||
return GenericStreamingChunk(
|
||||
text="hello world",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=ChatCompletionUsageBlock(
|
||||
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk: Any = self.streaming_response.__next__() # type: ignore
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.chunk_parser(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
||||
return self.streaming_response
|
||||
|
||||
async def __anext__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.chunk_parser(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> Iterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": "Hello world",
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=generic_streaming_chunk # type: ignore
|
||||
)
|
||||
custom_iterator = CustomModelResponseIterator(
|
||||
streaming_response=completion_stream
|
||||
)
|
||||
return custom_iterator
|
||||
|
||||
async def astreaming( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": "Hello world",
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk # type: ignore
|
||||
|
||||
|
||||
def test_get_llm_provider():
|
||||
""""""
|
||||
from litellm.utils import custom_llm_setup
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
|
||||
custom_llm_setup()
|
||||
|
||||
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
|
||||
|
||||
assert provider == "custom_llm"
|
||||
|
||||
|
||||
def test_simple_completion():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = completion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_acompletion():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = await acompletion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
||||
|
||||
|
||||
def test_simple_completion_streaming():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = completion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in resp:
|
||||
print(chunk)
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert isinstance(chunk.choices[0].delta.content, str)
|
||||
else:
|
||||
assert chunk.choices[0].finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_completion_async_streaming():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = await litellm.acompletion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
print(chunk)
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert isinstance(chunk.choices[0].delta.content, str)
|
||||
else:
|
||||
assert chunk.choices[0].finish_reason == "stop"
|
10
litellm/types/llms/custom_llm.py
Normal file
10
litellm/types/llms/custom_llm.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from typing import List
|
||||
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
|
||||
|
||||
class CustomLLMItem(TypedDict):
|
||||
provider: str
|
||||
custom_handler: CustomLLM
|
|
@ -330,6 +330,18 @@ class Rules:
|
|||
|
||||
####### CLIENT ###################
|
||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def custom_llm_setup():
|
||||
"""
|
||||
Add custom_llm provider to provider list
|
||||
"""
|
||||
for custom_llm in litellm.custom_provider_map:
|
||||
if custom_llm["provider"] not in litellm.provider_list:
|
||||
litellm.provider_list.append(custom_llm["provider"])
|
||||
|
||||
if custom_llm["provider"] not in litellm._custom_providers:
|
||||
litellm._custom_providers.append(custom_llm["provider"])
|
||||
|
||||
|
||||
def function_setup(
|
||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -341,6 +353,10 @@ def function_setup(
|
|||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
|
||||
## CUSTOM LLM SETUP ##
|
||||
custom_llm_setup()
|
||||
|
||||
## LOGGING SETUP
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
if len(litellm.callbacks) > 0:
|
||||
|
@ -9247,7 +9263,10 @@ class CustomStreamWrapper:
|
|||
try:
|
||||
# return this for all models
|
||||
completion_obj = {"content": ""}
|
||||
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
||||
if self.custom_llm_provider and (
|
||||
self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
|
@ -10114,6 +10133,7 @@ class CustomStreamWrapper:
|
|||
try:
|
||||
if self.completion_stream is None:
|
||||
await self.fetch_stream()
|
||||
|
||||
if (
|
||||
self.custom_llm_provider == "openai"
|
||||
or self.custom_llm_provider == "azure"
|
||||
|
@ -10138,6 +10158,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "triton"
|
||||
or self.custom_llm_provider == "watsonx"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
or self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
print_verbose(f"value of async chunk: {chunk}")
|
||||
|
@ -10966,3 +10987,8 @@ class ModelResponseIterator:
|
|||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
class CustomModelResponseIterator(Iterable):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -2996,6 +2996,15 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"mistral.mistral-large-2407-v1:0": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000009,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 32000,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue