mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(azure.py): adding support for aiohttp calls on azure + openai
This commit is contained in:
parent
2c67bda137
commit
86ef2a02f7
7 changed files with 93 additions and 30 deletions
|
@ -4,9 +4,7 @@ from .base import BaseLLM
|
|||
from litellm.utils import ModelResponse, Choices, Message
|
||||
from typing import Callable, Optional
|
||||
from litellm import OpenAIConfig
|
||||
|
||||
# This file just has the openai config classes.
|
||||
# For implementation check out completion() in main.py
|
||||
import aiohttp
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -116,6 +114,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
optional_params,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict]=None):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
|
@ -157,6 +156,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
## RESPONSE OBJECT
|
||||
return response.iter_lines()
|
||||
elif acompletion is True:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
|
||||
else:
|
||||
response = self._client_session.post(
|
||||
url=api_base,
|
||||
|
@ -178,6 +179,18 @@ class AzureChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_base, json=data, headers=headers) as response:
|
||||
response_json = await response.json()
|
||||
if response.status != 200:
|
||||
raise AzureOpenAIError(status_code=response.status, message=response.text)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
|
|
@ -3,9 +3,8 @@ import types, requests
|
|||
from .base import BaseLLM
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
from typing import Callable, Optional
|
||||
import aiohttp
|
||||
|
||||
# This file just has the openai config classes.
|
||||
# For implementation check out completion() in main.py
|
||||
|
||||
class OpenAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -184,12 +183,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
OpenAIError(status_code=500, message="Invalid response object.")
|
||||
|
||||
def completion(self,
|
||||
model_response: ModelResponse,
|
||||
model: Optional[str]=None,
|
||||
messages: Optional[list]=None,
|
||||
model_response: Optional[ModelResponse]=None,
|
||||
print_verbose: Optional[Callable]=None,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
acompletion: bool = False,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
|
@ -200,6 +200,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
api_base = f"{api_base}/chat/completions"
|
||||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
|
@ -214,13 +215,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base},
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "data": data},
|
||||
)
|
||||
|
||||
try:
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
stream=optional_params["stream"]
|
||||
|
@ -230,9 +231,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
## RESPONSE OBJECT
|
||||
return response.iter_lines()
|
||||
elif acompletion is True:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
|
||||
else:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
@ -270,6 +273,17 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_base, json=data, headers=headers) as response:
|
||||
response_json = await response.json()
|
||||
if response.status != 200:
|
||||
raise OpenAIError(status_code=response.status, message=response.text)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os, openai, sys, json, inspect
|
||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
|
@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
async def acompletion(*args, **kwargs):
|
||||
async def acompletion(model: str, messages: List = [], *args, **kwargs):
|
||||
"""
|
||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||
|
||||
|
@ -118,14 +118,29 @@ async def acompletion(*args, **kwargs):
|
|||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
### INITIALIZE LOGGING OBJECT ###
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
start_time = datetime.datetime.now()
|
||||
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
|
||||
|
||||
### PASS ARGS TO COMPLETION ###
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
kwargs["acompletion"] = True
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, *args, **kwargs)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||
|
||||
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
response = await completion(*args, **kwargs)
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
|
@ -137,6 +152,16 @@ async def acompletion(*args, **kwargs):
|
|||
async for line in response
|
||||
)
|
||||
else:
|
||||
end_time = datetime.datetime.now()
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(response, *args, **kwargs)
|
||||
|
||||
# LOG SUCCESS
|
||||
logging_obj.success_handler(response, start_time, end_time)
|
||||
# RETURN RESULT
|
||||
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
return response
|
||||
|
||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||
|
@ -268,6 +293,7 @@ def completion(
|
|||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
|
||||
|
@ -409,6 +435,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
|
||||
|
@ -472,6 +499,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
|
|
|
@ -25,16 +25,19 @@ def test_sync_response():
|
|||
def test_async_response():
|
||||
import asyncio
|
||||
async def test_get_response():
|
||||
litellm.set_verbose = True
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
print(f"response: {response}")
|
||||
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
response = asyncio.run(test_get_response())
|
||||
# print(response)
|
||||
test_async_response()
|
||||
|
||||
def test_get_response_streaming():
|
||||
import asyncio
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from litellm import completion, stream_chunk_builder
|
||||
import litellm
|
||||
import os, dotenv
|
||||
import pytest
|
||||
dotenv.load_dotenv()
|
||||
|
||||
user_message = "What is the current weather in Boston?"
|
||||
|
@ -23,6 +24,7 @@ function_schema = {
|
|||
},
|
||||
}
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_stream_chunk_builder():
|
||||
litellm.set_verbose = False
|
||||
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
|
|
@ -763,7 +763,8 @@ class Logging:
|
|||
)
|
||||
elif isinstance(callback, CustomLogger): # custom logger class
|
||||
callback.log_failure_event(
|
||||
model=self.model,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
messages=self.messages,
|
||||
kwargs=self.model_call_details,
|
||||
)
|
||||
|
@ -908,7 +909,7 @@ def client(original_function):
|
|||
def wrapper(*args, **kwargs):
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
logging_obj = None
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
|
||||
# only set litellm_call_id if its not in kwargs
|
||||
if "litellm_call_id" not in kwargs:
|
||||
|
@ -919,6 +920,7 @@ def client(original_function):
|
|||
raise ValueError("model param not passed in.")
|
||||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
logging_obj = function_setup(start_time, *args, **kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
|
@ -956,7 +958,8 @@ def client(original_function):
|
|||
return litellm.stream_chunk_builder(chunks)
|
||||
else:
|
||||
return result
|
||||
|
||||
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
|
||||
return result
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
|
@ -1014,7 +1017,6 @@ def client(original_function):
|
|||
raise e
|
||||
return wrapper
|
||||
|
||||
|
||||
####### USAGE CALCULATOR ################
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ click = "*"
|
|||
jinja2 = "^3.1.2"
|
||||
certifi = "^2023.7.22"
|
||||
appdirs = "^1.4.4"
|
||||
aiohttp = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
litellm = 'litellm:run_server'
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue