mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(main.py): add support for image generation endpoint
This commit is contained in:
parent
7847ae1e23
commit
13d088b72e
7 changed files with 366 additions and 9 deletions
117
litellm/utils.py
117
litellm/utils.py
|
@ -545,6 +545,52 @@ class TextCompletionResponse(OpenAIObject):
|
|||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class ImageResponse(OpenAIObject):
|
||||
created: Optional[int] = None
|
||||
|
||||
data: Optional[list] = None
|
||||
|
||||
def __init__(self, created=None, data=None, response_ms=None):
|
||||
if response_ms:
|
||||
_response_ms = response_ms
|
||||
else:
|
||||
_response_ms = None
|
||||
if data:
|
||||
data = data
|
||||
else:
|
||||
data = None
|
||||
|
||||
if created:
|
||||
created = created
|
||||
else:
|
||||
created = None
|
||||
|
||||
super().__init__(data=data, created=created)
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
############################################################
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
|
@ -561,6 +607,8 @@ class CallTypes(Enum):
|
|||
completion = 'completion'
|
||||
acompletion = 'acompletion'
|
||||
aembedding = 'aembedding'
|
||||
image_generation = 'image_generation'
|
||||
aimage_generation = 'aimage_generation'
|
||||
|
||||
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
||||
class Logging:
|
||||
|
@ -1499,7 +1547,7 @@ def client(original_function):
|
|||
# CRASH REPORTING TELEMETRY
|
||||
crash_reporting(*args, **kwargs)
|
||||
# INIT LOGGER - for user-specified integrations
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
model = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||
messages = None
|
||||
|
@ -1512,6 +1560,8 @@ def client(original_function):
|
|||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
|
||||
elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value:
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value:
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
||||
return logging_obj
|
||||
|
@ -1560,7 +1610,9 @@ def client(original_function):
|
|||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
except:
|
||||
raise ValueError("model param not passed in.")
|
||||
call_type = original_function.__name__
|
||||
if call_type != CallTypes.image_generation.value:
|
||||
raise ValueError("model param not passed in.")
|
||||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
|
@ -1614,7 +1666,7 @@ def client(original_function):
|
|||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
|
||||
|
@ -2207,6 +2259,47 @@ def get_litellm_params(
|
|||
|
||||
return litellm_params
|
||||
|
||||
def get_optional_params_image_gen(
|
||||
n: Optional[int]=None,
|
||||
quality: Optional[str]=None,
|
||||
response_format: Optional[str]=None,
|
||||
size: Optional[str]=None,
|
||||
style: Optional[str]=None,
|
||||
user: Optional[str]=None,
|
||||
custom_llm_provider: Optional[str]=None,
|
||||
**kwargs
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"n": None,
|
||||
"quality" : None,
|
||||
"response_format" : None,
|
||||
"size": None,
|
||||
"style": None,
|
||||
"user": None,
|
||||
}
|
||||
|
||||
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
||||
if len(non_default_params.keys()) > 0:
|
||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
non_default_params.pop(k, None)
|
||||
return non_default_params
|
||||
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
|
||||
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
|
||||
|
||||
|
||||
def get_optional_params_embeddings(
|
||||
# 2 optional params
|
||||
|
@ -2854,7 +2947,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
|||
|
||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||
## openai - chatcompletion + text completion
|
||||
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model:
|
||||
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models:
|
||||
custom_llm_provider = "openai"
|
||||
elif model in litellm.open_ai_text_completion_models:
|
||||
custom_llm_provider = "text-completion-openai"
|
||||
|
@ -3801,7 +3894,7 @@ def convert_to_streaming_response(response_object: Optional[dict]=None):
|
|||
yield model_response_object
|
||||
|
||||
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False):
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False):
|
||||
try:
|
||||
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
|
||||
if response_object is None or model_response_object is None:
|
||||
|
@ -3863,6 +3956,20 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
|
|||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||
|
||||
|
||||
return model_response_object
|
||||
elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)):
|
||||
if response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response_object is None:
|
||||
model_response_object = EmbeddingResponse()
|
||||
|
||||
if "created" in response_object:
|
||||
model_response_object.created = response_object["created"]
|
||||
|
||||
if "data" in response_object:
|
||||
model_response_object.data = response_object["data"]
|
||||
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid response object {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue