feat(main.py): add support for image generation endpoint

This commit is contained in:
Krrish Dholakia 2023-12-16 21:07:29 -08:00
parent 7847ae1e23
commit 13d088b72e
7 changed files with 366 additions and 9 deletions

View file

@ -132,8 +132,7 @@ for key, value in model_cost.items():
elif value.get('litellm_provider') == 'anthropic':
anthropic_models.append(key)
elif value.get('litellm_provider') == 'openrouter':
split_string = key.split('/', 1)
openrouter_models.append(split_string[1])
openrouter_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-text-models':
vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
@ -366,6 +365,13 @@ bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-en
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = [
"dall-e-2",
"dall-e-3"
]
from .timeout import timeout
from .utils import (
client,

View file

@ -456,6 +456,67 @@ class AzureChatCompletion(BaseLLM):
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
def image_generation(self,
prompt: list,
timeout: float,
model: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
try:
model = model
data = {
# "model": model,
"prompt": prompt,
**optional_params
}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
# if aembedding == True:
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
# return response
if client is None:
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data},
)
## COMPLETION CALL
response = azure_client.images.generate(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e

View file

@ -445,6 +445,66 @@ class OpenAIChatCompletion(BaseLLM):
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
def image_generation(self,
model: Optional[str],
prompt: str,
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
try:
model = model
data = {
"model": model,
"prompt": prompt,
**optional_params
}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
# if aembedding == True:
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
# return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=openai_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
)
## COMPLETION CALL
response = openai_client.images.generate(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client

View file

@ -33,7 +33,8 @@ from litellm.utils import (
convert_to_model_response_object,
token_counter,
Usage,
get_optional_params_embeddings
get_optional_params_embeddings,
get_optional_params_image_gen
)
from .llms import (
anthropic,
@ -2237,6 +2238,91 @@ def moderation(input: str, api_key: Optional[str]=None):
response = openai.moderations.create(input=input)
return response
##### Image Generation #######################
@client
def image_generation(prompt: str,
model: Optional[str]=None,
n: Optional[int]=None,
quality: Optional[str]=None,
response_format: Optional[str]=None,
size: Optional[str]=None,
style: Optional[str]=None,
user: Optional[str]=None,
timeout=600, # default to 10 minutes
api_key: Optional[str]=None,
api_base: Optional[str]=None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs):
"""
Maps the https://api.openai.com/v1/images/generations endpoint.
Currently supports just Azure + OpenAI.
"""
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get('proxy_server_request', None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model_response = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(n=n,
quality=quality,
response_format=response_format,
size=size,
style=style,
user=user,
custom_llm_provider=custom_llm_provider,
**non_default_params)
logging = litellm_logging_obj
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}})
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_base = (
api_base
or litellm.api_base
or get_secret("AZURE_API_BASE")
)
api_version = (
api_version or
litellm.api_version or
get_secret("AZURE_API_VERSION")
)
api_key = (
api_key or
litellm.api_key or
litellm.azure_key or
get_secret("AZURE_OPENAI_API_KEY") or
get_secret("AZURE_API_KEY")
)
azure_ad_token = (
optional_params.pop("azure_ad_token", None) or
get_secret("AZURE_AD_TOKEN")
)
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
pass
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
return model_response
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):

View file

@ -650,7 +650,7 @@ def test_completion_azure_key_completion_arg():
except Exception as e:
os.environ["AZURE_API_KEY"] = old_key
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_key_completion_arg()
test_completion_azure_key_completion_arg()
async def test_re_use_azure_async_client():

View file

@ -0,0 +1,37 @@
# What this tests?
## This tests the litellm support for the openai /generations endpoint
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
def test_image_generation_openai():
litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
print(f"response: {response}")
# test_image_generation_openai()
# def test_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
# print(f"response: {response}")
# test_image_generation_azure()
# @pytest.mark.asyncio
# async def test_async_image_generation_openai():
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
# print(f"response: {response}")
# @pytest.mark.asyncio
# async def test_async_image_generation_azure():
# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3")
# print(f"response: {response}")

View file

@ -545,6 +545,52 @@ class TextCompletionResponse(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class ImageResponse(OpenAIObject):
created: Optional[int] = None
data: Optional[list] = None
def __init__(self, created=None, data=None, response_ms=None):
if response_ms:
_response_ms = response_ms
else:
_response_ms = None
if data:
data = data
else:
data = None
if created:
created = created
else:
created = None
super().__init__(data=data, created=created)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
############################################################
def print_verbose(print_statement):
try:
@ -561,6 +607,8 @@ class CallTypes(Enum):
completion = 'completion'
acompletion = 'acompletion'
aembedding = 'aembedding'
image_generation = 'image_generation'
aimage_generation = 'aimage_generation'
# Logging function -> log the exact model details + what's being sent | Non-Blocking
class Logging:
@ -1499,7 +1547,7 @@ def client(original_function):
# CRASH REPORTING TELEMETRY
crash_reporting(*args, **kwargs)
# INIT LOGGER - for user-specified integrations
model = args[0] if len(args) > 0 else kwargs["model"]
model = args[0] if len(args) > 0 else kwargs.get("model", None)
call_type = original_function.__name__
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
messages = None
@ -1512,6 +1560,8 @@ def client(original_function):
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value:
messages = args[1] if len(args) > 1 else kwargs["input"]
elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value:
messages = args[0] if len(args) > 0 else kwargs["prompt"]
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
return logging_obj
@ -1560,7 +1610,9 @@ def client(original_function):
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
call_type = original_function.__name__
if call_type != CallTypes.image_generation.value:
raise ValueError("model param not passed in.")
try:
if logging_obj is None:
@ -1614,7 +1666,7 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
post_call_processing(original_response=result, model=model or None)
# [OPTIONAL] ADD TO CACHE
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
@ -2207,6 +2259,47 @@ def get_litellm_params(
return litellm_params
def get_optional_params_image_gen(
n: Optional[int]=None,
quality: Optional[str]=None,
response_format: Optional[str]=None,
size: Optional[str]=None,
style: Optional[str]=None,
user: Optional[str]=None,
custom_llm_provider: Optional[str]=None,
**kwargs
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"n": None,
"quality" : None,
"response_format" : None,
"size": None,
"style": None,
"user": None,
}
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values
keys = list(non_default_params.keys())
for k in keys:
non_default_params.pop(k, None)
return non_default_params
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
final_params = {**non_default_params, **kwargs}
return final_params
def get_optional_params_embeddings(
# 2 optional params
@ -2854,7 +2947,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model:
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models:
custom_llm_provider = "openai"
elif model in litellm.open_ai_text_completion_models:
custom_llm_provider = "text-completion-openai"
@ -3801,7 +3894,7 @@ def convert_to_streaming_response(response_object: Optional[dict]=None):
yield model_response_object
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False):
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False):
try:
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
if response_object is None or model_response_object is None:
@ -3863,6 +3956,20 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
return model_response_object
elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)):
if response_object is None:
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = EmbeddingResponse()
if "created" in response_object:
model_response_object.created = response_object["created"]
if "data" in response_object:
model_response_object.data = response_object["data"]
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {e}")