fix(azure.py): adding support for aiohttp calls on azure + openai

This commit is contained in:
Krrish Dholakia 2023-11-09 10:40:26 -08:00
parent bbc2cb43aa
commit 1d46891ceb
7 changed files with 93 additions and 30 deletions

View file

@ -7,7 +7,7 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################
async def acompletion(*args, **kwargs):
async def acompletion(model: str, messages: List = [], *args, **kwargs):
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -118,16 +118,31 @@ async def acompletion(*args, **kwargs):
"""
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
### INITIALIZE LOGGING OBJECT ###
kwargs["litellm_call_id"] = str(uuid.uuid4())
start_time = datetime.datetime.now()
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
### PASS ARGS TO COMPLETION ###
kwargs["litellm_logging_obj"] = logging_obj
kwargs["acompletion"] = True
kwargs["model"] = model
kwargs["messages"] = messages
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
response = await completion(*args, **kwargs)
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
# do not change this
# for stream = True, always return an async generator
@ -137,6 +152,16 @@ async def acompletion(*args, **kwargs):
async for line in response
)
else:
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(response, *args, **kwargs)
# LOG SUCCESS
logging_obj.success_handler(response, start_time, end_time)
# RETURN RESULT
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
return response
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
@ -268,6 +293,7 @@ def completion(
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
acompletion = kwargs.get("acompletion", False)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
@ -409,6 +435,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion
)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
@ -472,6 +499,7 @@ def completion(
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,