forked from phoenix/litellm-mirror
Merge branch 'main' into bedrock-llama3.1-405b
This commit is contained in:
commit
b6ca4406b6
22 changed files with 888 additions and 159 deletions
168
docs/my-website/docs/providers/custom_llm_server.md
Normal file
168
docs/my-website/docs/providers/custom_llm_server.md
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
# Custom API Server (Custom Format)
|
||||||
|
|
||||||
|
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
|
||||||
|
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
For calling an openai-compatible endpoint, [go here](./openai_compatible.md)
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import CustomLLM, completion, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomLLM(CustomLLM):
|
||||||
|
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
|
||||||
|
{"provider": "my-custom-llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
|
||||||
|
resp = completion(
|
||||||
|
model="my-custom-llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Hi!"
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAI Proxy Usage
|
||||||
|
|
||||||
|
1. Setup your `custom_handler.py` file
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import CustomLLM, completion, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomLLM(CustomLLM):
|
||||||
|
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Add to `config.yaml`
|
||||||
|
|
||||||
|
In the config below, we pass
|
||||||
|
|
||||||
|
python_filename: `custom_handler.py`
|
||||||
|
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
|
||||||
|
|
||||||
|
custom_handler: `custom_handler.my_custom_llm`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "test-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "openai/text-embedding-ada-002"
|
||||||
|
- model_name: "my-custom-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "my-custom-llm/my-model"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
custom_provider_map:
|
||||||
|
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "my-custom-model",
|
||||||
|
"messages": [{"role": "user", "content": "Say \"this is a test\" in JSON!"}],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-06f1b9cd-08bc-43f7-9814-a69173921216",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "Hi!",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null,
|
||||||
|
"function_call": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1721955063,
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": null,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"total_tokens": 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom Handler Spec
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||||
|
from typing import Iterator, AsyncIterator
|
||||||
|
from litellm.llms.base import BaseLLM
|
||||||
|
|
||||||
|
class CustomLLMError(Exception): # use this for all your exceptions
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
```
|
|
@ -1,129 +0,0 @@
|
||||||
# Custom API Server (OpenAI Format)
|
|
||||||
|
|
||||||
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
|
|
||||||
|
|
||||||
## API KEYS
|
|
||||||
No api keys required
|
|
||||||
|
|
||||||
## Set up your Custom API Server
|
|
||||||
Your server should have the following Endpoints:
|
|
||||||
|
|
||||||
Here's an example OpenAI proxy server with routes: https://replit.com/@BerriAI/openai-proxy#main.py
|
|
||||||
|
|
||||||
### Required Endpoints
|
|
||||||
- POST `/chat/completions` - chat completions endpoint
|
|
||||||
|
|
||||||
### Optional Endpoints
|
|
||||||
- POST `/completions` - completions endpoint
|
|
||||||
- Get `/models` - available models on server
|
|
||||||
- POST `/embeddings` - creates an embedding vector representing the input text.
|
|
||||||
|
|
||||||
|
|
||||||
## Example Usage
|
|
||||||
|
|
||||||
### Call `/chat/completions`
|
|
||||||
In order to use your custom OpenAI Chat Completion proxy with LiteLLM, ensure you set
|
|
||||||
|
|
||||||
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
|
|
||||||
* `custom_llm_provider` to `openai` this ensures litellm uses the `openai.ChatCompletion` to your api_base
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
## set ENV variables
|
|
||||||
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
|
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|
||||||
|
|
||||||
response = completion(
|
|
||||||
model="command-nightly",
|
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
|
||||||
api_base="https://openai-proxy.berriai.repl.co",
|
|
||||||
custom_llm_provider="openai" # litellm will use the openai.ChatCompletion to make the request
|
|
||||||
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Response
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"object":
|
|
||||||
"chat.completion",
|
|
||||||
"choices": [{
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"content":
|
|
||||||
"The sky, a canvas of blue,\nA work of art, pure and true,\nA",
|
|
||||||
"role": "assistant"
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
"id":
|
|
||||||
"chatcmpl-7fbd6077-de10-4cb4-a8a4-3ef11a98b7c8",
|
|
||||||
"created":
|
|
||||||
1699290237.408061,
|
|
||||||
"model":
|
|
||||||
"togethercomputer/llama-2-70b-chat",
|
|
||||||
"usage": {
|
|
||||||
"completion_tokens": 18,
|
|
||||||
"prompt_tokens": 14,
|
|
||||||
"total_tokens": 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Call `/completions`
|
|
||||||
In order to use your custom OpenAI Completion proxy with LiteLLM, ensure you set
|
|
||||||
|
|
||||||
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
|
|
||||||
* `custom_llm_provider` to `text-completion-openai` this ensures litellm uses the `openai.Completion` to your api_base
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
## set ENV variables
|
|
||||||
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
|
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|
||||||
|
|
||||||
response = completion(
|
|
||||||
model="command-nightly",
|
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
|
||||||
api_base="https://openai-proxy.berriai.repl.co",
|
|
||||||
custom_llm_provider="text-completion-openai" # litellm will use the openai.Completion to make the request
|
|
||||||
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Response
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"warning":
|
|
||||||
"This model version is deprecated. Migrate before January 4, 2024 to avoid disruption of service. Learn more https://platform.openai.com/docs/deprecations",
|
|
||||||
"id":
|
|
||||||
"cmpl-8HxHqF5dymQdALmLplS0dWKZVFe3r",
|
|
||||||
"object":
|
|
||||||
"text_completion",
|
|
||||||
"created":
|
|
||||||
1699290166,
|
|
||||||
"model":
|
|
||||||
"text-davinci-003",
|
|
||||||
"choices": [{
|
|
||||||
"text":
|
|
||||||
"\n\nThe weather in San Francisco varies depending on what time of year and time",
|
|
||||||
"index": 0,
|
|
||||||
"logprobs": None,
|
|
||||||
"finish_reason": "length"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 7,
|
|
||||||
"completion_tokens": 16,
|
|
||||||
"total_tokens": 23
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
|
@ -41,28 +41,6 @@ litellm --health
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Background Health Checks
|
|
||||||
|
|
||||||
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
|
|
||||||
|
|
||||||
Here's how to use it:
|
|
||||||
1. in the config.yaml add:
|
|
||||||
```
|
|
||||||
general_settings:
|
|
||||||
background_health_checks: True # enable background health checks
|
|
||||||
health_check_interval: 300 # frequency of background health checks
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start server
|
|
||||||
```
|
|
||||||
$ litellm /path/to/config.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Query health endpoint:
|
|
||||||
```
|
|
||||||
curl --location 'http://0.0.0.0:4000/health'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Embedding Models
|
### Embedding Models
|
||||||
|
|
||||||
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
|
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
|
||||||
|
@ -124,6 +102,41 @@ model_list:
|
||||||
mode: audio_transcription
|
mode: audio_transcription
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Text to Speech Models
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# OpenAI Text to Speech Models
|
||||||
|
- model_name: tts
|
||||||
|
litellm_params:
|
||||||
|
model: openai/tts-1
|
||||||
|
api_key: "os.environ/OPENAI_API_KEY"
|
||||||
|
model_info:
|
||||||
|
mode: audio_speech
|
||||||
|
```
|
||||||
|
|
||||||
|
## Background Health Checks
|
||||||
|
|
||||||
|
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
|
||||||
|
|
||||||
|
Here's how to use it:
|
||||||
|
1. in the config.yaml add:
|
||||||
|
```
|
||||||
|
general_settings:
|
||||||
|
background_health_checks: True # enable background health checks
|
||||||
|
health_check_interval: 300 # frequency of background health checks
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start server
|
||||||
|
```
|
||||||
|
$ litellm /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Query health endpoint:
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/health'
|
||||||
|
```
|
||||||
|
|
||||||
### Hide details
|
### Hide details
|
||||||
|
|
||||||
The health check response contains details like endpoint URLs, error messages,
|
The health check response contains details like endpoint URLs, error messages,
|
||||||
|
|
|
@ -31,8 +31,19 @@ model_list:
|
||||||
api_base: https://openai-france-1234.openai.azure.com/
|
api_base: https://openai-france-1234.openai.azure.com/
|
||||||
api_key: <your-azure-api-key>
|
api_key: <your-azure-api-key>
|
||||||
rpm: 1440
|
rpm: 1440
|
||||||
|
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||||
|
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
|
||||||
|
num_retries: 2
|
||||||
|
timeout: 30 # 30 seconds
|
||||||
|
redis_host: <your redis host> # set this when using multiple litellm proxy deployments, load balancing state stored in redis
|
||||||
|
redis_password: <your redis password>
|
||||||
|
redis_port: 1992
|
||||||
```
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
Detailed information about [routing strategies can be found here](../routing)
|
||||||
|
:::
|
||||||
|
|
||||||
#### Step 2: Start Proxy with config
|
#### Step 2: Start Proxy with config
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|
|
@ -175,7 +175,8 @@ const sidebars = {
|
||||||
"providers/aleph_alpha",
|
"providers/aleph_alpha",
|
||||||
"providers/baseten",
|
"providers/baseten",
|
||||||
"providers/openrouter",
|
"providers/openrouter",
|
||||||
"providers/custom_openai_proxy",
|
// "providers/custom_openai_proxy",
|
||||||
|
"providers/custom_llm_server",
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
|
|
||||||
],
|
],
|
||||||
|
|
|
@ -813,6 +813,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .types.utils import ImageObject
|
from .types.utils import ImageObject
|
||||||
|
from .llms.custom_llm import CustomLLM
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
||||||
|
@ -909,3 +910,12 @@ from .cost_calculator import response_cost_calculator, cost_per_token
|
||||||
from .types.adapter import AdapterItem
|
from .types.adapter import AdapterItem
|
||||||
|
|
||||||
adapters: List[AdapterItem] = []
|
adapters: List[AdapterItem] = []
|
||||||
|
|
||||||
|
### CUSTOM LLMs ###
|
||||||
|
from .types.llms.custom_llm import CustomLLMItem
|
||||||
|
from .types.utils import GenericStreamingChunk
|
||||||
|
|
||||||
|
custom_provider_map: List[CustomLLMItem] = []
|
||||||
|
_custom_providers: List[str] = (
|
||||||
|
[]
|
||||||
|
) # internal helper util, used to track names of custom providers
|
||||||
|
|
|
@ -1864,6 +1864,23 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
prompt=prompt, # type: ignore
|
prompt=prompt, # type: ignore
|
||||||
)
|
)
|
||||||
|
elif mode == "audio_transcription":
|
||||||
|
# Get the current directory of the file being run
|
||||||
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
|
||||||
|
audio_file = open(file_path, "rb")
|
||||||
|
completion = await client.audio.transcriptions.with_raw_response.create(
|
||||||
|
file=audio_file,
|
||||||
|
model=model, # type: ignore
|
||||||
|
prompt=prompt, # type: ignore
|
||||||
|
)
|
||||||
|
elif mode == "audio_speech":
|
||||||
|
# Get the current directory of the file being run
|
||||||
|
completion = await client.audio.speech.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
input=prompt, # type: ignore
|
||||||
|
voice="alloy",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("mode not set")
|
raise Exception("mode not set")
|
||||||
response = {}
|
response = {}
|
||||||
|
|
|
@ -79,6 +79,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
"meta.llama3-1-70b-instruct-v1:0",
|
"meta.llama3-1-70b-instruct-v1:0",
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"mistral.mistral-large-2407-v1:0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
161
litellm/llms/custom_llm.py
Normal file
161
litellm/llms/custom_llm.py
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
# What is this?
|
||||||
|
## Handler file for a Custom Chat LLM
|
||||||
|
|
||||||
|
"""
|
||||||
|
- completion
|
||||||
|
- acompletion
|
||||||
|
- streaming
|
||||||
|
- async_streaming
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
||||||
|
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLMError(Exception): # use this for all your exceptions
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
def streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> Iterator[GenericStreamingChunk]:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def astreaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
|
def custom_chat_llm_router(
|
||||||
|
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
|
||||||
|
|
||||||
|
Validates if response is in expected format
|
||||||
|
"""
|
||||||
|
if async_fn:
|
||||||
|
if stream:
|
||||||
|
return custom_llm.astreaming
|
||||||
|
return custom_llm.acompletion
|
||||||
|
if stream:
|
||||||
|
return custom_llm.streaming
|
||||||
|
return custom_llm.completion
|
|
@ -1,5 +1,6 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
|
@ -1870,8 +1871,25 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
prompt=prompt, # type: ignore
|
prompt=prompt, # type: ignore
|
||||||
)
|
)
|
||||||
|
elif mode == "audio_transcription":
|
||||||
|
# Get the current directory of the file being run
|
||||||
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
|
||||||
|
audio_file = open(file_path, "rb")
|
||||||
|
completion = await client.audio.transcriptions.with_raw_response.create(
|
||||||
|
file=audio_file,
|
||||||
|
model=model, # type: ignore
|
||||||
|
prompt=prompt, # type: ignore
|
||||||
|
)
|
||||||
|
elif mode == "audio_speech":
|
||||||
|
# Get the current directory of the file being run
|
||||||
|
completion = await client.audio.speech.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
input=prompt, # type: ignore
|
||||||
|
voice="alloy",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("mode not set")
|
raise ValueError("mode not set, passed in mode: " + mode)
|
||||||
response = {}
|
response = {}
|
||||||
|
|
||||||
if completion is None or not hasattr(completion, "headers"):
|
if completion is None or not hasattr(completion, "headers"):
|
||||||
|
|
|
@ -387,7 +387,7 @@ def process_response(
|
||||||
result = " "
|
result = " "
|
||||||
|
|
||||||
## Building RESPONSE OBJECT
|
## Building RESPONSE OBJECT
|
||||||
if len(result) > 1:
|
if len(result) >= 1:
|
||||||
model_response.choices[0].message.content = result # type: ignore
|
model_response.choices[0].message.content = result # type: ignore
|
||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
|
|
|
@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.azure import AzureChatCompletion
|
from .llms.azure import AzureChatCompletion
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||||
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks import DatabricksChatCompletion
|
from .llms.databricks import DatabricksChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
|
@ -381,6 +382,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "clarifai"
|
or custom_llm_provider == "clarifai"
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "watsonx"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
or custom_llm_provider in litellm._custom_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
if isinstance(init_response, dict) or isinstance(
|
if isinstance(init_response, dict) or isinstance(
|
||||||
|
@ -2690,6 +2692,54 @@ def completion(
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
model_response.model = model
|
model_response.model = model
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif (
|
||||||
|
custom_llm_provider in litellm._custom_providers
|
||||||
|
): # Assume custom LLM provider
|
||||||
|
# Get the Custom Handler
|
||||||
|
custom_handler: Optional[CustomLLM] = None
|
||||||
|
for item in litellm.custom_provider_map:
|
||||||
|
if item["provider"] == custom_llm_provider:
|
||||||
|
custom_handler = item["custom_handler"]
|
||||||
|
|
||||||
|
if custom_handler is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
)
|
||||||
|
|
||||||
|
## ROUTE LLM CALL ##
|
||||||
|
handler_fn = custom_chat_llm_router(
|
||||||
|
async_fn=acompletion, stream=stream, custom_llm=custom_handler
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
## CALL FUNCTION
|
||||||
|
response = handler_fn(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
timeout=timeout, # type: ignore
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
if stream is True:
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=response,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
|
|
@ -2996,6 +2996,15 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"mistral.mistral-large-2407-v1:0": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000009,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
|
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
"max_input_tokens": 32000,
|
"max_input_tokens": 32000,
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "test-model"
|
- model_name: "test-model"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "openai/gpt-3.5-turbo-instruct-0914"
|
model: "openai/text-embedding-ada-002"
|
||||||
|
- model_name: "my-custom-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "my-custom-llm/my-model"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
custom_provider_map:
|
||||||
|
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
|
21
litellm/proxy/custom_handler.py
Normal file
21
litellm/proxy/custom_handler.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import litellm
|
||||||
|
from litellm import CustomLLM, completion, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomLLM(CustomLLM):
|
||||||
|
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
my_custom_llm = MyCustomLLM()
|
|
@ -8,6 +8,12 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||||
api_key: "os.environ/FIREWORKS"
|
api_key: "os.environ/FIREWORKS"
|
||||||
|
- model_name: tts
|
||||||
|
litellm_params:
|
||||||
|
model: openai/tts-1
|
||||||
|
api_key: "os.environ/OPENAI_API_KEY"
|
||||||
|
model_info:
|
||||||
|
mode: audio_speech
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
|
|
@ -1507,6 +1507,21 @@ class ProxyConfig:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"litellm.post_call_rules: {litellm.post_call_rules}"
|
f"litellm.post_call_rules: {litellm.post_call_rules}"
|
||||||
)
|
)
|
||||||
|
elif key == "custom_provider_map":
|
||||||
|
from litellm.utils import custom_llm_setup
|
||||||
|
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{
|
||||||
|
"provider": item["provider"],
|
||||||
|
"custom_handler": get_instance_fn(
|
||||||
|
value=item["custom_handler"],
|
||||||
|
config_file_path=config_file_path,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
|
||||||
|
custom_llm_setup()
|
||||||
elif key == "success_callback":
|
elif key == "success_callback":
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
|
||||||
|
@ -3334,6 +3349,7 @@ async def embeddings(
|
||||||
if (
|
if (
|
||||||
"input" in data
|
"input" in data
|
||||||
and isinstance(data["input"], list)
|
and isinstance(data["input"], list)
|
||||||
|
and len(data["input"]) > 0
|
||||||
and isinstance(data["input"][0], list)
|
and isinstance(data["input"][0], list)
|
||||||
and isinstance(data["input"][0][0], int)
|
and isinstance(data["input"][0][0], int)
|
||||||
): # check if array of tokens passed in
|
): # check if array of tokens passed in
|
||||||
|
@ -3464,8 +3480,8 @@ async def embeddings(
|
||||||
litellm_debug_info,
|
litellm_debug_info,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}\n{}".format(
|
||||||
str(e)
|
str(e), traceback.format_exc()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
verbose_proxy_logger.debug(traceback.format_exc())
|
||||||
|
|
|
@ -263,7 +263,9 @@ class Router:
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
|
||||||
|
"local" # default to an in-memory cache
|
||||||
|
)
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
cache_config = {}
|
cache_config = {}
|
||||||
self.client_ttl = client_ttl
|
self.client_ttl = client_ttl
|
||||||
|
|
302
litellm/tests/test_custom_llm.py
Normal file
302
litellm/tests/test_custom_llm.py
Normal file
|
@ -0,0 +1,302 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for the CustomLLM class
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterator,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import (
|
||||||
|
ChatCompletionDeltaChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
CustomLLM,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
ModelResponse,
|
||||||
|
acompletion,
|
||||||
|
completion,
|
||||||
|
get_llm_provider,
|
||||||
|
)
|
||||||
|
from litellm.utils import ModelResponseIterator
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelResponseIterator:
|
||||||
|
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="hello world",
|
||||||
|
tool_use=None,
|
||||||
|
is_finished=True,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
chunk: Any = self.streaming_response.__next__() # type: ignore
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.chunk_parser(chunk=chunk)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
||||||
|
return self.streaming_response
|
||||||
|
|
||||||
|
async def __anext__(self) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.chunk_parser(chunk=chunk)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomLLM(CustomLLM):
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.HTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
async def acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||||
|
) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
def streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.HTTPHandler] = None,
|
||||||
|
) -> Iterator[GenericStreamingChunk]:
|
||||||
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"is_finished": True,
|
||||||
|
"text": "Hello world",
|
||||||
|
"tool_use": None,
|
||||||
|
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
model_response=generic_streaming_chunk # type: ignore
|
||||||
|
)
|
||||||
|
custom_iterator = CustomModelResponseIterator(
|
||||||
|
streaming_response=completion_stream
|
||||||
|
)
|
||||||
|
return custom_iterator
|
||||||
|
|
||||||
|
async def astreaming( # type: ignore
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||||
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"is_finished": True,
|
||||||
|
"text": "Hello world",
|
||||||
|
"tool_use": None,
|
||||||
|
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
yield generic_streaming_chunk # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_provider():
|
||||||
|
""""""
|
||||||
|
from litellm.utils import custom_llm_setup
|
||||||
|
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
|
||||||
|
custom_llm_setup()
|
||||||
|
|
||||||
|
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
|
||||||
|
|
||||||
|
assert provider == "custom_llm"
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_completion():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = completion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Hi!"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_acompletion():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = await acompletion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Hi!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_completion_streaming():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = completion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in resp:
|
||||||
|
print(chunk)
|
||||||
|
if chunk.choices[0].finish_reason is None:
|
||||||
|
assert isinstance(chunk.choices[0].delta.content, str)
|
||||||
|
else:
|
||||||
|
assert chunk.choices[0].finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_completion_async_streaming():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = await litellm.acompletion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in resp:
|
||||||
|
print(chunk)
|
||||||
|
if chunk.choices[0].finish_reason is None:
|
||||||
|
assert isinstance(chunk.choices[0].delta.content, str)
|
||||||
|
else:
|
||||||
|
assert chunk.choices[0].finish_reason == "stop"
|
10
litellm/types/llms/custom_llm.py
Normal file
10
litellm/types/llms/custom_llm.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
from litellm.llms.custom_llm import CustomLLM
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLMItem(TypedDict):
|
||||||
|
provider: str
|
||||||
|
custom_handler: CustomLLM
|
|
@ -330,6 +330,18 @@ class Rules:
|
||||||
|
|
||||||
####### CLIENT ###################
|
####### CLIENT ###################
|
||||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||||
|
def custom_llm_setup():
|
||||||
|
"""
|
||||||
|
Add custom_llm provider to provider list
|
||||||
|
"""
|
||||||
|
for custom_llm in litellm.custom_provider_map:
|
||||||
|
if custom_llm["provider"] not in litellm.provider_list:
|
||||||
|
litellm.provider_list.append(custom_llm["provider"])
|
||||||
|
|
||||||
|
if custom_llm["provider"] not in litellm._custom_providers:
|
||||||
|
litellm._custom_providers.append(custom_llm["provider"])
|
||||||
|
|
||||||
|
|
||||||
def function_setup(
|
def function_setup(
|
||||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||||
|
@ -341,6 +353,10 @@ def function_setup(
|
||||||
try:
|
try:
|
||||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||||
|
|
||||||
|
## CUSTOM LLM SETUP ##
|
||||||
|
custom_llm_setup()
|
||||||
|
|
||||||
|
## LOGGING SETUP
|
||||||
function_id = kwargs["id"] if "id" in kwargs else None
|
function_id = kwargs["id"] if "id" in kwargs else None
|
||||||
|
|
||||||
if len(litellm.callbacks) > 0:
|
if len(litellm.callbacks) > 0:
|
||||||
|
@ -9247,7 +9263,10 @@ class CustomStreamWrapper:
|
||||||
try:
|
try:
|
||||||
# return this for all models
|
# return this for all models
|
||||||
completion_obj = {"content": ""}
|
completion_obj = {"content": ""}
|
||||||
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
if self.custom_llm_provider and (
|
||||||
|
self.custom_llm_provider == "anthropic"
|
||||||
|
or self.custom_llm_provider in litellm._custom_providers
|
||||||
|
):
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
|
@ -10114,6 +10133,7 @@ class CustomStreamWrapper:
|
||||||
try:
|
try:
|
||||||
if self.completion_stream is None:
|
if self.completion_stream is None:
|
||||||
await self.fetch_stream()
|
await self.fetch_stream()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.custom_llm_provider == "openai"
|
self.custom_llm_provider == "openai"
|
||||||
or self.custom_llm_provider == "azure"
|
or self.custom_llm_provider == "azure"
|
||||||
|
@ -10138,6 +10158,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "triton"
|
or self.custom_llm_provider == "triton"
|
||||||
or self.custom_llm_provider == "watsonx"
|
or self.custom_llm_provider == "watsonx"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
|
or self.custom_llm_provider in litellm._custom_providers
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
print_verbose(f"value of async chunk: {chunk}")
|
print_verbose(f"value of async chunk: {chunk}")
|
||||||
|
@ -10966,3 +10987,8 @@ class ModelResponseIterator:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
self.is_done = True
|
self.is_done = True
|
||||||
return self.model_response
|
return self.model_response
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelResponseIterator(Iterable):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
|
@ -2996,6 +2996,15 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"mistral.mistral-large-2407-v1:0": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000009,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
|
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
"max_input_tokens": 32000,
|
"max_input_tokens": 32000,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue