-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
Description
This is an issue coming from #213 (comment)

My code referred to stable diffusion . Please check it in your free time, thanks. @patrickvonplaten
import os
import glob
import json
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils import data
from transformers import AutoFeatureExtractor
from torch.utils.data.dataloader import DataLoader
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, InterpolationMode
class SafetyChecker:
def __init__(self, safety_model_id="CompVis/stable-diffusion-safety-checker", device='cuda'):
# load safety model
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id).to(device)
self.device = device
def numpy_to_pil(self, images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def check_safety(self, x_image):
if isinstance(x_image, torch.Tensor):
x_image = x_image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.safety_feature_extractor(self.numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = self.safety_checker(images=x_image,
clip_input=safety_checker_input.pixel_values.to(self.device))
assert x_checked_image.shape[0] == len(has_nsfw_concept)
return x_checked_image, has_nsfw_concept
class SafetyCheckerDataset(data.Dataset):
def __init__(
self,
dataset_folder,
transform=None,
sample=0
):
super().__init__()
self.dataset_folder = dataset_folder
self.transform = transform
imgs_path = glob.glob(os.path.join(dataset_folder, '*.jpg'))
self.imgs_path = random.sample(imgs_path, sample) if sample else imgs_path
def __len__(self):
return len(self.imgs_path)
def __getitem__(self, index):
img_path = self.imgs_path[index]
img_name = os.path.basename(img_path)
img = Image.open(img_path)
if self.transform is not None:
img = self.transform(img)
return dict(image=img, image_name=img_name)
def test_safety_checker(dataset_dir):
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=InterpolationMode.BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
])
dataset = SafetyCheckerDataset(dataset_dir, _transform(256), sample_num)
dataloader = DataLoader(dataset, batch_size=4, num_workers=4, pin_memory=True, shuffle=False)
checker = SafetyChecker(device='cuda')
for data in tqdm(dataloader):
_, res= checker.check_safety(data['image'])