fixed acompletion import bug

This commit is contained in:
Krrish Dholakia 2023-08-03 14:06:32 -07:00
parent 82a75a9d92
commit 123de53475
18 changed files with 53 additions and 24 deletions

View file

@ -3,13 +3,22 @@ failure_callback = []
set_verbose=False set_verbose=False
telemetry=True telemetry=True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
retry = True # control tenacity retries.
####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone ####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone
api_base = None api_base = None
headers = None headers = None
####### COMPLETION MODELS ################### ####### COMPLETION MODELS ###################
open_ai_chat_completion_models = [ open_ai_chat_completion_models = [
"gpt-4",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0613",
#################
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
'gpt-3.5-turbo', 'gpt-3.5-turbo',
'gpt-4',
'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-16k-0613',
'gpt-3.5-turbo-16k' 'gpt-3.5-turbo-16k'
] ]
@ -32,7 +41,6 @@ model_list = open_ai_chat_completion_models + open_ai_text_completion_models + c
open_ai_embedding_models = [ open_ai_embedding_models = [
'text-embedding-ada-002' 'text-embedding-ada-002'
] ]
from .timeout import timeout from .timeout import timeout
from .utils import client, logging, exception_type # Import all the symbols from main.py from .utils import client, logging, exception_type # Import all the symbols from main.py
from .main import * # Import all the symbols from main.py from .main import * # Import all the symbols from main.py

View file

@ -2,11 +2,18 @@ import os, openai, cohere, replicate, sys
from typing import Any from typing import Any
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
import traceback import traceback
from functools import partial
import dotenv import dotenv
import traceback import traceback
import litellm import litellm
from litellm import client, logging, exception_type, timeout, success_callback, failure_callback from litellm import client, logging, exception_type, timeout, success_callback, failure_callback
import random import random
import asyncio
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -24,6 +31,7 @@ def get_optional_params(
frequency_penalty = 0, frequency_penalty = 0,
logit_bias = {}, logit_bias = {},
user = "", user = "",
deployment_id = None
): ):
optional_params = {} optional_params = {}
if functions != []: if functions != []:
@ -50,27 +58,39 @@ def get_optional_params(
optional_params["logit_bias"] = logit_bias optional_params["logit_bias"] = logit_bias
if user != "": if user != "":
optional_params["user"] = user optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = user
return optional_params return optional_params
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
############################################# #############################################
async def acompletion(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
# Call the synchronous function using run_in_executor
return await loop.run_in_executor(None, func)
@client @client
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def completion( def completion(
model, messages, # required params model, messages, # required params
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
functions=[], function_call="", # optional params functions=[], function_call="", # optional params
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'), temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params # Optional liteLLM function params
*, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False *, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
): ):
try: try:
# check if user passed in any of the OpenAI optional params # check if user passed in any of the OpenAI optional params
optional_params = get_optional_params( optional_params = get_optional_params(
functions=functions, function_call=function_call, functions=functions, function_call=function_call,
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id
) )
if azure == True: if azure == True:
# azure configs # azure configs
@ -247,7 +267,6 @@ def completion(
## Map to OpenAI Exception ## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e) raise exception_type(model=model, original_exception=e)
### EMBEDDING ENDPOINTS #################### ### EMBEDDING ENDPOINTS ####################
@client @client
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`

View file

@ -37,26 +37,29 @@ def timeout(
thread = _LoopWrapper() thread = _LoopWrapper()
thread.start() thread.start()
future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop) future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop)
local_timeout_duration = timeout_duration
if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"]
try: try:
local_timeout_duration = timeout_duration
if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"]
result = future.result(timeout=local_timeout_duration) result = future.result(timeout=local_timeout_duration)
except futures.TimeoutError: except futures.TimeoutError:
thread.stop_loop() thread.stop_loop()
raise exception_to_raise() raise exception_to_raise(f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).")
thread.stop_loop() thread.stop_loop()
return result return result
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
local_timeout_duration = timeout_duration
if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"]
try: try:
value = await asyncio.wait_for( value = await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_duration func(*args, **kwargs), timeout=timeout_duration
) )
return value return value
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise exception_to_raise() raise exception_to_raise(f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).")
if iscoroutinefunction(func): if iscoroutinefunction(func):
return async_wrapper return async_wrapper

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-0.1.229-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.229.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.2291-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.2291.tar.gz vendored Normal file

Binary file not shown.

View file

@ -1,6 +1,6 @@
Metadata-Version: 2.1 Metadata-Version: 2.1
Name: litellm Name: litellm
Version: 0.1.226 Version: 0.1.2291
Summary: Library to easily interface with LLM API providers Summary: Library to easily interface with LLM API providers
Author: BerriAI Author: BerriAI
License-File: LICENSE License-File: LICENSE

View file

@ -5,3 +5,4 @@ anthropic
replicate replicate
python-dotenv python-dotenv
openai[datalib] openai[datalib]
tenacity

View file

@ -41,7 +41,6 @@ model_list = open_ai_chat_completion_models + open_ai_text_completion_models + c
open_ai_embedding_models = [ open_ai_embedding_models = [
'text-embedding-ada-002' 'text-embedding-ada-002'
] ]
from .timeout import timeout from .timeout import timeout
from .utils import client, logging, exception_type # Import all the symbols from main.py from .utils import client, logging, exception_type # Import all the symbols from main.py
from .main import * # Import all the symbols from main.py from .main import * # Import all the symbols from main.py

View file

@ -64,6 +64,15 @@ def get_optional_params(
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
############################################# #############################################
async def acompletion(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
# Call the synchronous function using run_in_executor
return await loop.run_in_executor(None, func)
@client @client
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False` @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
@ -258,16 +267,6 @@ def completion(
## Map to OpenAI Exception ## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e) raise exception_type(model=model, original_exception=e)
async def acompletion(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
# Call the synchronous function using run_in_executor
return await loop.run_in_executor(None, func)
### EMBEDDING ENDPOINTS #################### ### EMBEDDING ENDPOINTS ####################
@client @client
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='litellm', name='litellm',
version='0.1.228', version='0.1.2291',
description='Library to easily interface with LLM API providers', description='Library to easily interface with LLM API providers',
author='BerriAI', author='BerriAI',
packages=[ packages=[