feat(main.py): add support for image generation endpoint

This commit is contained in:
Krrish Dholakia 2023-12-16 21:07:29 -08:00
parent 7847ae1e23
commit 13d088b72e
7 changed files with 366 additions and 9 deletions

View file

@ -545,6 +545,52 @@ class TextCompletionResponse(OpenAIObject):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class ImageResponse(OpenAIObject):
created: Optional[int] = None
data: Optional[list] = None
def __init__(self, created=None, data=None, response_ms=None):
if response_ms:
_response_ms = response_ms
else:
_response_ms = None
if data:
data = data
else:
data = None
if created:
created = created
else:
created = None
super().__init__(data=data, created=created)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
############################################################
def print_verbose(print_statement):
try:
@ -561,6 +607,8 @@ class CallTypes(Enum):
completion = 'completion'
acompletion = 'acompletion'
aembedding = 'aembedding'
image_generation = 'image_generation'
aimage_generation = 'aimage_generation'
# Logging function -> log the exact model details + what's being sent | Non-Blocking
class Logging:
@ -1499,7 +1547,7 @@ def client(original_function):
# CRASH REPORTING TELEMETRY
crash_reporting(*args, **kwargs)
# INIT LOGGER - for user-specified integrations
model = args[0] if len(args) > 0 else kwargs["model"]
model = args[0] if len(args) > 0 else kwargs.get("model", None)
call_type = original_function.__name__
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
messages = None
@ -1512,6 +1560,8 @@ def client(original_function):
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value:
messages = args[1] if len(args) > 1 else kwargs["input"]
elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value:
messages = args[0] if len(args) > 0 else kwargs["prompt"]
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
return logging_obj
@ -1560,7 +1610,9 @@ def client(original_function):
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
call_type = original_function.__name__
if call_type != CallTypes.image_generation.value:
raise ValueError("model param not passed in.")
try:
if logging_obj is None:
@ -1614,7 +1666,7 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
post_call_processing(original_response=result, model=model or None)
# [OPTIONAL] ADD TO CACHE
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
@ -2207,6 +2259,47 @@ def get_litellm_params(
return litellm_params
def get_optional_params_image_gen(
n: Optional[int]=None,
quality: Optional[str]=None,
response_format: Optional[str]=None,
size: Optional[str]=None,
style: Optional[str]=None,
user: Optional[str]=None,
custom_llm_provider: Optional[str]=None,
**kwargs
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"n": None,
"quality" : None,
"response_format" : None,
"size": None,
"style": None,
"user": None,
}
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values
keys = list(non_default_params.keys())
for k in keys:
non_default_params.pop(k, None)
return non_default_params
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
final_params = {**non_default_params, **kwargs}
return final_params
def get_optional_params_embeddings(
# 2 optional params
@ -2854,7 +2947,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model:
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models:
custom_llm_provider = "openai"
elif model in litellm.open_ai_text_completion_models:
custom_llm_provider = "text-completion-openai"
@ -3801,7 +3894,7 @@ def convert_to_streaming_response(response_object: Optional[dict]=None):
yield model_response_object
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False):
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False):
try:
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
if response_object is None or model_response_object is None:
@ -3863,6 +3956,20 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
return model_response_object
elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)):
if response_object is None:
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = EmbeddingResponse()
if "created" in response_object:
model_response_object.created = response_object["created"]
if "data" in response_object:
model_response_object.data = response_object["data"]
return model_response_object
except Exception as e:
raise Exception(f"Invalid response object {e}")