fix(azure.py): adding support for aiohttp calls on azure + openai

This commit is contained in:
Krrish Dholakia 2023-11-09 10:40:26 -08:00
parent 2c67bda137
commit 86ef2a02f7
7 changed files with 93 additions and 30 deletions

View file

@ -4,9 +4,7 @@ from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from typing import Callable, Optional
from litellm import OpenAIConfig
# This file just has the openai config classes.
# For implementation check out completion() in main.py
import aiohttp
class AzureOpenAIError(Exception):
def __init__(self, status_code, message):
@ -116,6 +114,7 @@ class AzureChatCompletion(BaseLLM):
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict]=None):
super().completion()
exception_mapping_worked = False
@ -157,6 +156,8 @@ class AzureChatCompletion(BaseLLM):
## RESPONSE OBJECT
return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else:
response = self._client_session.post(
url=api_base,
@ -178,6 +179,18 @@ class AzureChatCompletion(BaseLLM):
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
if response.status != 200:
raise AzureOpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def embedding(self,
model: str,
input: list,

View file

@ -3,9 +3,8 @@ import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from typing import Callable, Optional
import aiohttp
# This file just has the openai config classes.
# For implementation check out completion() in main.py
class OpenAIError(Exception):
def __init__(self, status_code, message):
@ -184,22 +183,24 @@ class OpenAIChatCompletion(BaseLLM):
OpenAIError(status_code=500, message="Invalid response object.")
def completion(self,
model: Optional[str]=None,
messages: Optional[list]=None,
model_response: Optional[ModelResponse]=None,
print_verbose: Optional[Callable]=None,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
logging_obj=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict]=None):
model_response: ModelResponse,
model: Optional[str]=None,
messages: Optional[list]=None,
print_verbose: Optional[Callable]=None,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
acompletion: bool = False,
logging_obj=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict]=None):
super().completion()
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
api_base = f"{api_base}/chat/completions"
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
@ -214,13 +215,13 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base},
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "data": data},
)
try:
if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=f"{api_base}/chat/completions",
url=api_base,
json=data,
headers=headers,
stream=optional_params["stream"]
@ -230,9 +231,11 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else:
response = self._client_session.post(
url=f"{api_base}/chat/completions",
url=api_base,
json=data,
headers=headers,
)
@ -270,6 +273,17 @@ class OpenAIChatCompletion(BaseLLM):
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
if response.status != 200:
raise OpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def embedding(self,
model: str,
input: list,

View file

@ -7,7 +7,7 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################
async def acompletion(*args, **kwargs):
async def acompletion(model: str, messages: List = [], *args, **kwargs):
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -118,16 +118,31 @@ async def acompletion(*args, **kwargs):
"""
loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments
### INITIALIZE LOGGING OBJECT ###
kwargs["litellm_call_id"] = str(uuid.uuid4())
start_time = datetime.datetime.now()
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
### PASS ARGS TO COMPLETION ###
kwargs["litellm_logging_obj"] = logging_obj
kwargs["acompletion"] = True
kwargs["model"] = model
kwargs["messages"] = messages
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
response = await completion(*args, **kwargs)
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
# do not change this
# for stream = True, always return an async generator
@ -137,6 +152,16 @@ async def acompletion(*args, **kwargs):
async for line in response
)
else:
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(response, *args, **kwargs)
# LOG SUCCESS
logging_obj.success_handler(response, start_time, end_time)
# RETURN RESULT
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
return response
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
@ -268,6 +293,7 @@ def completion(
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
acompletion = kwargs.get("acompletion", False)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
@ -409,6 +435,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion
)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
@ -472,6 +499,7 @@ def completion(
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,

View file

@ -25,16 +25,19 @@ def test_sync_response():
def test_async_response():
import asyncio
async def test_get_response():
litellm.set_verbose = True
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
response = asyncio.run(test_get_response())
# print(response)
test_async_response()
def test_get_response_streaming():
import asyncio

View file

@ -1,6 +1,7 @@
from litellm import completion, stream_chunk_builder
import litellm
import os, dotenv
import pytest
dotenv.load_dotenv()
user_message = "What is the current weather in Boston?"
@ -23,6 +24,7 @@ function_schema = {
},
}
@pytest.mark.skip
def test_stream_chunk_builder():
litellm.set_verbose = False
litellm.api_key = os.environ["OPENAI_API_KEY"]

View file

@ -763,7 +763,8 @@ class Logging:
)
elif isinstance(callback, CustomLogger): # custom logger class
callback.log_failure_event(
model=self.model,
start_time=start_time,
end_time=end_time,
messages=self.messages,
kwargs=self.model_call_details,
)
@ -908,7 +909,7 @@ def client(original_function):
def wrapper(*args, **kwargs):
start_time = datetime.datetime.now()
result = None
logging_obj = None
logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs
if "litellm_call_id" not in kwargs:
@ -919,7 +920,8 @@ def client(original_function):
raise ValueError("model param not passed in.")
try:
logging_obj = function_setup(start_time, *args, **kwargs)
if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET
@ -956,7 +958,8 @@ def client(original_function):
return litellm.stream_chunk_builder(chunks)
else:
return result
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
return result
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
@ -1014,7 +1017,6 @@ def client(original_function):
raise e
return wrapper
####### USAGE CALCULATOR ################

View file

@ -17,6 +17,7 @@ click = "*"
jinja2 = "^3.1.2"
certifi = "^2023.7.22"
appdirs = "^1.4.4"
aiohttp = "*"
[tool.poetry.scripts]
litellm = 'litellm:run_server'