From 51cb16a01506e585a3384d76bb6522a12e1029ff Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 16 Dec 2023 21:07:29 -0800 Subject: [PATCH] feat(main.py): add support for image generation endpoint --- litellm/__init__.py | 10 ++- litellm/llms/azure.py | 61 +++++++++++++ litellm/llms/openai.py | 60 +++++++++++++ litellm/main.py | 88 ++++++++++++++++++- litellm/tests/test_completion.py | 2 +- litellm/tests/test_image_generation.py | 37 ++++++++ litellm/utils.py | 117 +++++++++++++++++++++++-- 7 files changed, 366 insertions(+), 9 deletions(-) create mode 100644 litellm/tests/test_image_generation.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 36ce31cfa1..bd7e6b11f9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -132,8 +132,7 @@ for key, value in model_cost.items(): elif value.get('litellm_provider') == 'anthropic': anthropic_models.append(key) elif value.get('litellm_provider') == 'openrouter': - split_string = key.split('/', 1) - openrouter_models.append(split_string[1]) + openrouter_models.append(key) elif value.get('litellm_provider') == 'vertex_ai-text-models': vertex_text_models.append(key) elif value.get('litellm_provider') == 'vertex_ai-code-text-models': @@ -366,6 +365,13 @@ bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-en all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models +####### IMAGE GENERATION MODELS ################### +openai_image_generation_models = [ + "dall-e-2", + "dall-e-3" +] + + from .timeout import timeout from .utils import ( client, diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 2be62ebc1e..e785294902 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -456,6 +456,67 @@ class AzureChatCompletion(BaseLLM): except AzureOpenAIError as e: exception_mapping_worked = True raise e + except Exception as e: + if exception_mapping_worked: + raise e + else: + import traceback + raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) + + def image_generation(self, + prompt: list, + timeout: float, + model: Optional[str]=None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model_response: Optional[litellm.utils.ImageResponse] = None, + logging_obj=None, + optional_params=None, + client=None, + aimg_generation=None, + ): + exception_mapping_worked = False + try: + model = model + data = { + # "model": model, + "prompt": prompt, + **optional_params + } + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError(status_code=422, message="max retries must be an int") + + # if aembedding == True: + # response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + # return response + + if client is None: + azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore + else: + azure_client = client + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=azure_client.api_key, + additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data}, + ) + + ## COMPLETION CALL + response = azure_client.images.generate(**data) # type: ignore + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + # return response + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore + except AzureOpenAIError as e: + exception_mapping_worked = True + raise e except Exception as e: if exception_mapping_worked: raise e diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 4fdecce109..c923cbf2dd 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -445,6 +445,66 @@ class OpenAIChatCompletion(BaseLLM): import traceback raise OpenAIError(status_code=500, message=traceback.format_exc()) + def image_generation(self, + model: Optional[str], + prompt: str, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model_response: Optional[litellm.utils.ImageResponse] = None, + logging_obj=None, + optional_params=None, + client=None, + aimg_generation=None, + ): + exception_mapping_worked = False + try: + model = model + data = { + "model": model, + "prompt": prompt, + **optional_params + } + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise OpenAIError(status_code=422, message="max retries must be an int") + + # if aembedding == True: + # response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + # return response + + if client is None: + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + else: + openai_client = client + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=openai_client.api_key, + additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, + ) + + ## COMPLETION CALL + response = openai_client.images.generate(**data) # type: ignore + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + # return response + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore + except OpenAIError as e: + exception_mapping_worked = True + raise e + except Exception as e: + if exception_mapping_worked: + raise e + else: + import traceback + raise OpenAIError(status_code=500, message=traceback.format_exc()) class OpenAITextCompletion(BaseLLM): _client_session: httpx.Client diff --git a/litellm/main.py b/litellm/main.py index d666cfb2c8..6bd1d08035 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -33,7 +33,8 @@ from litellm.utils import ( convert_to_model_response_object, token_counter, Usage, - get_optional_params_embeddings + get_optional_params_embeddings, + get_optional_params_image_gen ) from .llms import ( anthropic, @@ -2237,6 +2238,91 @@ def moderation(input: str, api_key: Optional[str]=None): response = openai.moderations.create(input=input) return response +##### Image Generation ####################### +@client +def image_generation(prompt: str, + model: Optional[str]=None, + n: Optional[int]=None, + quality: Optional[str]=None, + response_format: Optional[str]=None, + size: Optional[str]=None, + style: Optional[str]=None, + user: Optional[str]=None, + timeout=600, # default to 10 minutes + api_key: Optional[str]=None, + api_base: Optional[str]=None, + api_version: Optional[str] = None, + litellm_logging_obj=None, + custom_llm_provider=None, + **kwargs): + """ + Maps the https://api.openai.com/v1/images/generations endpoint. + + Currently supports just Azure + OpenAI. + """ + litellm_call_id = kwargs.get("litellm_call_id", None) + logger_fn = kwargs.get("logger_fn", None) + proxy_server_request = kwargs.get('proxy_server_request', None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + + model_response = litellm.utils.ImageResponse() + if model is not None or custom_llm_provider is not None: + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + else: + model = "dall-e-2" + custom_llm_provider = "openai" # default to dall-e-2 on openai + openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"] + litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] + default_params = openai_params + litellm_params + non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider + optional_params = get_optional_params_image_gen(n=n, + quality=quality, + response_format=response_format, + size=size, + style=style, + user=user, + custom_llm_provider=custom_llm_provider, + **non_default_params) + logging = litellm_logging_obj + logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}}) + + if custom_llm_provider == "azure": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = ( + api_base + or litellm.api_base + or get_secret("AZURE_API_BASE") + ) + + api_version = ( + api_version or + litellm.api_version or + get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key or + litellm.api_key or + litellm.azure_key or + get_secret("AZURE_OPENAI_API_KEY") or + get_secret("AZURE_API_KEY") + ) + + azure_ad_token = ( + optional_params.pop("azure_ad_token", None) or + get_secret("AZURE_AD_TOKEN") + ) + + # model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) + pass + elif custom_llm_provider == "openai": + model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) + + return model_response + ####### HELPER FUNCTIONS ################ ## Set verbose to true -> ```litellm.set_verbose = True``` def print_verbose(print_statement): diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d8babc0ca2..1e41550624 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -650,7 +650,7 @@ def test_completion_azure_key_completion_arg(): except Exception as e: os.environ["AZURE_API_KEY"] = old_key pytest.fail(f"Error occurred: {e}") -# test_completion_azure_key_completion_arg() +test_completion_azure_key_completion_arg() async def test_re_use_azure_async_client(): diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py new file mode 100644 index 0000000000..a265c0f65e --- /dev/null +++ b/litellm/tests/test_image_generation.py @@ -0,0 +1,37 @@ +# What this tests? +## This tests the litellm support for the openai /generations endpoint + +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm + +def test_image_generation_openai(): + litellm.set_verbose = True + response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") + print(f"response: {response}") + +# test_image_generation_openai() + +# def test_image_generation_azure(): +# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure") +# print(f"response: {response}") +# test_image_generation_azure() + +# @pytest.mark.asyncio +# async def test_async_image_generation_openai(): +# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") +# print(f"response: {response}") + +# @pytest.mark.asyncio +# async def test_async_image_generation_azure(): +# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3") +# print(f"response: {response}") \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index fce23ee720..ab22c200dc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -545,6 +545,52 @@ class TextCompletionResponse(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + +class ImageResponse(OpenAIObject): + created: Optional[int] = None + + data: Optional[list] = None + + def __init__(self, created=None, data=None, response_ms=None): + if response_ms: + _response_ms = response_ms + else: + _response_ms = None + if data: + data = data + else: + data = None + + if created: + created = created + else: + created = None + + super().__init__(data=data, created=created) + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + ############################################################ def print_verbose(print_statement): try: @@ -561,6 +607,8 @@ class CallTypes(Enum): completion = 'completion' acompletion = 'acompletion' aembedding = 'aembedding' + image_generation = 'image_generation' + aimage_generation = 'aimage_generation' # Logging function -> log the exact model details + what's being sent | Non-Blocking class Logging: @@ -1499,7 +1547,7 @@ def client(original_function): # CRASH REPORTING TELEMETRY crash_reporting(*args, **kwargs) # INIT LOGGER - for user-specified integrations - model = args[0] if len(args) > 0 else kwargs["model"] + model = args[0] if len(args) > 0 else kwargs.get("model", None) call_type = original_function.__name__ if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: messages = None @@ -1512,6 +1560,8 @@ def client(original_function): rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model) elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value: messages = args[1] if len(args) > 1 else kwargs["input"] + elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value: + messages = args[0] if len(args) > 0 else kwargs["prompt"] stream = True if "stream" in kwargs and kwargs["stream"] == True else False logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time) return logging_obj @@ -1560,7 +1610,9 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] except: - raise ValueError("model param not passed in.") + call_type = original_function.__name__ + if call_type != CallTypes.image_generation.value: + raise ValueError("model param not passed in.") try: if logging_obj is None: @@ -1614,7 +1666,7 @@ def client(original_function): return result ### POST-CALL RULES ### - post_call_processing(original_response=result, model=model) + post_call_processing(original_response=result, model=model or None) # [OPTIONAL] ADD TO CACHE if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: @@ -2207,6 +2259,47 @@ def get_litellm_params( return litellm_params +def get_optional_params_image_gen( + n: Optional[int]=None, + quality: Optional[str]=None, + response_format: Optional[str]=None, + size: Optional[str]=None, + style: Optional[str]=None, + user: Optional[str]=None, + custom_llm_provider: Optional[str]=None, + **kwargs +): + # retrieve all parameters passed to the function + passed_params = locals() + custom_llm_provider = passed_params.pop("custom_llm_provider") + special_params = passed_params.pop("kwargs") + for k, v in special_params.items(): + passed_params[k] = v + + default_params = { + "n": None, + "quality" : None, + "response_format" : None, + "size": None, + "style": None, + "user": None, + } + + non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} + ## raise exception if non-default value passed for non-openai/azure embedding calls + if custom_llm_provider != "openai" and custom_llm_provider != "azure": + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values + keys = list(non_default_params.keys()) + for k in keys: + non_default_params.pop(k, None) + return non_default_params + raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") + + final_params = {**non_default_params, **kwargs} + return final_params + + def get_optional_params_embeddings( # 2 optional params @@ -2854,7 +2947,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) ## openai - chatcompletion + text completion - if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model: + if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models: custom_llm_provider = "openai" elif model in litellm.open_ai_text_completion_models: custom_llm_provider = "text-completion-openai" @@ -3801,7 +3894,7 @@ def convert_to_streaming_response(response_object: Optional[dict]=None): yield model_response_object -def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False): +def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False): try: if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)): if response_object is None or model_response_object is None: @@ -3863,6 +3956,20 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + return model_response_object + elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = EmbeddingResponse() + + if "created" in response_object: + model_response_object.created = response_object["created"] + + if "data" in response_object: + model_response_object.data = response_object["data"] + return model_response_object except Exception as e: raise Exception(f"Invalid response object {e}")