Skip to content

About the performance of the safety checker? #275

@BIGJUN777

Description

@BIGJUN777

This is an issue coming from #213 (comment)
image

My code referred to stable diffusion . Please check it in your free time, thanks. @patrickvonplaten

import os
import glob
import json
import random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
from torch.utils import data
from transformers import AutoFeatureExtractor
from torch.utils.data.dataloader import DataLoader
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, InterpolationMode

class SafetyChecker:
    def __init__(self, safety_model_id="CompVis/stable-diffusion-safety-checker", device='cuda'):
        # load safety model
        self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
        self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id).to(device)
        self.device = device

    def numpy_to_pil(self, images):
        """
        Convert a numpy image or a batch of images to a PIL image.
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        return pil_images

    def check_safety(self, x_image):
        if isinstance(x_image, torch.Tensor):
            x_image = x_image.cpu().permute(0, 2, 3, 1).numpy()
        safety_checker_input = self.safety_feature_extractor(self.numpy_to_pil(x_image), return_tensors="pt")
        x_checked_image, has_nsfw_concept = self.safety_checker(images=x_image, 
        clip_input=safety_checker_input.pixel_values.to(self.device))
        assert x_checked_image.shape[0] == len(has_nsfw_concept)
        return x_checked_image, has_nsfw_concept

class SafetyCheckerDataset(data.Dataset):
    def __init__(
        self,
        dataset_folder,
        transform=None,
        sample=0
    ):
        super().__init__()
        self.dataset_folder = dataset_folder
        self.transform = transform
        imgs_path = glob.glob(os.path.join(dataset_folder, '*.jpg'))
        self.imgs_path = random.sample(imgs_path, sample) if sample else imgs_path

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        img_name = os.path.basename(img_path)
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
        return dict(image=img, image_name=img_name)

def test_safety_checker(dataset_dir):
    def _convert_image_to_rgb(image):
        return image.convert("RGB")
    
    def _transform(n_px):
        return Compose([
            Resize(n_px, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(n_px),
            _convert_image_to_rgb,
            ToTensor(),
        ])
    dataset = SafetyCheckerDataset(dataset_dir, _transform(256), sample_num)
    dataloader = DataLoader(dataset, batch_size=4, num_workers=4, pin_memory=True, shuffle=False)
    checker = SafetyChecker(device='cuda')
    for data in tqdm(dataloader):
        _, res= checker.check_safety(data['image'])

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions