feat(main.py): add support for image generation endpoint

This commit is contained in:
Krrish Dholakia 2023-12-16 21:07:29 -08:00
parent 7847ae1e23
commit 13d088b72e
7 changed files with 366 additions and 9 deletions

View file

@ -33,7 +33,8 @@ from litellm.utils import (
convert_to_model_response_object,
token_counter,
Usage,
get_optional_params_embeddings
get_optional_params_embeddings,
get_optional_params_image_gen
)
from .llms import (
anthropic,
@ -2237,6 +2238,91 @@ def moderation(input: str, api_key: Optional[str]=None):
response = openai.moderations.create(input=input)
return response
##### Image Generation #######################
@client
def image_generation(prompt: str,
model: Optional[str]=None,
n: Optional[int]=None,
quality: Optional[str]=None,
response_format: Optional[str]=None,
size: Optional[str]=None,
style: Optional[str]=None,
user: Optional[str]=None,
timeout=600, # default to 10 minutes
api_key: Optional[str]=None,
api_base: Optional[str]=None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs):
"""
Maps the https://api.openai.com/v1/images/generations endpoint.
Currently supports just Azure + OpenAI.
"""
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get('proxy_server_request', None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model_response = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(n=n,
quality=quality,
response_format=response_format,
size=size,
style=style,
user=user,
custom_llm_provider=custom_llm_provider,
**non_default_params)
logging = litellm_logging_obj
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}})
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_base = (
api_base
or litellm.api_base
or get_secret("AZURE_API_BASE")
)
api_version = (
api_version or
litellm.api_version or
get_secret("AZURE_API_VERSION")
)
api_key = (
api_key or
litellm.api_key or
litellm.azure_key or
get_secret("AZURE_OPENAI_API_KEY") or
get_secret("AZURE_API_KEY")
)
azure_ad_token = (
optional_params.pop("azure_ad_token", None) or
get_secret("AZURE_AD_TOKEN")
)
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
pass
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
return model_response
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):