forked from phoenix/litellm-mirror
fix(azure.py): adding support for aiohttp calls on azure + openai
This commit is contained in:
parent
2c67bda137
commit
86ef2a02f7
7 changed files with 93 additions and 30 deletions
|
@ -3,9 +3,8 @@ import types, requests
|
|||
from .base import BaseLLM
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
from typing import Callable, Optional
|
||||
import aiohttp
|
||||
|
||||
# This file just has the openai config classes.
|
||||
# For implementation check out completion() in main.py
|
||||
|
||||
class OpenAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -184,22 +183,24 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
OpenAIError(status_code=500, message="Invalid response object.")
|
||||
|
||||
def completion(self,
|
||||
model: Optional[str]=None,
|
||||
messages: Optional[list]=None,
|
||||
model_response: Optional[ModelResponse]=None,
|
||||
print_verbose: Optional[Callable]=None,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict]=None):
|
||||
model_response: ModelResponse,
|
||||
model: Optional[str]=None,
|
||||
messages: Optional[list]=None,
|
||||
print_verbose: Optional[Callable]=None,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
acompletion: bool = False,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict]=None):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
api_base = f"{api_base}/chat/completions"
|
||||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
|
@ -214,13 +215,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base},
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "data": data},
|
||||
)
|
||||
|
||||
try:
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
stream=optional_params["stream"]
|
||||
|
@ -230,9 +231,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
## RESPONSE OBJECT
|
||||
return response.iter_lines()
|
||||
elif acompletion is True:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
|
||||
else:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
url=api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
@ -270,6 +273,17 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_base, json=data, headers=headers) as response:
|
||||
response_json = await response.json()
|
||||
if response.status != 200:
|
||||
raise OpenAIError(status_code=response.status, message=response.text)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue