Merge branch 'main' into litellm_redis_team_object

This commit is contained in:
Krish Dholakia 2024-07-25 19:31:52 -07:00 committed by GitHub
commit c2086300b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1182 additions and 232 deletions

View file

@ -813,6 +813,7 @@ from .utils import (
)
from .types.utils import ImageObject
from .llms.custom_llm import CustomLLM
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
@ -909,3 +910,12 @@ from .cost_calculator import response_cost_calculator, cost_per_token
from .types.adapter import AdapterItem
adapters: List[AdapterItem] = []
### CUSTOM LLMs ###
from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers

View file

@ -1864,6 +1864,23 @@ class AzureChatCompletion(BaseLLM):
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else:
raise Exception("mode not set")
response = {}

View file

@ -78,6 +78,8 @@ BEDROCK_CONVERSE_MODELS = [
"ai21.jamba-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
]
@ -1315,6 +1317,7 @@ class AmazonConverseConfig:
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
or model.startswith("meta.llama3-1")
):
supported_params.append("tools")

161
litellm/llms/custom_llm.py Normal file
View file

@ -0,0 +1,161 @@
# What is this?
## Handler file for a Custom Chat LLM
"""
- completion
- acompletion
- streaming
- async_streaming
"""
import copy
import json
import os
import time
import types
from enum import Enum
from functools import partial
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class CustomLLMError(Exception): # use this for all your exceptions
def __init__(
self,
status_code,
message,
):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def custom_chat_llm_router(
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
):
"""
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
Validates if response is in expected format
"""
if async_fn:
if stream:
return custom_llm.astreaming
return custom_llm.acompletion
if stream:
return custom_llm.streaming
return custom_llm.completion

View file

@ -1,5 +1,6 @@
import hashlib
import json
import os
import time
import traceback
import types
@ -1870,8 +1871,25 @@ class OpenAIChatCompletion(BaseLLM):
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
else:
raise Exception("mode not set")
raise ValueError("mode not set, passed in mode: " + mode)
response = {}
if completion is None or not hasattr(completion, "headers"):

View file

@ -387,7 +387,7 @@ def process_response(
result = " "
## Building RESPONSE OBJECT
if len(result) > 1:
if len(result) >= 1:
model_response.choices[0].message.content = result # type: ignore
# Calculate usage

View file

@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -381,6 +382,7 @@ async def acompletion(
or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
@ -2690,6 +2692,54 @@ def completion(
model_response.created = int(time.time())
model_response.model = model
response = model_response
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
## ROUTE LLM CALL ##
handler_fn = custom_chat_llm_router(
async_fn=acompletion, stream=stream, custom_llm=custom_handler
)
headers = headers or litellm.headers
## CALL FUNCTION
response = handler_fn(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
)
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging,
)
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
@ -3833,7 +3883,7 @@ def text_completion(
optional_params["custom_llm_provider"] = custom_llm_provider
# get custom_llm_provider
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3
@ -3916,10 +3966,12 @@ def text_completion(
kwargs.pop("prompt", None)
if model is not None and model.startswith(
"openai/"
if (
_model is not None and custom_llm_provider == "openai"
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
model = model.replace("openai/", "text-completion-openai/")
if _model not in litellm.open_ai_chat_completion_models:
model = "text-completion-openai/" + _model
optional_params.pop("custom_llm_provider", None)
kwargs["text_completion"] = True
response = completion(

View file

@ -893,11 +893,11 @@
"mode": "chat"
},
"mistral/mistral-large-latest": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000004,
"output_cost_per_token": 0.000012,
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 128000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
@ -912,6 +912,16 @@
"mode": "chat",
"supports_function_calling": true
},
"mistral/mistral-large-2407": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 128000,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "mistral",
"mode": "chat",
"supports_function_calling": true
},
"mistral/open-mistral-7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
@ -1094,6 +1104,36 @@
"mode": "chat",
"supports_function_calling": true
},
"groq/llama-3.1-8b-instant": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000059,
"output_cost_per_token": 0.00000079,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/llama-3.1-70b-versatile": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000059,
"output_cost_per_token": 0.00000079,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/llama-3.1-405b-reasoning": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000059,
"output_cost_per_token": 0.00000079,
"litellm_provider": "groq",
"mode": "chat",
"supports_function_calling": true
},
"groq/mixtral-8x7b-32768": {
"max_tokens": 32768,
"max_input_tokens": 32768,
@ -2956,6 +2996,15 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"mistral.mistral-large-2407-v1:0": {
"max_tokens": 8191,
"max_input_tokens": 128000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000009,
"litellm_provider": "bedrock",
"mode": "chat"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": {
"max_tokens": 8191,
"max_input_tokens": 32000,
@ -3691,6 +3740,15 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"meta.llama3-1-405b-instruct-v1:0": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000532,
"output_cost_per_token": 0.000016,
"litellm_provider": "bedrock",
"mode": "chat"
},
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,

View file

@ -1,9 +1,11 @@
model_list:
- model_name: "*" # all requests where model not in your config go to this deployment
- model_name: "test-model"
litellm_params:
model: "openai/*" # passes our validation check that a real provider is given
api_key: ""
model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
cache: True
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}

View file

@ -0,0 +1,21 @@
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()

View file

@ -27,6 +27,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.key_management_endpoints import (
_duration_in_seconds,
generate_key_helper_fn,
)
from litellm.proxy.management_helpers.utils import (
@ -486,6 +487,13 @@ async def user_update(
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
if "budget_duration" in non_default_values:
duration_s = _duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
if data.user_id is not None and len(data.user_id) > 0:

View file

@ -8,6 +8,12 @@ model_list:
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
- model_name: tts
litellm_params:
model: openai/tts-1
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
general_settings:
master_key: sk-1234
alerting: ["slack"]

View file

@ -1507,6 +1507,21 @@ class ProxyConfig:
verbose_proxy_logger.debug(
f"litellm.post_call_rules: {litellm.post_call_rules}"
)
elif key == "custom_provider_map":
from litellm.utils import custom_llm_setup
litellm.custom_provider_map = [
{
"provider": item["provider"],
"custom_handler": get_instance_fn(
value=item["custom_handler"],
config_file_path=config_file_path,
),
}
for item in value
]
custom_llm_setup()
elif key == "success_callback":
litellm.success_callback = []
@ -3334,6 +3349,7 @@ async def embeddings(
if (
"input" in data
and isinstance(data["input"], list)
and len(data["input"]) > 0
and isinstance(data["input"][0], list)
and isinstance(data["input"][0][0], int)
): # check if array of tokens passed in
@ -3464,8 +3480,8 @@ async def embeddings(
litellm_debug_info,
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
str(e)
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
verbose_proxy_logger.debug(traceback.format_exc())

View file

@ -263,7 +263,9 @@ class Router:
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
"local" # default to an in-memory cache
)
redis_cache = None
cache_config = {}
self.client_ttl = client_ttl

View file

@ -2573,21 +2573,17 @@ def test_completion_azure_extra_headers():
http_client = Client()
with patch.object(http_client, "send", new=MagicMock()) as mock_client:
client = AzureOpenAI(
azure_endpoint=os.getenv("AZURE_API_BASE"),
api_version=litellm.AZURE_DEFAULT_API_VERSION,
api_key=os.getenv("AZURE_API_KEY"),
http_client=http_client,
)
litellm.client_session = http_client
try:
response = completion(
model="azure/chatgpt-v-2",
messages=messages,
client=client,
api_base=os.getenv("AZURE_API_BASE"),
api_version="2023-07-01-preview",
api_key=os.getenv("AZURE_API_KEY"),
extra_headers={
"Authorization": "my-bad-key",
"Ocp-Apim-Subscription-Key": "hello-world-testing",
"api-key": "my-bad-key",
},
)
print(response)
@ -2603,8 +2599,10 @@ def test_completion_azure_extra_headers():
print(request.url) # This will print the full URL
print(request.headers) # This will print the full URL
auth_header = request.headers.get("Authorization")
apim_key = request.headers.get("Ocp-Apim-Subscription-Key")
print(auth_header)
assert auth_header == "my-bad-key"
assert apim_key == "hello-world-testing"
def test_completion_azure_ad_token():
@ -2613,18 +2611,37 @@ def test_completion_azure_ad_token():
# If you want to remove it, speak to Ishaan!
# Ishaan will be very disappointed if this test is removed -> this is a standard way to pass api_key + the router + proxy use this
from httpx import Client
from openai import AzureOpenAI
from litellm import completion
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
response = completion(
model="azure/chatgpt-v-2",
messages=messages,
# api_key="my-fake-ad-token",
azure_ad_token=os.getenv("AZURE_API_KEY"),
)
print(response)
litellm.set_verbose = True
old_key = os.environ["AZURE_API_KEY"]
os.environ.pop("AZURE_API_KEY", None)
http_client = Client()
with patch.object(http_client, "send", new=MagicMock()) as mock_client:
litellm.client_session = http_client
try:
response = completion(
model="azure/chatgpt-v-2",
messages=messages,
azure_ad_token="my-special-token",
)
print(response)
except Exception as e:
pass
finally:
os.environ["AZURE_API_KEY"] = old_key
mock_client.assert_called_once()
request = mock_client.call_args[0][0]
print(request.method) # This will print 'POST'
print(request.url) # This will print the full URL
print(request.headers) # This will print the full URL
auth_header = request.headers.get("Authorization")
assert auth_header == "Bearer my-special-token"
def test_completion_azure_key_completion_arg():

View file

@ -0,0 +1,302 @@
# What is this?
## Unit tests for the CustomLLM class
import asyncio
import os
import sys
import time
import traceback
import openai
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
Optional,
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import (
ChatCompletionDeltaChunk,
ChatCompletionUsageBlock,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
acompletion,
completion,
get_llm_provider,
)
from litellm.utils import ModelResponseIterator
class CustomModelResponseIterator:
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="hello world",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=ChatCompletionUsageBlock(
prompt_tokens=10, completion_tokens=20, total_tokens=30
),
index=0,
)
# Sync iterator
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk: Any = self.streaming_response.__next__() # type: ignore
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self.streaming_response
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MyCustomLLM(CustomLLM):
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
completion_stream = ModelResponseIterator(
model_response=generic_streaming_chunk # type: ignore
)
custom_iterator = CustomModelResponseIterator(
streaming_response=completion_stream
)
return custom_iterator
async def astreaming( # type: ignore
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
yield generic_streaming_chunk # type: ignore
def test_get_llm_provider():
""""""
from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
custom_llm_setup()
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
assert provider == "custom_llm"
def test_simple_completion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
@pytest.mark.asyncio
async def test_simple_acompletion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
def test_simple_completion_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"
@pytest.mark.asyncio
async def test_simple_completion_async_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
async for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"

View file

@ -206,6 +206,9 @@ def test_openai_azure_embedding_with_oidc_and_cf():
os.environ["AZURE_TENANT_ID"] = "17c0a27a-1246-4aa1-a3b6-d294e80e783c"
os.environ["AZURE_CLIENT_ID"] = "4faf5422-b2bd-45e8-a6d7-46543a38acd0"
old_key = os.environ["AZURE_API_KEY"]
os.environ.pop("AZURE_API_KEY", None)
try:
response = embedding(
model="azure/text-embedding-ada-002",
@ -218,6 +221,8 @@ def test_openai_azure_embedding_with_oidc_and_cf():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
finally:
os.environ["AZURE_API_KEY"] = old_key
def test_openai_azure_embedding_optional_arg(mocker):
@ -673,17 +678,3 @@ async def test_databricks_embeddings(sync_mode):
# print(response)
# local_proxy_embeddings()
def test_embedding_azure_ad_token():
# this tests if we can pass api_key to completion, when it's not in the env.
# DO NOT REMOVE THIS TEST. No MATTER WHAT Happens!
# If you want to remove it, speak to Ishaan!
# Ishaan will be very disappointed if this test is removed -> this is a standard way to pass api_key + the router + proxy use this
response = embedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
azure_ad_token=os.getenv("AZURE_API_KEY"),
)
print(response)

View file

@ -1,14 +1,18 @@
import sys, os
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
@ -21,6 +25,12 @@ def test_get_llm_provider():
# test_get_llm_provider()
def test_get_llm_provider_gpt_instruct():
_, response, _, _ = litellm.get_llm_provider(model="gpt-3.5-turbo-instruct-0914")
assert response == "text-completion-openai"
def test_get_llm_provider_mistral_custom_api_base():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="mistral/mistral-large-fr",

View file

@ -3840,7 +3840,26 @@ def test_completion_chatgpt_prompt():
try:
print("\n gpt3.5 test\n")
response = text_completion(
model="gpt-3.5-turbo", prompt="What's the weather in SF?"
model="openai/gpt-3.5-turbo", prompt="What's the weather in SF?"
)
print(response)
response_str = response["choices"][0]["text"]
print("\n", response.choices)
print("\n", response.choices[0])
# print(response.choices[0].text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_chatgpt_prompt()
def test_completion_gpt_instruct():
try:
response = text_completion(
model="gpt-3.5-turbo-instruct-0914",
prompt="What's the weather in SF?",
custom_llm_provider="openai",
)
print(response)
response_str = response["choices"][0]["text"]

View file

@ -0,0 +1,10 @@
from typing import List
from typing_extensions import Dict, Required, TypedDict, override
from litellm.llms.custom_llm import CustomLLM
class CustomLLMItem(TypedDict):
provider: str
custom_handler: CustomLLM

View file

@ -330,6 +330,18 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
if custom_llm["provider"] not in litellm._custom_providers:
litellm._custom_providers.append(custom_llm["provider"])
def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -341,6 +353,10 @@ def function_setup(
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
## CUSTOM LLM SETUP ##
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
@ -2774,7 +2790,7 @@ def get_optional_params(
tool_function["parameters"] = new_parameters
def _check_valid_arg(supported_params):
verbose_logger.debug(
verbose_logger.info(
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
)
verbose_logger.debug(
@ -3121,7 +3137,19 @@ def get_optional_params(
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if "ai21" in model:
if model in litellm.BEDROCK_CONVERSE_MODELS:
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "ai21" in model:
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -3143,17 +3171,6 @@ def get_optional_params(
optional_params=optional_params,
)
)
elif model in litellm.BEDROCK_CONVERSE_MODELS:
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
@ -4486,7 +4503,11 @@ def get_llm_provider(
or get_secret("TOGETHER_AI_TOKEN")
)
elif custom_llm_provider == "friendliai":
api_base = "https://inference.friendli.ai/v1"
api_base = (
api_base
or get_secret("FRIENDLI_API_BASE")
or "https://inference.friendli.ai/v1"
)
dynamic_api_key = (
api_key
or get_secret("FRIENDLIAI_API_KEY")
@ -9242,7 +9263,10 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
@ -10109,6 +10133,7 @@ class CustomStreamWrapper:
try:
if self.completion_stream is None:
await self.fetch_stream()
if (
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"
@ -10133,6 +10158,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
or self.custom_llm_provider in litellm._custom_providers
):
async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}")
@ -10961,3 +10987,8 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()