Merge branch 'main' into bedrock-llama3.1-405b

This commit is contained in:
Krish Dholakia 2024-07-25 19:29:10 -07:00 committed by GitHub
commit b6ca4406b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 888 additions and 159 deletions

View file

@ -0,0 +1,168 @@
# Custom API Server (Custom Format)
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
:::info
For calling an openai-compatible endpoint, [go here](./openai_compatible.md)
:::
## Quick Start
```python
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
{"provider": "my-custom-llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="my-custom-llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
```
## OpenAI Proxy Usage
1. Setup your `custom_handler.py` file
```python
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()
```
2. Add to `config.yaml`
In the config below, we pass
python_filename: `custom_handler.py`
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
custom_handler: `custom_handler.my_custom_llm`
```yaml
model_list:
- model_name: "test-model"
litellm_params:
model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
```
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "my-custom-model",
"messages": [{"role": "user", "content": "Say \"this is a test\" in JSON!"}],
}'
```
Expected Response
```
{
"id": "chatcmpl-06f1b9cd-08bc-43f7-9814-a69173921216",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "Hi!",
"role": "assistant",
"tool_calls": null,
"function_call": null
}
}
],
"created": 1721955063,
"model": "gpt-3.5-turbo",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
```
## Custom Handler Spec
```python
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from typing import Iterator, AsyncIterator
from litellm.llms.base import BaseLLM
class CustomLLMError(Exception): # use this for all your exceptions
def __init__(
self,
status_code,
message,
):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
```

View file

@ -1,129 +0,0 @@
# Custom API Server (OpenAI Format)
LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format
## API KEYS
No api keys required
## Set up your Custom API Server
Your server should have the following Endpoints:
Here's an example OpenAI proxy server with routes: https://replit.com/@BerriAI/openai-proxy#main.py
### Required Endpoints
- POST `/chat/completions` - chat completions endpoint
### Optional Endpoints
- POST `/completions` - completions endpoint
- Get `/models` - available models on server
- POST `/embeddings` - creates an embedding vector representing the input text.
## Example Usage
### Call `/chat/completions`
In order to use your custom OpenAI Chat Completion proxy with LiteLLM, ensure you set
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
* `custom_llm_provider` to `openai` this ensures litellm uses the `openai.ChatCompletion` to your api_base
```python
import os
from litellm import completion
## set ENV variables
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
messages = [{ "content": "Hello, how are you?","role": "user"}]
response = completion(
model="command-nightly",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://openai-proxy.berriai.repl.co",
custom_llm_provider="openai" # litellm will use the openai.ChatCompletion to make the request
)
print(response)
```
#### Response
```json
{
"object":
"chat.completion",
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content":
"The sky, a canvas of blue,\nA work of art, pure and true,\nA",
"role": "assistant"
}
}],
"id":
"chatcmpl-7fbd6077-de10-4cb4-a8a4-3ef11a98b7c8",
"created":
1699290237.408061,
"model":
"togethercomputer/llama-2-70b-chat",
"usage": {
"completion_tokens": 18,
"prompt_tokens": 14,
"total_tokens": 32
}
}
```
### Call `/completions`
In order to use your custom OpenAI Completion proxy with LiteLLM, ensure you set
* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co"
* `custom_llm_provider` to `text-completion-openai` this ensures litellm uses the `openai.Completion` to your api_base
```python
import os
from litellm import completion
## set ENV variables
os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy
messages = [{ "content": "Hello, how are you?","role": "user"}]
response = completion(
model="command-nightly",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://openai-proxy.berriai.repl.co",
custom_llm_provider="text-completion-openai" # litellm will use the openai.Completion to make the request
)
print(response)
```
#### Response
```json
{
"warning":
"This model version is deprecated. Migrate before January 4, 2024 to avoid disruption of service. Learn more https://platform.openai.com/docs/deprecations",
"id":
"cmpl-8HxHqF5dymQdALmLplS0dWKZVFe3r",
"object":
"text_completion",
"created":
1699290166,
"model":
"text-davinci-003",
"choices": [{
"text":
"\n\nThe weather in San Francisco varies depending on what time of year and time",
"index": 0,
"logprobs": None,
"finish_reason": "length"
}],
"usage": {
"prompt_tokens": 7,
"completion_tokens": 16,
"total_tokens": 23
}
}
```

View file

@ -41,28 +41,6 @@ litellm --health
}
```
### Background Health Checks
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
Here's how to use it:
1. in the config.yaml add:
```
general_settings:
background_health_checks: True # enable background health checks
health_check_interval: 300 # frequency of background health checks
```
2. Start server
```
$ litellm /path/to/config.yaml
```
3. Query health endpoint:
```
curl --location 'http://0.0.0.0:4000/health'
```
### Embedding Models
We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check
@ -124,6 +102,41 @@ model_list:
mode: audio_transcription
```
### Text to Speech Models
```yaml
# OpenAI Text to Speech Models
- model_name: tts
litellm_params:
model: openai/tts-1
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
```
## Background Health Checks
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
Here's how to use it:
1. in the config.yaml add:
```
general_settings:
background_health_checks: True # enable background health checks
health_check_interval: 300 # frequency of background health checks
```
2. Start server
```
$ litellm /path/to/config.yaml
```
3. Query health endpoint:
```
curl --location 'http://0.0.0.0:4000/health'
```
### Hide details
The health check response contains details like endpoint URLs, error messages,

View file

@ -31,8 +31,19 @@ model_list:
api_base: https://openai-france-1234.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 1440
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
num_retries: 2
timeout: 30 # 30 seconds
redis_host: <your redis host> # set this when using multiple litellm proxy deployments, load balancing state stored in redis
redis_password: <your redis password>
redis_port: 1992
```
:::info
Detailed information about [routing strategies can be found here](../routing)
:::
#### Step 2: Start Proxy with config
```shell

View file

@ -175,7 +175,8 @@ const sidebars = {
"providers/aleph_alpha",
"providers/baseten",
"providers/openrouter",
"providers/custom_openai_proxy",
// "providers/custom_openai_proxy",
"providers/custom_llm_server",
"providers/petals",
],

View file

@ -813,6 +813,7 @@ from .utils import (
)
from .types.utils import ImageObject
from .llms.custom_llm import CustomLLM
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
@ -909,3 +910,12 @@ from .cost_calculator import response_cost_calculator, cost_per_token
from .types.adapter import AdapterItem
adapters: List[AdapterItem] = []
### CUSTOM LLMs ###
from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers

View file

@ -1864,6 +1864,23 @@ class AzureChatCompletion(BaseLLM):
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else:
raise Exception("mode not set")
response = {}

View file

@ -79,6 +79,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
]

161
litellm/llms/custom_llm.py Normal file
View file

@ -0,0 +1,161 @@
# What is this?
## Handler file for a Custom Chat LLM
"""
- completion
- acompletion
- streaming
- async_streaming
"""
import copy
import json
import os
import time
import types
from enum import Enum
from functools import partial
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class CustomLLMError(Exception): # use this for all your exceptions
def __init__(
self,
status_code,
message,
):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def custom_chat_llm_router(
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
):
"""
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
Validates if response is in expected format
"""
if async_fn:
if stream:
return custom_llm.astreaming
return custom_llm.acompletion
if stream:
return custom_llm.streaming
return custom_llm.completion

View file

@ -1,5 +1,6 @@
import hashlib
import json
import os
import time
import traceback
import types
@ -1870,8 +1871,25 @@ class OpenAIChatCompletion(BaseLLM):
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else:
raise Exception("mode not set")
raise ValueError("mode not set, passed in mode: " + mode)
response = {}
if completion is None or not hasattr(completion, "headers"):

View file

@ -387,7 +387,7 @@ def process_response(
result = " "
## Building RESPONSE OBJECT
if len(result) > 1:
if len(result) >= 1:
model_response.choices[0].message.content = result # type: ignore
# Calculate usage

View file

@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -381,6 +382,7 @@ async def acompletion(
or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
@ -2690,6 +2692,54 @@ def completion(
model_response.created = int(time.time())
model_response.model = model
response = model_response
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
## ROUTE LLM CALL ##
handler_fn = custom_chat_llm_router(
async_fn=acompletion, stream=stream, custom_llm=custom_handler
)
headers = headers or litellm.headers
## CALL FUNCTION
response = handler_fn(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
)
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging,
)
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"

View file

@ -2996,6 +2996,15 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"mistral.mistral-large-2407-v1:0": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "bedrock",
"mode": "chat"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
"max_tokens": 8191,
"max_input_tokens": 32000,

View file

@ -1,4 +1,11 @@
model_list:
- model_name: "test-model"
litellm_params:
model: "openai/gpt-3.5-turbo-instruct-0914"
model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}

View file

@ -0,0 +1,21 @@
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()

View file

@ -8,6 +8,12 @@ model_list:
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
- model_name: tts
litellm_params:
model: openai/tts-1
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
general_settings:
master_key: sk-1234
alerting: ["slack"]

View file

@ -1507,6 +1507,21 @@ class ProxyConfig:
verbose_proxy_logger.debug(
f"litellm.post_call_rules: {litellm.post_call_rules}"
)
elif key == "custom_provider_map":
from litellm.utils import custom_llm_setup
litellm.custom_provider_map = [
{
"provider": item["provider"],
"custom_handler": get_instance_fn(
value=item["custom_handler"],
config_file_path=config_file_path,
),
}
for item in value
]
custom_llm_setup()
elif key == "success_callback":
litellm.success_callback = []
@ -3334,6 +3349,7 @@ async def embeddings(
if (
"input" in data
and isinstance(data["input"], list)
and len(data["input"]) > 0
and isinstance(data["input"][0], list)
and isinstance(data["input"][0][0], int)
): # check if array of tokens passed in
@ -3464,8 +3480,8 @@ async def embeddings(
litellm_debug_info,
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
str(e)
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
verbose_proxy_logger.debug(traceback.format_exc())

View file

@ -263,7 +263,9 @@ class Router:
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
"local" # default to an in-memory cache
)
redis_cache = None
cache_config = {}
self.client_ttl = client_ttl

View file

@ -0,0 +1,302 @@
# What is this?
## Unit tests for the CustomLLM class
import asyncio
import os
import sys
import time
import traceback
import openai
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
Optional,
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import (
ChatCompletionDeltaChunk,
ChatCompletionUsageBlock,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
acompletion,
completion,
get_llm_provider,
)
from litellm.utils import ModelResponseIterator
class CustomModelResponseIterator:
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="hello world",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=ChatCompletionUsageBlock(
prompt_tokens=10, completion_tokens=20, total_tokens=30
),
index=0,
)
# Sync iterator
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk: Any = self.streaming_response.__next__() # type: ignore
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self.streaming_response
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MyCustomLLM(CustomLLM):
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
completion_stream = ModelResponseIterator(
model_response=generic_streaming_chunk # type: ignore
)
custom_iterator = CustomModelResponseIterator(
streaming_response=completion_stream
)
return custom_iterator
async def astreaming( # type: ignore
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
yield generic_streaming_chunk # type: ignore
def test_get_llm_provider():
""""""
from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
custom_llm_setup()
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
assert provider == "custom_llm"
def test_simple_completion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
@pytest.mark.asyncio
async def test_simple_acompletion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
def test_simple_completion_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"
@pytest.mark.asyncio
async def test_simple_completion_async_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
async for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"

View file

@ -0,0 +1,10 @@
from typing import List
from typing_extensions import Dict, Required, TypedDict, override
from litellm.llms.custom_llm import CustomLLM
class CustomLLMItem(TypedDict):
provider: str
custom_handler: CustomLLM

View file

@ -330,6 +330,18 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
if custom_llm["provider"] not in litellm._custom_providers:
litellm._custom_providers.append(custom_llm["provider"])
def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -341,6 +353,10 @@ def function_setup(
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
## CUSTOM LLM SETUP ##
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
@ -9247,7 +9263,10 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
@ -10114,6 +10133,7 @@ class CustomStreamWrapper:
try:
if self.completion_stream is None:
await self.fetch_stream()
if (
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"
@ -10138,6 +10158,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
or self.custom_llm_provider in litellm._custom_providers
):
async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}")
@ -10966,3 +10987,8 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()

View file

@ -2996,6 +2996,15 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"mistral.mistral-large-2407-v1:0": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "bedrock",
"mode": "chat"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
"max_tokens": 8191,
"max_input_tokens": 32000,