fix(azure.py): adding support for aiohttp calls on azure + openai

This commit is contained in:
Krrish Dholakia 2023-11-09 10:40:26 -08:00
parent 2c67bda137
commit 86ef2a02f7
7 changed files with 93 additions and 30 deletions

View file

@ -4,9 +4,7 @@ from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
from typing import Callable, Optional from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import aiohttp
# This file just has the openai config classes.
# For implementation check out completion() in main.py
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -116,6 +114,7 @@ class AzureChatCompletion(BaseLLM):
optional_params, optional_params,
litellm_params, litellm_params,
logger_fn, logger_fn,
acompletion: bool = False,
headers: Optional[dict]=None): headers: Optional[dict]=None):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
@ -157,6 +156,8 @@ class AzureChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return response.iter_lines() return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else: else:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -178,6 +179,18 @@ class AzureChatCompletion(BaseLLM):
import traceback import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
if response.status != 200:
raise AzureOpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def embedding(self, def embedding(self,
model: str, model: str,
input: list, input: list,

View file

@ -3,9 +3,8 @@ import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp
# This file just has the openai config classes.
# For implementation check out completion() in main.py
class OpenAIError(Exception): class OpenAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -184,22 +183,24 @@ class OpenAIChatCompletion(BaseLLM):
OpenAIError(status_code=500, message="Invalid response object.") OpenAIError(status_code=500, message="Invalid response object.")
def completion(self, def completion(self,
model: Optional[str]=None, model_response: ModelResponse,
messages: Optional[list]=None, model: Optional[str]=None,
model_response: Optional[ModelResponse]=None, messages: Optional[list]=None,
print_verbose: Optional[Callable]=None, print_verbose: Optional[Callable]=None,
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None, api_base: Optional[str]=None,
logging_obj=None, acompletion: bool = False,
optional_params=None, logging_obj=None,
litellm_params=None, optional_params=None,
logger_fn=None, litellm_params=None,
headers: Optional[dict]=None): logger_fn=None,
headers: Optional[dict]=None):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if headers is None: if headers is None:
headers = self.validate_environment(api_key=api_key) headers = self.validate_environment(api_key=api_key)
api_base = f"{api_base}/chat/completions"
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
@ -214,13 +215,13 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,
additional_args={"headers": headers, "api_base": api_base}, additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "data": data},
) )
try: try:
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post( response = self._client_session.post(
url=f"{api_base}/chat/completions", url=api_base,
json=data, json=data,
headers=headers, headers=headers,
stream=optional_params["stream"] stream=optional_params["stream"]
@ -230,9 +231,11 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return response.iter_lines() return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else: else:
response = self._client_session.post( response = self._client_session.post(
url=f"{api_base}/chat/completions", url=api_base,
json=data, json=data,
headers=headers, headers=headers,
) )
@ -270,6 +273,17 @@ class OpenAIChatCompletion(BaseLLM):
import traceback import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
if response.status != 200:
raise OpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def embedding(self, def embedding(self,
model: str, model: str,
input: list, input: list,

View file

@ -7,7 +7,7 @@
# #
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any from typing import Any
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
async def acompletion(*args, **kwargs): async def acompletion(model: str, messages: List = [], *args, **kwargs):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -118,16 +118,31 @@ async def acompletion(*args, **kwargs):
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments ### INITIALIZE LOGGING OBJECT ###
kwargs["litellm_call_id"] = str(uuid.uuid4())
start_time = datetime.datetime.now()
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
### PASS ARGS TO COMPLETION ###
kwargs["litellm_logging_obj"] = logging_obj
kwargs["acompletion"] = True kwargs["acompletion"] = True
kwargs["model"] = model
kwargs["messages"] = messages
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs) func = partial(completion, *args, **kwargs)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
# Call the synchronous function using run_in_executor _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
response = await loop.run_in_executor(None, func_with_context)
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
response = await completion(*args, **kwargs)
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator if kwargs.get("stream", False): # return an async generator
# do not change this # do not change this
# for stream = True, always return an async generator # for stream = True, always return an async generator
@ -137,6 +152,16 @@ async def acompletion(*args, **kwargs):
async for line in response async for line in response
) )
else: else:
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(response, *args, **kwargs)
# LOG SUCCESS
logging_obj.success_handler(response, start_time, end_time)
# RETURN RESULT
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
return response return response
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
@ -268,6 +293,7 @@ def completion(
final_prompt_value = kwargs.get("final_prompt_value", None) final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None) bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None) eos_token = kwargs.get("eos_token", None)
acompletion = kwargs.get("acompletion", False)
######## end of unpacking kwargs ########### ######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"] openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"] litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
@ -409,6 +435,7 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging) response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
@ -472,6 +499,7 @@ def completion(
print_verbose=print_verbose, print_verbose=print_verbose,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,

View file

@ -25,16 +25,19 @@ def test_sync_response():
def test_async_response(): def test_async_response():
import asyncio import asyncio
async def test_get_response(): async def test_get_response():
litellm.set_verbose = True
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages) response = await acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
response = asyncio.run(test_get_response()) response = asyncio.run(test_get_response())
# print(response) # print(response)
test_async_response()
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio

View file

@ -1,6 +1,7 @@
from litellm import completion, stream_chunk_builder from litellm import completion, stream_chunk_builder
import litellm import litellm
import os, dotenv import os, dotenv
import pytest
dotenv.load_dotenv() dotenv.load_dotenv()
user_message = "What is the current weather in Boston?" user_message = "What is the current weather in Boston?"
@ -23,6 +24,7 @@ function_schema = {
}, },
} }
@pytest.mark.skip
def test_stream_chunk_builder(): def test_stream_chunk_builder():
litellm.set_verbose = False litellm.set_verbose = False
litellm.api_key = os.environ["OPENAI_API_KEY"] litellm.api_key = os.environ["OPENAI_API_KEY"]

View file

@ -763,7 +763,8 @@ class Logging:
) )
elif isinstance(callback, CustomLogger): # custom logger class elif isinstance(callback, CustomLogger): # custom logger class
callback.log_failure_event( callback.log_failure_event(
model=self.model, start_time=start_time,
end_time=end_time,
messages=self.messages, messages=self.messages,
kwargs=self.model_call_details, kwargs=self.model_call_details,
) )
@ -908,7 +909,7 @@ def client(original_function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
result = None result = None
logging_obj = None logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs # only set litellm_call_id if its not in kwargs
if "litellm_call_id" not in kwargs: if "litellm_call_id" not in kwargs:
@ -919,7 +920,8 @@ def client(original_function):
raise ValueError("model param not passed in.") raise ValueError("model param not passed in.")
try: try:
logging_obj = function_setup(start_time, *args, **kwargs) if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET # [OPTIONAL] CHECK BUDGET
@ -956,7 +958,8 @@ def client(original_function):
return litellm.stream_chunk_builder(chunks) return litellm.stream_chunk_builder(chunks)
else: else:
return result return result
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
return result
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
@ -1014,7 +1017,6 @@ def client(original_function):
raise e raise e
return wrapper return wrapper
####### USAGE CALCULATOR ################ ####### USAGE CALCULATOR ################

View file

@ -17,6 +17,7 @@ click = "*"
jinja2 = "^3.1.2" jinja2 = "^3.1.2"
certifi = "^2023.7.22" certifi = "^2023.7.22"
appdirs = "^1.4.4" appdirs = "^1.4.4"
aiohttp = "*"
[tool.poetry.scripts] [tool.poetry.scripts]
litellm = 'litellm:run_server' litellm = 'litellm:run_server'