(feat) support azure/gpt-instruct models

This commit is contained in:
ishaan-jaff 2024-03-12 09:30:15 -07:00
parent 10f5f342bd
commit b193b01f40
5 changed files with 606 additions and 0 deletions

View file

@ -455,6 +455,7 @@ provider_list: List = [
"ai21",
"baseten",
"azure",
"azure_text",
"sagemaker",
"bedrock",
"vllm",

511
litellm/llms/azure_text.py Normal file
View file

@ -0,0 +1,511 @@
from typing import Optional, Union, Any
import types, requests
from .base import BaseLLM
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
)
from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig
import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
from ..llms.openai import OpenAITextCompletion
import uuid
from .prompt_templates.factory import prompt_factory, custom_prompt
openai_text_completion = OpenAITextCompletion()
class AzureOpenAIError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p,
)
def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
class AzureTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key, azure_ad_token):
headers = {
"content-type": "application/json",
}
if api_key is not None:
headers["api-key"] = api_key
elif azure_ad_token is not None:
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: str,
print_verbose: Callable,
timeout,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
):
super().completion()
exception_mapping_worked = False
try:
if model is None or messages is None:
raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
prompt = prompt_factory(
messages=messages, model=model, custom_llm_provider="azure_text"
)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
data = {"model": None, "prompt": prompt, **optional_params}
else:
data = {
"model": model, # type: ignore
"prompt": prompt,
**optional_params,
}
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
logging_obj=logging_obj,
)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_version": api_version,
"api_base": api_base,
"complete_input_dict": data,
},
)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault(
"api-version", api_version
)
response = azure_client.completions.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return openai_text_completion.convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
async def acompletion(
self,
api_key: str,
api_version: str,
model: str,
api_base: str,
data: dict,
timeout: Any,
model_response: ModelResponse,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
logging_obj=None,
):
response = None
try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.completions.create(**data, timeout=timeout)
return openai_text_completion.convert_to_model_response_object(
response_object=response.model_dump(),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = azure_client.completions.create(**data, timeout=timeout)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
):
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.completions.create(**data, timeout=timeout)
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))

View file

@ -807,6 +807,20 @@ def gemini_text_image_pt(messages: list):
return content
def azure_text_pt(messages: list):
prompt = ""
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
return prompt
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = (
@ -907,6 +921,8 @@ def prompt_factory(
for message in messages:
message.pop("name", None)
return messages
elif custom_llm_provider == "azure_text":
return azure_text_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)

View file

@ -64,6 +64,7 @@ from .llms import (
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.prompt_templates.factory import (
prompt_factory,
@ -96,6 +97,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
####### COMPLETION ENDPOINTS ################
@ -254,6 +256,7 @@ async def acompletion(
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
@ -800,6 +803,71 @@ def completion(
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
elif custom_llm_provider == "azure_text":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
)
azure_ad_token = optional_params.get("extra_body", {}).pop(
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL
response = azure_text_completions.completion(
model=model,
messages=messages,
headers=headers,
api_key=api_key,
api_base=api_base,
api_version=api_version,
api_type=api_type,
azure_ad_token=azure_ad_token,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(

View file

@ -1152,6 +1152,16 @@ def test_completion_azure_key_completion_arg():
# test_completion_azure_key_completion_arg()
def test_azure_instruct():
litellm.set_verbose = True
response = completion(
model="azure_text/instruct-model",
messages=[{"role": "user", "content": "What is the weather like in Boston?"}],
max_tokens=10,
)
print("response", response)
async def test_re_use_azure_async_client():
try:
print("azure gpt-3.5 ASYNC with clie nttest\n\n")