forked from phoenix/litellm-mirror
fix(azure.py): adding support for aiohttp calls on azure + openai
This commit is contained in:
parent
2c67bda137
commit
86ef2a02f7
7 changed files with 93 additions and 30 deletions
|
@ -7,7 +7,7 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os, openai, sys, json, inspect
|
||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
|
@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
async def acompletion(*args, **kwargs):
|
||||
async def acompletion(model: str, messages: List = [], *args, **kwargs):
|
||||
"""
|
||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||
|
||||
|
@ -118,16 +118,31 @@ async def acompletion(*args, **kwargs):
|
|||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
### INITIALIZE LOGGING OBJECT ###
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
start_time = datetime.datetime.now()
|
||||
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
|
||||
|
||||
### PASS ARGS TO COMPLETION ###
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
kwargs["acompletion"] = True
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, *args, **kwargs)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||
|
||||
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
response = await completion(*args, **kwargs)
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
# do not change this
|
||||
# for stream = True, always return an async generator
|
||||
|
@ -137,6 +152,16 @@ async def acompletion(*args, **kwargs):
|
|||
async for line in response
|
||||
)
|
||||
else:
|
||||
end_time = datetime.datetime.now()
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(response, *args, **kwargs)
|
||||
|
||||
# LOG SUCCESS
|
||||
logging_obj.success_handler(response, start_time, end_time)
|
||||
# RETURN RESULT
|
||||
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
return response
|
||||
|
||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||
|
@ -268,6 +293,7 @@ def completion(
|
|||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
|
||||
|
@ -409,6 +435,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
|
||||
|
@ -472,6 +499,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue