fixed acompletion import bug

This commit is contained in:
Krrish Dholakia 2023-08-03 14:06:32 -07:00
parent 82a75a9d92
commit 123de53475
18 changed files with 53 additions and 24 deletions

View file

@ -2,11 +2,18 @@ import os, openai, cohere, replicate, sys
from typing import Any
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
import traceback
from functools import partial
import dotenv
import traceback
import litellm
from litellm import client, logging, exception_type, timeout, success_callback, failure_callback
import random
import asyncio
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
@ -24,6 +31,7 @@ def get_optional_params(
frequency_penalty = 0,
logit_bias = {},
user = "",
deployment_id = None
):
optional_params = {}
if functions != []:
@ -50,27 +58,39 @@ def get_optional_params(
optional_params["logit_bias"] = logit_bias
if user != "":
optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = user
return optional_params
####### COMPLETION ENDPOINTS ################
#############################################
async def acompletion(*args, **kwargs):
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
# Call the synchronous function using run_in_executor
return await loop.run_in_executor(None, func)
@client
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def completion(
model, messages, # required params
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
functions=[], function_call="", # optional params
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="",
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params
*, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
):
try:
# check if user passed in any of the OpenAI optional params
optional_params = get_optional_params(
functions=functions, function_call=function_call,
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id
)
if azure == True:
# azure configs
@ -247,7 +267,6 @@ def completion(
## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e)
### EMBEDDING ENDPOINTS ####################
@client
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`