refactor(azure.py): moving azure openai calls to http calls

This commit is contained in:
Krrish Dholakia 2023-11-08 16:52:18 -08:00
parent 01a7660a12
commit 53abc31c27
7 changed files with 309 additions and 78 deletions

View file

@ -368,7 +368,8 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, AzureOpenAIConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *
from .exceptions import ( from .exceptions import (

179
litellm/llms/azure.py Normal file
View file

@ -0,0 +1,179 @@
from typing import Optional, Union
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from typing import Callable, Optional
from litellm import OpenAIConfig
# This file just has the openai config classes.
# For implementation check out completion() in main.py
class AzureOpenAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]]= None,
functions: Optional[list]= None,
logit_bias: Optional[dict]= None,
max_tokens: Optional[int]= None,
n: Optional[int]= None,
presence_penalty: Optional[int]= None,
stop: Optional[Union[str,list]]=None,
temperature: Optional[int]= None,
top_p: Optional[int]= None) -> None:
super().__init__(frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p)
class AzureChatCompletion(BaseLLM):
_client_session: requests.Session
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["api-key"] = api_key
return headers
def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
try:
if response_object is None or model_response_object is None:
raise AzureOpenAIError(status_code=500, message="Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
return model_response_object
except:
AzureOpenAIError(status_code=500, message="Invalid response object.")
def completion(self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
print_verbose: Callable,
logging_obj,
optional_params,
litellm_params,
logger_fn,
headers: Optional[dict]=None):
super().completion()
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/chat/completions?api-version={api_version}"
data = {
"messages": messages,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=api_base,
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
else:
response = self._client_session.post(
url=api_base,
json=data,
headers=headers,
)
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())

View file

@ -145,56 +145,6 @@ class OpenAITextCompletionConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None} and v is not None}
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]]= None,
functions: Optional[list]= None,
logit_bias: Optional[dict]= None,
max_tokens: Optional[int]= None,
n: Optional[int]= None,
presence_penalty: Optional[int]= None,
stop: Optional[Union[str,list]]=None,
temperature: Optional[int]= None,
top_p: Optional[int]= None) -> None:
super().__init__(frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p)
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
_client_session: requests.Session _client_session: requests.Session

View file

@ -50,6 +50,7 @@ from .llms import (
vertex_ai, vertex_ai,
maritalk) maritalk)
from .llms.openai import OpenAIChatCompletion from .llms.openai import OpenAIChatCompletion
from .llms.azure import AzureChatCompletion
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -71,7 +72,8 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
openai_proxy_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
async def acompletion(*args, **kwargs): async def acompletion(*args, **kwargs):
@ -393,29 +395,24 @@ def completion(
if k not in optional_params: # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in if k not in optional_params: # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
## LOGGING
logging.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
## COMPLETION CALL ## COMPLETION CALL
response = openai.ChatCompletion.create( response = azure_chat_completions.completion(
engine=model, model=model,
messages=messages, messages=messages,
headers=headers, headers=headers,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
api_type=api_type, api_type=api_type,
**optional_params, model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider="openai", logging_obj=logging) response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response return response
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
@ -476,8 +473,7 @@ def completion(
## COMPLETION CALL ## COMPLETION CALL
try: try:
if custom_llm_provider == "custom_openai": if custom_llm_provider == "custom_openai":
print("making call using openai custom chat completion") response = openai_chat_completions.completion(
response = openai_proxy_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,

View file

@ -62,7 +62,7 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
# test_context_window(model="command-nightly") # test_context_window(model="azure/chatgpt-v-2")
# test_context_window_with_fallbacks(model="command-nightly") # test_context_window_with_fallbacks(model="command-nightly")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@ -80,7 +80,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AWS_REGION_NAME"] = "bad-key" os.environ["AWS_REGION_NAME"] = "bad-key"
temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"] temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key" os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key"
elif model == "chatgpt-test": elif model == "azure/chatgpt-v-2":
temporary_key = os.environ["AZURE_API_KEY"] temporary_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "bad-key" os.environ["AZURE_API_KEY"] = "bad-key"
elif model == "claude-instant-1": elif model == "claude-instant-1":
@ -156,8 +156,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return return
for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
invalid_auth(model=model) # invalid_auth(model=model)
# invalid_auth(model="azure/chatgpt-v-2")
# Test 3: Invalid Request Error # Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@ -167,6 +168,7 @@ def test_invalid_request_error(model):
with pytest.raises(InvalidRequestError): with pytest.raises(InvalidRequestError):
completion(model=model, messages=messages, max_tokens="hello world") completion(model=model, messages=messages, max_tokens="hello world")
test_invalid_request_error(model="azure/chatgpt-v-2")
# Test 3: Rate Limit Errors # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):
# try: # try:

View file

@ -403,6 +403,32 @@ def test_completion_cohere_stream_bad_key():
# test_completion_hf_stream_bad_key() # test_completion_hf_stream_bad_key()
def test_completion_azure_stream():
try:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
},
]
response = completion(
model="azure/chatgpt-v-2", messages=messages, stream=True, max_tokens=50
)
complete_response = ""
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
complete_response += chunk
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_stream()
def test_completion_claude_stream(): def test_completion_claude_stream():
try: try:
messages = [ messages = [

View file

@ -2881,8 +2881,6 @@ def exception_type(
llm_provider="openrouter" llm_provider="openrouter"
) )
original_exception.llm_provider = "openrouter" original_exception.llm_provider = "openrouter"
elif custom_llm_provider == "azure":
original_exception.llm_provider = "azure"
else: else:
original_exception.llm_provider = "openai" original_exception.llm_provider = "openai"
if "This model's maximum context length is" in original_exception._message: if "This model's maximum context length is" in original_exception._message:
@ -3478,6 +3476,9 @@ def exception_type(
raise original_exception raise original_exception
raise original_exception raise original_exception
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
if "no attribute 'async_get_ollama_response_stream" in error_str:
exception_mapping_worked = True
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
if isinstance(original_exception, dict): if isinstance(original_exception, dict):
error_str = original_exception.get("error", "") error_str = original_exception.get("error", "")
else: else:
@ -3512,9 +3513,59 @@ def exception_type(
llm_provider="vllm", llm_provider="vllm",
model=model model=model
) )
elif custom_llm_provider == "ollama": elif custom_llm_provider == "azure":
if "no attribute 'async_get_ollama_response_stream" in error_str: if "This model's maximum context length is" in error_str:
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'") exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model
)
elif "invalid_request_error" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model
)
elif hasattr(original_exception, "status_code"):
exception_mapping_worked = True
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"AzureException - {original_exception.message}",
model=model,
llm_provider="azure"
)
if original_exception.status_code == 422:
exception_mapping_worked = True
raise InvalidRequestError(
message=f"AzureException - {original_exception.message}",
model=model,
llm_provider="azure",
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"AzureException - {original_exception.message}",
model=model,
llm_provider="azure",
)
else:
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model
)
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk": elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
exception_mapping_worked = True exception_mapping_worked = True
@ -3853,6 +3904,26 @@ class CustomStreamWrapper:
except: except:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk):
chunk = chunk.decode("utf-8")
is_finished = False
finish_reason = ""
text = ""
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:]) # chunk.startswith("data:"):
try:
text = data_json["choices"][0]["delta"].get("content", "")
if data_json["choices"][0].get("finish_reason", None):
is_finished = True
finish_reason = data_json["choices"][0]["finish_reason"]
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
elif "error" in chunk:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
def handle_replicate_chunk(self, chunk): def handle_replicate_chunk(self, chunk):
try: try:
text = "" text = ""
@ -4013,6 +4084,12 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
chunk = next(self.completion_stream)
response_obj = self.handle_azure_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
response_obj = self.handle_maritalk_chunk(chunk) response_obj = self.handle_maritalk_chunk(chunk)
@ -4187,7 +4264,7 @@ class TextCompletionStreamWrapper:
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except Exception as e: except Exception as e:
print(f"got exception {e}") print(f"got exception {e}") # noqa
async def __anext__(self): async def __anext__(self):
try: try:
return next(self) return next(self)