mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
from typing import Optional, Union, List, Dict
|
|
|
|
import io
|
|
import base64
|
|
import time
|
|
|
|
try:
|
|
from PIL import Image
|
|
from diffusers import StableDiffusionPipeline
|
|
except ModuleNotFoundError:
|
|
pass
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class ImageResponse(BaseModel):
|
|
created: int
|
|
data: List[Dict[str, str]] # List of dicts with "b64_json" or "url"
|
|
|
|
|
|
class DiffusersImageHandler:
|
|
def __init__(self):
|
|
self.pipeline_cache = {} # Cache loaded pipelines
|
|
self.device = self._get_default_device()
|
|
|
|
def _get_default_device(self):
|
|
"""Determine the best available device"""
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
elif torch.backends.mps.is_available(): # For Apple Silicon
|
|
return "mps"
|
|
else:
|
|
return "cpu"
|
|
|
|
def _load_pipeline(
|
|
self, model: str, device: Optional[str] = None
|
|
) -> StableDiffusionPipeline:
|
|
"""Load and cache diffusion pipeline"""
|
|
device = device or self.device
|
|
|
|
if model not in self.pipeline_cache:
|
|
try:
|
|
pipe = StableDiffusionPipeline.from_pretrained(model)
|
|
pipe = pipe.to(device)
|
|
self.pipeline_cache[model] = pipe
|
|
except RuntimeError as e:
|
|
if "CUDA" in str(e):
|
|
# Fallback to CPU if CUDA fails
|
|
verbose_logger.warning(f"Falling back to CPU: {str(e)}")
|
|
pipe = pipe.to("cpu")
|
|
self.pipeline_cache[model] = pipe
|
|
else:
|
|
raise
|
|
|
|
return self.pipeline_cache[model]
|
|
|
|
def _image_to_b64(self, image: Image.Image) -> str:
|
|
"""Convert PIL Image to base64 string"""
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
def generate_image(
|
|
self, prompt: str, model: str, num_images_per_prompt: int = 1, **kwargs
|
|
) -> ImageResponse:
|
|
# Get or create pipeline
|
|
if model not in self.pipeline_cache:
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
self.pipeline_cache[model] = StableDiffusionPipeline.from_pretrained(model)
|
|
|
|
pipe = self.pipeline_cache[model]
|
|
|
|
# Generate images
|
|
images = pipe(
|
|
prompt=prompt, num_images_per_prompt=num_images_per_prompt, **kwargs
|
|
).images
|
|
|
|
# Convert to base64
|
|
image_data = []
|
|
for img in images:
|
|
buffered = io.BytesIO()
|
|
img.save(buffered, format="PNG")
|
|
image_data.append(
|
|
{"b64_json": base64.b64encode(buffered.getvalue()).decode("utf-8")}
|
|
)
|
|
|
|
return ImageResponse(created=int(time.time()), data=image_data)
|
|
|
|
def generate_variation(
|
|
self,
|
|
image: Union[Image.Image, str, bytes], # Accepts PIL, file path, or bytes
|
|
prompt: Optional[str] = None,
|
|
model: str = "runwayml/stable-diffusion-v1-5",
|
|
strength: float = 0.8,
|
|
**kwargs,
|
|
) -> ImageResponse:
|
|
"""
|
|
Generate variation of input image
|
|
Args:
|
|
image: Input image (PIL, file path, or bytes)
|
|
prompt: Optional text prompt to guide variation
|
|
model: Diffusers model ID
|
|
strength: Strength of variation (0-1)
|
|
Returns:
|
|
ImageResponse with base64 encoded images
|
|
"""
|
|
# Convert input to PIL Image
|
|
if isinstance(image, str):
|
|
image = Image.open(image)
|
|
elif isinstance(image, bytes):
|
|
image = Image.open(io.BytesIO(image))
|
|
|
|
pipe = self._load_pipeline(model)
|
|
|
|
# Generate variation
|
|
result = pipe(prompt=prompt, image=image, strength=strength, **kwargs)
|
|
|
|
# Convert to response format
|
|
image_data = [{"b64_json": self._image_to_b64(result.images[0])}]
|
|
|
|
return ImageResponse(created=int(time.time()), data=image_data)
|