litellm-mirror/litellm/llms/diffusers/diffusers.py
2025-04-04 23:51:51 -04:00

123 lines
3.9 KiB
Python

from typing import Optional, Union, List, Dict
import io
import base64
import time
try:
from PIL import Image
from diffusers import StableDiffusionPipeline
except ModuleNotFoundError:
pass
from pydantic import BaseModel
class ImageResponse(BaseModel):
created: int
data: List[Dict[str, str]] # List of dicts with "b64_json" or "url"
class DiffusersImageHandler:
def __init__(self):
self.pipeline_cache = {} # Cache loaded pipelines
self.device = self._get_default_device()
def _get_default_device(self):
"""Determine the best available device"""
import torch
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available(): # For Apple Silicon
return "mps"
else:
return "cpu"
def _load_pipeline(
self, model: str, device: Optional[str] = None
) -> StableDiffusionPipeline:
"""Load and cache diffusion pipeline"""
device = device or self.device
if model not in self.pipeline_cache:
try:
pipe = StableDiffusionPipeline.from_pretrained(model)
pipe = pipe.to(device)
self.pipeline_cache[model] = pipe
except RuntimeError as e:
if "CUDA" in str(e):
# Fallback to CPU if CUDA fails
verbose_logger.warning(f"Falling back to CPU: {str(e)}")
pipe = pipe.to("cpu")
self.pipeline_cache[model] = pipe
else:
raise
return self.pipeline_cache[model]
def _image_to_b64(self, image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def generate_image(
self, prompt: str, model: str, num_images_per_prompt: int = 1, **kwargs
) -> ImageResponse:
# Get or create pipeline
if model not in self.pipeline_cache:
from diffusers import StableDiffusionPipeline
self.pipeline_cache[model] = StableDiffusionPipeline.from_pretrained(model)
pipe = self.pipeline_cache[model]
# Generate images
images = pipe(
prompt=prompt, num_images_per_prompt=num_images_per_prompt, **kwargs
).images
# Convert to base64
image_data = []
for img in images:
buffered = io.BytesIO()
img.save(buffered, format="PNG")
image_data.append(
{"b64_json": base64.b64encode(buffered.getvalue()).decode("utf-8")}
)
return ImageResponse(created=int(time.time()), data=image_data)
def generate_variation(
self,
image: Union[Image.Image, str, bytes], # Accepts PIL, file path, or bytes
prompt: Optional[str] = None,
model: str = "runwayml/stable-diffusion-v1-5",
strength: float = 0.8,
**kwargs,
) -> ImageResponse:
"""
Generate variation of input image
Args:
image: Input image (PIL, file path, or bytes)
prompt: Optional text prompt to guide variation
model: Diffusers model ID
strength: Strength of variation (0-1)
Returns:
ImageResponse with base64 encoded images
"""
# Convert input to PIL Image
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, bytes):
image = Image.open(io.BytesIO(image))
pipe = self._load_pipeline(model)
# Generate variation
result = pipe(prompt=prompt, image=image, strength=strength, **kwargs)
# Convert to response format
image_data = [{"b64_json": self._image_to_b64(result.images[0])}]
return ImageResponse(created=int(time.time()), data=image_data)