mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge d78ee3182d
into b82af5b826
This commit is contained in:
commit
4fef15901b
7 changed files with 684 additions and 13 deletions
|
@ -140,6 +140,7 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
|||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||
from .llms.deprecated_providers import aleph_alpha, palm
|
||||
from .llms.diffusers.diffusers import DiffusersImageHandler
|
||||
from .llms.groq.chat.handler import GroqChatCompletion
|
||||
from .llms.huggingface.embedding.handler import HuggingFaceEmbedding
|
||||
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
||||
|
@ -228,6 +229,7 @@ codestral_text_completions = CodestralTextCompletion()
|
|||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
bedrock_embedding = BedrockEmbedding()
|
||||
bedrock_image_generation = BedrockImageGeneration()
|
||||
diffusers_image_generation = DiffusersImageHandler()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
vertex_embedding = VertexEmbedding()
|
||||
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||
|
@ -4564,7 +4566,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
|
|||
|
||||
|
||||
@client
|
||||
def image_generation( # noqa: PLR0915
|
||||
def image_generation(
|
||||
prompt: str,
|
||||
model: Optional[str] = None,
|
||||
n: Optional[int] = None,
|
||||
|
@ -4573,45 +4575,75 @@ def image_generation( # noqa: PLR0915
|
|||
size: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
timeout=600, # default to 10 minutes
|
||||
timeout=600,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
custom_llm_provider=None,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Maps the https://api.openai.com/v1/images/generations endpoint.
|
||||
|
||||
Currently supports just Azure + OpenAI.
|
||||
Handles image generation for various providers including local Diffusers models.
|
||||
"""
|
||||
try:
|
||||
args = locals()
|
||||
aimg_generation = kwargs.get("aimg_generation", False)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
|
||||
mock_response = kwargs.get("mock_response", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
client = kwargs.get("client", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
headers: dict = kwargs.get("headers", None) or {}
|
||||
headers = kwargs.get("headers", None) or {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
||||
|
||||
model_response = litellm.utils.ImageResponse()
|
||||
|
||||
# Get model provider info
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, # type: ignore
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
model_response._hidden_params["model"] = model
|
||||
|
||||
# Handle Diffusers/local models
|
||||
if model.startswith("diffusers/") or custom_llm_provider == "diffusers":
|
||||
from .llms.diffusers.diffusers import DiffusersImageHandler
|
||||
|
||||
model_path = model.replace("diffusers/", "")
|
||||
width, height = (512, 512)
|
||||
if size:
|
||||
width, height = map(int, size.split("x"))
|
||||
|
||||
handler = DiffusersImageHandler()
|
||||
diffusers_response = handler.generate_image(
|
||||
prompt=prompt,
|
||||
model=model_path,
|
||||
height=height,
|
||||
width=width,
|
||||
num_images_per_prompt=n or 1,
|
||||
device=device, # Pass through device parameter
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_response.created = diffusers_response.created
|
||||
model_response.data = diffusers_response.data
|
||||
return model_response
|
||||
|
||||
# Original provider handling remains the same
|
||||
openai_params = [
|
||||
"user",
|
||||
"request_timeout",
|
||||
|
@ -4629,11 +4661,12 @@ def image_generation( # noqa: PLR0915
|
|||
"size",
|
||||
"style",
|
||||
]
|
||||
|
||||
litellm_params = all_litellm_params
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
}
|
||||
|
||||
optional_params = get_optional_params_image_gen(
|
||||
model=model,
|
||||
|
@ -4649,7 +4682,7 @@ def image_generation( # noqa: PLR0915
|
|||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
logging: Logging = litellm_logging_obj
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
|
@ -4667,6 +4700,7 @@ def image_generation( # noqa: PLR0915
|
|||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if "custom_llm_provider" not in logging.model_call_details:
|
||||
logging.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||
if mock_response is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue