mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge d78ee3182d
into b82af5b826
This commit is contained in:
commit
4fef15901b
7 changed files with 684 additions and 13 deletions
|
@ -323,6 +323,8 @@ def get_llm_provider( # noqa: PLR0915
|
||||||
custom_llm_provider = "empower"
|
custom_llm_provider = "empower"
|
||||||
elif model == "*":
|
elif model == "*":
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
|
elif "diffusers" in model:
|
||||||
|
custom_llm_provider = "diffusers"
|
||||||
if not custom_llm_provider:
|
if not custom_llm_provider:
|
||||||
if litellm.suppress_debug_info is False:
|
if litellm.suppress_debug_info is False:
|
||||||
print() # noqa
|
print() # noqa
|
||||||
|
|
123
litellm/llms/diffusers/diffusers.py
Normal file
123
litellm/llms/diffusers/diffusers.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
from typing import Optional, Union, List, Dict
|
||||||
|
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResponse(BaseModel):
|
||||||
|
created: int
|
||||||
|
data: List[Dict[str, str]] # List of dicts with "b64_json" or "url"
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersImageHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self.pipeline_cache = {} # Cache loaded pipelines
|
||||||
|
self.device = self._get_default_device()
|
||||||
|
|
||||||
|
def _get_default_device(self):
|
||||||
|
"""Determine the best available device"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
elif torch.backends.mps.is_available(): # For Apple Silicon
|
||||||
|
return "mps"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
def _load_pipeline(
|
||||||
|
self, model: str, device: Optional[str] = None
|
||||||
|
) -> StableDiffusionPipeline:
|
||||||
|
"""Load and cache diffusion pipeline"""
|
||||||
|
device = device or self.device
|
||||||
|
|
||||||
|
if model not in self.pipeline_cache:
|
||||||
|
try:
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
self.pipeline_cache[model] = pipe
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "CUDA" in str(e):
|
||||||
|
# Fallback to CPU if CUDA fails
|
||||||
|
verbose_logger.warning(f"Falling back to CPU: {str(e)}")
|
||||||
|
pipe = pipe.to("cpu")
|
||||||
|
self.pipeline_cache[model] = pipe
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return self.pipeline_cache[model]
|
||||||
|
|
||||||
|
def _image_to_b64(self, image: Image.Image) -> str:
|
||||||
|
"""Convert PIL Image to base64 string"""
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
self, prompt: str, model: str, num_images_per_prompt: int = 1, **kwargs
|
||||||
|
) -> ImageResponse:
|
||||||
|
# Get or create pipeline
|
||||||
|
if model not in self.pipeline_cache:
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
self.pipeline_cache[model] = StableDiffusionPipeline.from_pretrained(model)
|
||||||
|
|
||||||
|
pipe = self.pipeline_cache[model]
|
||||||
|
|
||||||
|
# Generate images
|
||||||
|
images = pipe(
|
||||||
|
prompt=prompt, num_images_per_prompt=num_images_per_prompt, **kwargs
|
||||||
|
).images
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
image_data = []
|
||||||
|
for img in images:
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
img.save(buffered, format="PNG")
|
||||||
|
image_data.append(
|
||||||
|
{"b64_json": base64.b64encode(buffered.getvalue()).decode("utf-8")}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageResponse(created=int(time.time()), data=image_data)
|
||||||
|
|
||||||
|
def generate_variation(
|
||||||
|
self,
|
||||||
|
image: Union[Image.Image, str, bytes], # Accepts PIL, file path, or bytes
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
model: str = "runwayml/stable-diffusion-v1-5",
|
||||||
|
strength: float = 0.8,
|
||||||
|
**kwargs,
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""
|
||||||
|
Generate variation of input image
|
||||||
|
Args:
|
||||||
|
image: Input image (PIL, file path, or bytes)
|
||||||
|
prompt: Optional text prompt to guide variation
|
||||||
|
model: Diffusers model ID
|
||||||
|
strength: Strength of variation (0-1)
|
||||||
|
Returns:
|
||||||
|
ImageResponse with base64 encoded images
|
||||||
|
"""
|
||||||
|
# Convert input to PIL Image
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image)
|
||||||
|
elif isinstance(image, bytes):
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
|
||||||
|
pipe = self._load_pipeline(model)
|
||||||
|
|
||||||
|
# Generate variation
|
||||||
|
result = pipe(prompt=prompt, image=image, strength=strength, **kwargs)
|
||||||
|
|
||||||
|
# Convert to response format
|
||||||
|
image_data = [{"b64_json": self._image_to_b64(result.images[0])}]
|
||||||
|
|
||||||
|
return ImageResponse(created=int(time.time()), data=image_data)
|
211
litellm/llms/diffusers/fine_tuning/handler.py
Normal file
211
litellm/llms/diffusers/fine_tuning/handler.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
from typing import Any, Coroutine, Optional, Union, Dict, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
verbose_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FineTuningJob:
|
||||||
|
id: str
|
||||||
|
status: str
|
||||||
|
model: str
|
||||||
|
created_at: int
|
||||||
|
hyperparameters: Dict[str, Any]
|
||||||
|
result_files: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersFineTuningAPI:
|
||||||
|
"""
|
||||||
|
Diffusers implementation for fine-tuning stable diffusion models locally
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.jobs: Dict[str, FineTuningJob] = {}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _train_diffusers_model(
|
||||||
|
self,
|
||||||
|
training_data: str,
|
||||||
|
base_model: str = "stabilityai/stable-diffusion-2",
|
||||||
|
output_dir: str = "./fine_tuned_model",
|
||||||
|
learning_rate: float = 5e-6,
|
||||||
|
train_batch_size: int = 1,
|
||||||
|
max_train_steps: int = 500,
|
||||||
|
gradient_accumulation_steps: int = 1,
|
||||||
|
mixed_precision: str = "fp16",
|
||||||
|
) -> FineTuningJob:
|
||||||
|
"""Actual training implementation for diffusers"""
|
||||||
|
job_id = f"ftjob_{len(self.jobs)+1}"
|
||||||
|
job = FineTuningJob(
|
||||||
|
id=job_id,
|
||||||
|
status="running",
|
||||||
|
model=base_model,
|
||||||
|
created_at=int(time.time()),
|
||||||
|
hyperparameters={
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"batch_size": train_batch_size,
|
||||||
|
"steps": max_train_steps,
|
||||||
|
},
|
||||||
|
result_files=[output_dir],
|
||||||
|
)
|
||||||
|
self.jobs[job_id] = job
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load models and create pipeline
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer")
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
|
base_model, subfolder="text_encoder"
|
||||||
|
)
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(base_model, subfolder="unet")
|
||||||
|
|
||||||
|
# Optimizer and scheduler
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
unet.parameters(),
|
||||||
|
lr=learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
"linear",
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=0,
|
||||||
|
num_training_steps=max_train_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop would go here
|
||||||
|
# This is simplified - actual implementation would need:
|
||||||
|
# 1. Data loading from training_data path
|
||||||
|
# 2. Proper training loop with forward/backward passes
|
||||||
|
# 3. Saving checkpoints
|
||||||
|
|
||||||
|
# Simulate training
|
||||||
|
for step in range(max_train_steps):
|
||||||
|
if step % 10 == 0:
|
||||||
|
verbose_logger.debug(f"Training step {step}/{max_train_steps}")
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
unet.save_pretrained(f"{output_dir}/unet")
|
||||||
|
text_encoder.save_pretrained(f"{output_dir}/text_encoder")
|
||||||
|
|
||||||
|
job.status = "succeeded"
|
||||||
|
return job
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
job.status = "failed"
|
||||||
|
verbose_logger.error(f"Training failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def acreate_fine_tuning_job(
|
||||||
|
self,
|
||||||
|
create_fine_tuning_job_data: dict,
|
||||||
|
) -> FineTuningJob:
|
||||||
|
"""Create a fine-tuning job asynchronously"""
|
||||||
|
return await self._train_diffusers_model(**create_fine_tuning_job_data)
|
||||||
|
|
||||||
|
def create_fine_tuning_job(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
create_fine_tuning_job_data: dict,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||||
|
"""Create a fine-tuning job (sync or async)"""
|
||||||
|
if _is_async:
|
||||||
|
return self.acreate_fine_tuning_job(create_fine_tuning_job_data)
|
||||||
|
else:
|
||||||
|
# Run async code synchronously
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.run(
|
||||||
|
self.acreate_fine_tuning_job(create_fine_tuning_job_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def alist_fine_tuning_jobs(
|
||||||
|
self,
|
||||||
|
after: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""List fine-tuning jobs asynchronously"""
|
||||||
|
jobs = list(self.jobs.values())
|
||||||
|
if after:
|
||||||
|
jobs = [j for j in jobs if j.id > after]
|
||||||
|
if limit:
|
||||||
|
jobs = jobs[:limit]
|
||||||
|
return {"data": jobs}
|
||||||
|
|
||||||
|
def list_fine_tuning_jobs(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
after: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""List fine-tuning jobs (sync or async)"""
|
||||||
|
if _is_async:
|
||||||
|
return self.alist_fine_tuning_jobs(after=after, limit=limit)
|
||||||
|
else:
|
||||||
|
# Run async code synchronously
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.run(self.alist_fine_tuning_jobs(after=after, limit=limit))
|
||||||
|
|
||||||
|
async def aretrieve_fine_tuning_job(
|
||||||
|
self,
|
||||||
|
fine_tuning_job_id: str,
|
||||||
|
) -> FineTuningJob:
|
||||||
|
"""Retrieve a fine-tuning job asynchronously"""
|
||||||
|
if fine_tuning_job_id not in self.jobs:
|
||||||
|
raise ValueError(f"Job {fine_tuning_job_id} not found")
|
||||||
|
return self.jobs[fine_tuning_job_id]
|
||||||
|
|
||||||
|
def retrieve_fine_tuning_job(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
fine_tuning_job_id: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Retrieve a fine-tuning job (sync or async)"""
|
||||||
|
if _is_async:
|
||||||
|
return self.aretrieve_fine_tuning_job(fine_tuning_job_id)
|
||||||
|
else:
|
||||||
|
# Run async code synchronously
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.run(self.aretrieve_fine_tuning_job(fine_tuning_job_id))
|
||||||
|
|
||||||
|
async def acancel_fine_tuning_job(
|
||||||
|
self,
|
||||||
|
fine_tuning_job_id: str,
|
||||||
|
) -> FineTuningJob:
|
||||||
|
"""Cancel a fine-tuning job asynchronously"""
|
||||||
|
if fine_tuning_job_id not in self.jobs:
|
||||||
|
raise ValueError(f"Job {fine_tuning_job_id} not found")
|
||||||
|
|
||||||
|
job = self.jobs[fine_tuning_job_id]
|
||||||
|
if job.status in ["succeeded", "failed", "cancelled"]:
|
||||||
|
raise ValueError(f"Cannot cancel job in status {job.status}")
|
||||||
|
|
||||||
|
job.status = "cancelled"
|
||||||
|
return job
|
||||||
|
|
||||||
|
def cancel_fine_tuning_job(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
fine_tuning_job_id: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Cancel a fine-tuning job (sync or async)"""
|
||||||
|
if _is_async:
|
||||||
|
return self.acancel_fine_tuning_job(fine_tuning_job_id)
|
||||||
|
else:
|
||||||
|
# Run async code synchronously
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return asyncio.run(self.acancel_fine_tuning_job(fine_tuning_job_id))
|
62
litellm/llms/diffusers/image_variations/handler.py
Normal file
62
litellm/llms/diffusers/image_variations/handler.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
from typing import Union
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersImageHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self.pipeline_cache = {} # Cache loaded models
|
||||||
|
self.device = self._get_default_device()
|
||||||
|
|
||||||
|
def _load_pipeline(self, model: str, device: str = "cuda"):
|
||||||
|
"""Load and cache diffusion pipeline"""
|
||||||
|
if model not in self.pipeline_cache:
|
||||||
|
self.pipeline_cache[model] = StableDiffusionPipeline.from_pretrained(
|
||||||
|
model
|
||||||
|
).to(device)
|
||||||
|
return self.pipeline_cache[model]
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str = "runwayml/stable-diffusion-v1-5",
|
||||||
|
device: str = "cuda",
|
||||||
|
**kwargs
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Generate image from text prompt"""
|
||||||
|
pipe = self._load_pipeline(model, device)
|
||||||
|
return pipe(prompt, **kwargs).images[0]
|
||||||
|
|
||||||
|
def generate_variation(
|
||||||
|
self,
|
||||||
|
image: Union[Image.Image, str, bytes], # Accepts PIL, file path, or bytes
|
||||||
|
model: str = "runwayml/stable-diffusion-v1-5",
|
||||||
|
device: str = "cuda",
|
||||||
|
**kwargs
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Generate variation of input image"""
|
||||||
|
# Convert input to PIL Image
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image)
|
||||||
|
elif isinstance(image, bytes):
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
|
||||||
|
pipe = self._load_pipeline(model, device)
|
||||||
|
return pipe(image=image, **kwargs).images[0]
|
||||||
|
|
||||||
|
def generate_to_bytes(self, *args, **kwargs) -> bytes:
|
||||||
|
"""Generate image and return as bytes"""
|
||||||
|
img = self.generate_image(*args, **kwargs)
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
img.save(buffered, format="PNG")
|
||||||
|
return buffered.getvalue()
|
||||||
|
|
||||||
|
def generate_to_b64(self, *args, **kwargs) -> str:
|
||||||
|
"""Generate image and return as base64"""
|
||||||
|
return base64.b64encode(self.generate_to_bytes(*args, **kwargs)).decode("utf-8")
|
88
litellm/llms/diffusers/image_variations/transformation.py
Normal file
88
litellm/llms/diffusers/image_variations/transformation.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj
|
||||||
|
from litellm.types.utils import FileTypes, ImageResponse
|
||||||
|
|
||||||
|
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||||
|
from ..common_utils import LLMError
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersImageVariationConfig(BaseImageVariationConfig):
|
||||||
|
def get_supported_diffusers_params(self) -> List[str]:
|
||||||
|
"""Return supported parameters for diffusers pipeline"""
|
||||||
|
return [
|
||||||
|
"prompt",
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
"num_inference_steps",
|
||||||
|
"guidance_scale",
|
||||||
|
"negative_prompt",
|
||||||
|
"num_images_per_prompt",
|
||||||
|
"eta",
|
||||||
|
"seed",
|
||||||
|
]
|
||||||
|
|
||||||
|
def transform_request_image_variation(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
image: FileTypes,
|
||||||
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""Convert input to format expected by diffusers"""
|
||||||
|
# Convert image to PIL if needed
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
if isinstance(image, str): # file path
|
||||||
|
image = Image.open(image)
|
||||||
|
elif isinstance(image, bytes): # raw bytes
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"model": model or "runwayml/stable-diffusion-v1-5",
|
||||||
|
"params": {
|
||||||
|
k: v
|
||||||
|
for k, v in optional_params.items()
|
||||||
|
if k in self.get_supported_diffusers_params()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_response_image_variation(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
raw_response: Any, # Not used for local
|
||||||
|
model_response: ImageResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
image: FileTypes,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""Convert diffusers output to standardized ImageResponse"""
|
||||||
|
# For diffusers, model_response should be PIL Image or list of PIL Images
|
||||||
|
if isinstance(model_response, list):
|
||||||
|
images = model_response
|
||||||
|
else:
|
||||||
|
images = [model_response]
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
image_data = []
|
||||||
|
for img in images:
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
img.save(buffered, format="PNG")
|
||||||
|
image_data.append(
|
||||||
|
{"b64_json": base64.b64encode(buffered.getvalue()).decode("utf-8")}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageResponse(created=int(time.time()), data=image_data)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: dict
|
||||||
|
) -> LLMError:
|
||||||
|
"""Return generic LLM error for diffusers"""
|
||||||
|
return LLMError(status_code=status_code, message=error_message, headers=headers)
|
|
@ -140,6 +140,7 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||||
from .llms.deprecated_providers import aleph_alpha, palm
|
from .llms.deprecated_providers import aleph_alpha, palm
|
||||||
|
from .llms.diffusers.diffusers import DiffusersImageHandler
|
||||||
from .llms.groq.chat.handler import GroqChatCompletion
|
from .llms.groq.chat.handler import GroqChatCompletion
|
||||||
from .llms.huggingface.embedding.handler import HuggingFaceEmbedding
|
from .llms.huggingface.embedding.handler import HuggingFaceEmbedding
|
||||||
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
||||||
|
@ -228,6 +229,7 @@ codestral_text_completions = CodestralTextCompletion()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
bedrock_embedding = BedrockEmbedding()
|
bedrock_embedding = BedrockEmbedding()
|
||||||
bedrock_image_generation = BedrockImageGeneration()
|
bedrock_image_generation = BedrockImageGeneration()
|
||||||
|
diffusers_image_generation = DiffusersImageHandler()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
vertex_embedding = VertexEmbedding()
|
vertex_embedding = VertexEmbedding()
|
||||||
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||||
|
@ -4564,7 +4566,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
|
||||||
|
|
||||||
|
|
||||||
@client
|
@client
|
||||||
def image_generation( # noqa: PLR0915
|
def image_generation(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
|
@ -4573,45 +4575,75 @@ def image_generation( # noqa: PLR0915
|
||||||
size: Optional[str] = None,
|
size: Optional[str] = None,
|
||||||
style: Optional[str] = None,
|
style: Optional[str] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
timeout=600, # default to 10 minutes
|
timeout=600,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
|
device: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
"""
|
"""
|
||||||
Maps the https://api.openai.com/v1/images/generations endpoint.
|
Handles image generation for various providers including local Diffusers models.
|
||||||
|
|
||||||
Currently supports just Azure + OpenAI.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
args = locals()
|
args = locals()
|
||||||
aimg_generation = kwargs.get("aimg_generation", False)
|
aimg_generation = kwargs.get("aimg_generation", False)
|
||||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
|
mock_response = kwargs.get("mock_response", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
|
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
extra_headers = kwargs.get("extra_headers", None)
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
headers: dict = kwargs.get("headers", None) or {}
|
headers = kwargs.get("headers", None) or {}
|
||||||
|
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
headers.update(extra_headers)
|
headers.update(extra_headers)
|
||||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
|
||||||
|
model_response = litellm.utils.ImageResponse()
|
||||||
|
|
||||||
|
# Get model provider info
|
||||||
if model is not None or custom_llm_provider is not None:
|
if model is not None or custom_llm_provider is not None:
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||||
model=model, # type: ignore
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = "dall-e-2"
|
model = "dall-e-2"
|
||||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
custom_llm_provider = "openai"
|
||||||
|
|
||||||
model_response._hidden_params["model"] = model
|
model_response._hidden_params["model"] = model
|
||||||
|
|
||||||
|
# Handle Diffusers/local models
|
||||||
|
if model.startswith("diffusers/") or custom_llm_provider == "diffusers":
|
||||||
|
from .llms.diffusers.diffusers import DiffusersImageHandler
|
||||||
|
|
||||||
|
model_path = model.replace("diffusers/", "")
|
||||||
|
width, height = (512, 512)
|
||||||
|
if size:
|
||||||
|
width, height = map(int, size.split("x"))
|
||||||
|
|
||||||
|
handler = DiffusersImageHandler()
|
||||||
|
diffusers_response = handler.generate_image(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model_path,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_images_per_prompt=n or 1,
|
||||||
|
device=device, # Pass through device parameter
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.created = diffusers_response.created
|
||||||
|
model_response.data = diffusers_response.data
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
# Original provider handling remains the same
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"user",
|
"user",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
|
@ -4629,11 +4661,12 @@ def image_generation( # noqa: PLR0915
|
||||||
"size",
|
"size",
|
||||||
"style",
|
"style",
|
||||||
]
|
]
|
||||||
|
|
||||||
litellm_params = all_litellm_params
|
litellm_params = all_litellm_params
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
}
|
||||||
|
|
||||||
optional_params = get_optional_params_image_gen(
|
optional_params = get_optional_params_image_gen(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -4649,7 +4682,7 @@ def image_generation( # noqa: PLR0915
|
||||||
|
|
||||||
litellm_params_dict = get_litellm_params(**kwargs)
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
logging: Logging = litellm_logging_obj
|
logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
user=user,
|
user=user,
|
||||||
|
@ -4667,6 +4700,7 @@ def image_generation( # noqa: PLR0915
|
||||||
},
|
},
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "custom_llm_provider" not in logging.model_call_details:
|
if "custom_llm_provider" not in logging.model_call_details:
|
||||||
logging.model_call_details["custom_llm_provider"] = custom_llm_provider
|
logging.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||||
if mock_response is not None:
|
if mock_response is not None:
|
||||||
|
|
151
tests/litellm/llms/diffusers/test_diffusers.py
Normal file
151
tests/litellm/llms/diffusers/test_diffusers.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../../.."))
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.diffusers.diffusers import DiffusersImageHandler
|
||||||
|
|
||||||
|
API_FUNCTION_PARAMS = [
|
||||||
|
(
|
||||||
|
"image_generation",
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"model": "diffusers/runwayml/stable-diffusion-v1-5",
|
||||||
|
"prompt": "A cute cat",
|
||||||
|
"n": 1,
|
||||||
|
"size": "512x512",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"image_generation",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"model": "diffusers/runwayml/stable-diffusion-v1-5",
|
||||||
|
"prompt": "A cute cat",
|
||||||
|
"n": 1,
|
||||||
|
"size": "512x512",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_diffusers():
|
||||||
|
"""Fixture that properly mocks the diffusers pipeline"""
|
||||||
|
with patch(
|
||||||
|
"diffusers.StableDiffusionPipeline.from_pretrained"
|
||||||
|
) as mock_from_pretrained:
|
||||||
|
# Create real test images
|
||||||
|
def create_test_image():
|
||||||
|
arr = np.random.rand(512, 512, 3) * 255
|
||||||
|
return Image.fromarray(arr.astype("uint8")).convert("RGB")
|
||||||
|
|
||||||
|
test_images = [create_test_image(), create_test_image()]
|
||||||
|
|
||||||
|
# Create mock pipeline that returns our test images
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipe.return_value.images = test_images
|
||||||
|
mock_from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"from_pretrained": mock_from_pretrained,
|
||||||
|
"pipeline": mock_pipe,
|
||||||
|
"test_images": test_images,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_diffusers_image_handler(mock_diffusers):
|
||||||
|
"""Test that the handler properly processes images into base64 responses"""
|
||||||
|
from litellm.llms.diffusers.diffusers import DiffusersImageHandler
|
||||||
|
|
||||||
|
handler = DiffusersImageHandler()
|
||||||
|
|
||||||
|
# Test with 2 images
|
||||||
|
response = handler.generate_image(
|
||||||
|
prompt="test prompt",
|
||||||
|
model="runwayml/stable-diffusion-v1-5",
|
||||||
|
num_images_per_prompt=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert hasattr(response, "data")
|
||||||
|
assert isinstance(response.data, list)
|
||||||
|
assert len(response.data) == 2 # Should return exactly 2 images
|
||||||
|
|
||||||
|
# Verify each image is properly encoded
|
||||||
|
for img_data in response.data:
|
||||||
|
assert "b64_json" in img_data
|
||||||
|
# Test we can decode it back to an image
|
||||||
|
try:
|
||||||
|
img_bytes = base64.b64decode(img_data["b64_json"])
|
||||||
|
img = Image.open(io.BytesIO(img_bytes))
|
||||||
|
assert img.size == (512, 512)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Failed to decode base64 image: {str(e)}")
|
||||||
|
|
||||||
|
# Verify pipeline was called correctly
|
||||||
|
mock_diffusers["from_pretrained"].assert_called_once_with(
|
||||||
|
"runwayml/stable-diffusion-v1-5"
|
||||||
|
)
|
||||||
|
mock_diffusers["pipeline"].assert_called_once_with(
|
||||||
|
prompt="test prompt", num_images_per_prompt=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"function_name,is_async,args",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"image_generation",
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"model": "diffusers/runwayml/stable-diffusion-v1-5",
|
||||||
|
"prompt": "A cat",
|
||||||
|
"n": 1,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"image_generation",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"model": "diffusers/runwayml/stable-diffusion-v1-5",
|
||||||
|
"prompt": "A cat",
|
||||||
|
"n": 1,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_generation(function_name, is_async, args, mock_diffusers):
|
||||||
|
"""Test the image generation API endpoint"""
|
||||||
|
# Configure mock
|
||||||
|
mock_diffusers["pipeline"].return_value.images = mock_diffusers["test_images"][
|
||||||
|
: args["n"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if is_async:
|
||||||
|
response = await litellm.aimage_generation(**args)
|
||||||
|
else:
|
||||||
|
response = litellm.image_generation(**args)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert len(response.data) == args["n"]
|
||||||
|
assert "b64_json" in response.data[0]
|
||||||
|
|
||||||
|
# Test base64 decoding
|
||||||
|
try:
|
||||||
|
img_bytes = base64.b64decode(response.data[0]["b64_json"])
|
||||||
|
img = Image.open(io.BytesIO(img_bytes))
|
||||||
|
assert img.size == (512, 512)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Invalid base64 image: {str(e)}")
|
Loading…
Add table
Add a link
Reference in a new issue