This commit is contained in:
Christopher 2025-04-24 00:56:18 -07:00 committed by GitHub
commit 4fef15901b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 684 additions and 13 deletions

View file

@ -140,6 +140,7 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.deprecated_providers import aleph_alpha, palm
from .llms.diffusers.diffusers import DiffusersImageHandler
from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface.embedding.handler import HuggingFaceEmbedding
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
@ -228,6 +229,7 @@ codestral_text_completions = CodestralTextCompletion()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
bedrock_image_generation = BedrockImageGeneration()
diffusers_image_generation = DiffusersImageHandler()
vertex_chat_completion = VertexLLM()
vertex_embedding = VertexEmbedding()
vertex_multimodal_embedding = VertexMultimodalEmbedding()
@ -4564,7 +4566,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
@client
def image_generation( # noqa: PLR0915
def image_generation(
prompt: str,
model: Optional[str] = None,
n: Optional[int] = None,
@ -4573,45 +4575,75 @@ def image_generation( # noqa: PLR0915
size: Optional[str] = None,
style: Optional[str] = None,
user: Optional[str] = None,
timeout=600, # default to 10 minutes
timeout=600,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
custom_llm_provider=None,
device: Optional[str] = None,
**kwargs,
) -> ImageResponse:
"""
Maps the https://api.openai.com/v1/images/generations endpoint.
Currently supports just Azure + OpenAI.
Handles image generation for various providers including local Diffusers models.
"""
try:
args = locals()
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
mock_response = kwargs.get("mock_response", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_logging_obj = kwargs.get("litellm_logging_obj")
client = kwargs.get("client", None)
extra_headers = kwargs.get("extra_headers", None)
headers: dict = kwargs.get("headers", None) or {}
headers = kwargs.get("headers", None) or {}
if extra_headers is not None:
headers.update(extra_headers)
model_response: ImageResponse = litellm.utils.ImageResponse()
model_response = litellm.utils.ImageResponse()
# Get model provider info
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
model=model,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
custom_llm_provider = "openai"
model_response._hidden_params["model"] = model
# Handle Diffusers/local models
if model.startswith("diffusers/") or custom_llm_provider == "diffusers":
from .llms.diffusers.diffusers import DiffusersImageHandler
model_path = model.replace("diffusers/", "")
width, height = (512, 512)
if size:
width, height = map(int, size.split("x"))
handler = DiffusersImageHandler()
diffusers_response = handler.generate_image(
prompt=prompt,
model=model_path,
height=height,
width=width,
num_images_per_prompt=n or 1,
device=device, # Pass through device parameter
**kwargs,
)
model_response.created = diffusers_response.created
model_response.data = diffusers_response.data
return model_response
# Original provider handling remains the same
openai_params = [
"user",
"request_timeout",
@ -4629,11 +4661,12 @@ def image_generation( # noqa: PLR0915
"size",
"style",
]
litellm_params = all_litellm_params
default_params = openai_params + litellm_params
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
}
optional_params = get_optional_params_image_gen(
model=model,
@ -4649,7 +4682,7 @@ def image_generation( # noqa: PLR0915
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj
logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
user=user,
@ -4667,6 +4700,7 @@ def image_generation( # noqa: PLR0915
},
custom_llm_provider=custom_llm_provider,
)
if "custom_llm_provider" not in logging.model_call_details:
logging.model_call_details["custom_llm_provider"] = custom_llm_provider
if mock_response is not None: