#!/usr/bin/python import io import random import logging import torch import tractor from diffusers import ( StableDiffusionPipeline, EulerAncestralDiscreteScheduler ) from .types import ImageGenRequest from .constants import ALGOS def pipeline_for(algo: str, mem_fraction: float): assert torch.cuda.is_available() torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(mem_fraction) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True params = { 'torch_dtype': torch.float16, 'safety_checker': None } if algo == 'stable': params['revision'] = 'fp16' pipe = StableDiffusionPipeline.from_pretrained( ALGOS[algo], **params) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler.config) return pipe.to("cuda") @tractor.context async def open_gpu_worker( ctx: tractor.Context, start_algo: str, mem_fraction: float ): log = tractor.log.get_logger(name='gpu', _root_name='skynet') log.info(f'starting gpu worker with algo {start_algo}...') current_algo = start_algo with torch.no_grad(): pipe = pipeline_for(current_algo, mem_fraction) log.info('pipeline loaded') await ctx.started() async with ctx.open_stream() as bus: async for ireq in bus: if ireq.algo != current_algo: current_algo = ireq.algo pipe = pipeline_for(current_algo, mem_fraction) seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) image = pipe( ireq.prompt, width=ireq.width, height=ireq.height, guidance_scale=ireq.guidance, num_inference_steps=ireq.step, generator=torch.Generator("cuda").manual_seed(seed) ).images[0] torch.cuda.empty_cache() # convert PIL.Image to BytesIO img_bytes = io.BytesIO() image.save(img_bytes, format='PNG') await bus.send(img_bytes.getvalue())