forked from phoenix/litellm-mirror
feat(main.py): add support for image generation endpoint
This commit is contained in:
parent
7847ae1e23
commit
13d088b72e
7 changed files with 366 additions and 9 deletions
|
@ -132,8 +132,7 @@ for key, value in model_cost.items():
|
|||
elif value.get('litellm_provider') == 'anthropic':
|
||||
anthropic_models.append(key)
|
||||
elif value.get('litellm_provider') == 'openrouter':
|
||||
split_string = key.split('/', 1)
|
||||
openrouter_models.append(split_string[1])
|
||||
openrouter_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-text-models':
|
||||
vertex_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||
|
@ -366,6 +365,13 @@ bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-en
|
|||
|
||||
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||
|
||||
####### IMAGE GENERATION MODELS ###################
|
||||
openai_image_generation_models = [
|
||||
"dall-e-2",
|
||||
"dall-e-3"
|
||||
]
|
||||
|
||||
|
||||
from .timeout import timeout
|
||||
from .utils import (
|
||||
client,
|
||||
|
|
|
@ -456,6 +456,67 @@ class AzureChatCompletion(BaseLLM):
|
|||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
import traceback
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
def image_generation(self,
|
||||
prompt: list,
|
||||
timeout: float,
|
||||
model: Optional[str]=None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
model = model
|
||||
data = {
|
||||
# "model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
# if aembedding == True:
|
||||
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
# return response
|
||||
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
response = azure_client.images.generate(**data) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
# return response
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
|
|
|
@ -445,6 +445,66 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
def image_generation(self,
|
||||
model: Optional[str],
|
||||
prompt: str,
|
||||
timeout: float,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
model = model
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
# if aembedding == True:
|
||||
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
# return response
|
||||
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=openai_client.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
response = openai_client.images.generate(**data) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
# return response
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except OpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
_client_session: httpx.Client
|
||||
|
|
|
@ -33,7 +33,8 @@ from litellm.utils import (
|
|||
convert_to_model_response_object,
|
||||
token_counter,
|
||||
Usage,
|
||||
get_optional_params_embeddings
|
||||
get_optional_params_embeddings,
|
||||
get_optional_params_image_gen
|
||||
)
|
||||
from .llms import (
|
||||
anthropic,
|
||||
|
@ -2237,6 +2238,91 @@ def moderation(input: str, api_key: Optional[str]=None):
|
|||
response = openai.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
##### Image Generation #######################
|
||||
@client
|
||||
def image_generation(prompt: str,
|
||||
model: Optional[str]=None,
|
||||
n: Optional[int]=None,
|
||||
quality: Optional[str]=None,
|
||||
response_format: Optional[str]=None,
|
||||
size: Optional[str]=None,
|
||||
style: Optional[str]=None,
|
||||
user: Optional[str]=None,
|
||||
timeout=600, # default to 10 minutes
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_logging_obj=None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Maps the https://api.openai.com/v1/images/generations endpoint.
|
||||
|
||||
Currently supports just Azure + OpenAI.
|
||||
"""
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
proxy_server_request = kwargs.get('proxy_server_request', None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
|
||||
model_response = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
|
||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = get_optional_params_image_gen(n=n,
|
||||
quality=quality,
|
||||
response_format=response_format,
|
||||
size=size,
|
||||
style=style,
|
||||
user=user,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params)
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}})
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("AZURE_API_BASE")
|
||||
)
|
||||
|
||||
api_version = (
|
||||
api_version or
|
||||
litellm.api_version or
|
||||
get_secret("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key or
|
||||
litellm.api_key or
|
||||
litellm.azure_key or
|
||||
get_secret("AZURE_OPENAI_API_KEY") or
|
||||
get_secret("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
azure_ad_token = (
|
||||
optional_params.pop("azure_ad_token", None) or
|
||||
get_secret("AZURE_AD_TOKEN")
|
||||
)
|
||||
|
||||
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||
pass
|
||||
elif custom_llm_provider == "openai":
|
||||
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||
|
||||
return model_response
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||
def print_verbose(print_statement):
|
||||
|
|
|
@ -650,7 +650,7 @@ def test_completion_azure_key_completion_arg():
|
|||
except Exception as e:
|
||||
os.environ["AZURE_API_KEY"] = old_key
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_azure_key_completion_arg()
|
||||
test_completion_azure_key_completion_arg()
|
||||
|
||||
|
||||
async def test_re_use_azure_async_client():
|
||||
|
|
37
litellm/tests/test_image_generation.py
Normal file
37
litellm/tests/test_image_generation.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# What this tests?
|
||||
## This tests the litellm support for the openai /generations endpoint
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
def test_image_generation_openai():
|
||||
litellm.set_verbose = True
|
||||
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
print(f"response: {response}")
|
||||
|
||||
# test_image_generation_openai()
|
||||
|
||||
# def test_image_generation_azure():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
|
||||
# print(f"response: {response}")
|
||||
# test_image_generation_azure()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_async_image_generation_openai():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||
# print(f"response: {response}")
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_async_image_generation_azure():
|
||||
# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3")
|
||||
# print(f"response: {response}")
|
117
litellm/utils.py
117
litellm/utils.py
|
@ -545,6 +545,52 @@ class TextCompletionResponse(OpenAIObject):
|
|||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class ImageResponse(OpenAIObject):
|
||||
created: Optional[int] = None
|
||||
|
||||
data: Optional[list] = None
|
||||
|
||||
def __init__(self, created=None, data=None, response_ms=None):
|
||||
if response_ms:
|
||||
_response_ms = response_ms
|
||||
else:
|
||||
_response_ms = None
|
||||
if data:
|
||||
data = data
|
||||
else:
|
||||
data = None
|
||||
|
||||
if created:
|
||||
created = created
|
||||
else:
|
||||
created = None
|
||||
|
||||
super().__init__(data=data, created=created)
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
############################################################
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
|
@ -561,6 +607,8 @@ class CallTypes(Enum):
|
|||
completion = 'completion'
|
||||
acompletion = 'acompletion'
|
||||
aembedding = 'aembedding'
|
||||
image_generation = 'image_generation'
|
||||
aimage_generation = 'aimage_generation'
|
||||
|
||||
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
||||
class Logging:
|
||||
|
@ -1499,7 +1547,7 @@ def client(original_function):
|
|||
# CRASH REPORTING TELEMETRY
|
||||
crash_reporting(*args, **kwargs)
|
||||
# INIT LOGGER - for user-specified integrations
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
model = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||
messages = None
|
||||
|
@ -1512,6 +1560,8 @@ def client(original_function):
|
|||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
|
||||
elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value:
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value:
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
||||
return logging_obj
|
||||
|
@ -1560,7 +1610,9 @@ def client(original_function):
|
|||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
except:
|
||||
raise ValueError("model param not passed in.")
|
||||
call_type = original_function.__name__
|
||||
if call_type != CallTypes.image_generation.value:
|
||||
raise ValueError("model param not passed in.")
|
||||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
|
@ -1614,7 +1666,7 @@ def client(original_function):
|
|||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
|
||||
|
@ -2207,6 +2259,47 @@ def get_litellm_params(
|
|||
|
||||
return litellm_params
|
||||
|
||||
def get_optional_params_image_gen(
|
||||
n: Optional[int]=None,
|
||||
quality: Optional[str]=None,
|
||||
response_format: Optional[str]=None,
|
||||
size: Optional[str]=None,
|
||||
style: Optional[str]=None,
|
||||
user: Optional[str]=None,
|
||||
custom_llm_provider: Optional[str]=None,
|
||||
**kwargs
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"n": None,
|
||||
"quality" : None,
|
||||
"response_format" : None,
|
||||
"size": None,
|
||||
"style": None,
|
||||
"user": None,
|
||||
}
|
||||
|
||||
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
||||
if len(non_default_params.keys()) > 0:
|
||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
non_default_params.pop(k, None)
|
||||
return non_default_params
|
||||
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
|
||||
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
|
||||
|
||||
|
||||
def get_optional_params_embeddings(
|
||||
# 2 optional params
|
||||
|
@ -2854,7 +2947,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
|||
|
||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||
## openai - chatcompletion + text completion
|
||||
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model:
|
||||
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models:
|
||||
custom_llm_provider = "openai"
|
||||
elif model in litellm.open_ai_text_completion_models:
|
||||
custom_llm_provider = "text-completion-openai"
|
||||
|
@ -3801,7 +3894,7 @@ def convert_to_streaming_response(response_object: Optional[dict]=None):
|
|||
yield model_response_object
|
||||
|
||||
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False):
|
||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False):
|
||||
try:
|
||||
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
|
||||
if response_object is None or model_response_object is None:
|
||||
|
@ -3863,6 +3956,20 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
|
|||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||
|
||||
|
||||
return model_response_object
|
||||
elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)):
|
||||
if response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response_object is None:
|
||||
model_response_object = EmbeddingResponse()
|
||||
|
||||
if "created" in response_object:
|
||||
model_response_object.created = response_object["created"]
|
||||
|
||||
if "data" in response_object:
|
||||
model_response_object.data = response_object["data"]
|
||||
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid response object {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue