adding test_diffusers.py and updating diffusers

This commit is contained in:
Chris Agostino 2025-04-04 23:29:20 -04:00
parent 3638109576
commit 639a4470a5
2 changed files with 169 additions and 38 deletions

View file

@ -61,49 +61,29 @@ class DiffusersImageHandler:
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def generate_image(
self,
prompt: str,
model: str = "runwayml/stable-diffusion-v1-5",
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[str] = None,
num_images_per_prompt: int = 1,
device: str = "cuda",
**kwargs,
self, prompt: str, model: str, num_images_per_prompt: int = 1, **kwargs
) -> ImageResponse:
"""
Generate image from text prompt
Args:
prompt: Text prompt to generate image from
model: Diffusers model ID
height: Height of output image
width: Width of output image
num_inference_steps: Number of denoising steps
guidance_scale: Scale for classifier-free guidance
negative_prompt: Negative prompt to avoid certain content
num_images_per_prompt: Number of images to generate
device: Device to run on ('cuda' or 'cpu')
Returns:
ImageResponse with base64 encoded images
"""
pipe = self._load_pipeline(model, device)
# Get or create pipeline
if model not in self.pipeline_cache:
from diffusers import StableDiffusionPipeline
# Generate image(s)
self.pipeline_cache[model] = StableDiffusionPipeline.from_pretrained(model)
pipe = self.pipeline_cache[model]
# Generate images
images = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
prompt=prompt, num_images_per_prompt=num_images_per_prompt, **kwargs
).images
# Convert to response format
image_data = [{"b64_json": self._image_to_b64(img)} for img in images]
# Convert to base64
image_data = []
for img in images:
buffered = io.BytesIO()
img.save(buffered, format="PNG")
image_data.append(
{"b64_json": base64.b64encode(buffered.getvalue()).decode("utf-8")}
)
return ImageResponse(created=int(time.time()), data=image_data)