|
| 1 | +from ast import arg |
| 2 | +import os.path as osp |
| 3 | +import glob |
| 4 | +import cv2 |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | + |
| 8 | +from torch.fx.experimental.proxy_tensor import make_fx |
| 9 | +from torch._decomp import get_decompositions |
| 10 | +from shark.shark_inference import SharkInference |
| 11 | +import torch_mlir |
| 12 | +import tempfile |
| 13 | +import functools |
| 14 | +import torch |
| 15 | +import torch.nn as nn |
| 16 | +import torch.nn.functional as F |
| 17 | + |
| 18 | + |
| 19 | +def make_layer(block, n_layers): |
| 20 | + layers = [] |
| 21 | + for _ in range(n_layers): |
| 22 | + layers.append(block()) |
| 23 | + return nn.Sequential(*layers) |
| 24 | + |
| 25 | + |
| 26 | +class ResidualDenseBlock_5C(nn.Module): |
| 27 | + def __init__(self, nf=64, gc=32, bias=True): |
| 28 | + super(ResidualDenseBlock_5C, self).__init__() |
| 29 | + # gc: growth channel, i.e. intermediate channels |
| 30 | + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) |
| 31 | + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) |
| 32 | + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) |
| 33 | + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) |
| 34 | + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) |
| 35 | + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| 36 | + |
| 37 | + # initialization |
| 38 | + # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) |
| 39 | + |
| 40 | + def forward(self, x): |
| 41 | + x1 = self.lrelu(self.conv1(x)) |
| 42 | + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) |
| 43 | + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) |
| 44 | + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) |
| 45 | + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
| 46 | + return x5 * 0.2 + x |
| 47 | + |
| 48 | + |
| 49 | +class RRDB(nn.Module): |
| 50 | + """Residual in Residual Dense Block""" |
| 51 | + |
| 52 | + def __init__(self, nf, gc=32): |
| 53 | + super(RRDB, self).__init__() |
| 54 | + self.RDB1 = ResidualDenseBlock_5C(nf, gc) |
| 55 | + self.RDB2 = ResidualDenseBlock_5C(nf, gc) |
| 56 | + self.RDB3 = ResidualDenseBlock_5C(nf, gc) |
| 57 | + |
| 58 | + def forward(self, x): |
| 59 | + out = self.RDB1(x) |
| 60 | + out = self.RDB2(out) |
| 61 | + out = self.RDB3(out) |
| 62 | + return out * 0.2 + x |
| 63 | + |
| 64 | + |
| 65 | +class RRDBNet(nn.Module): |
| 66 | + def __init__(self, in_nc, out_nc, nf, nb, gc=32): |
| 67 | + super(RRDBNet, self).__init__() |
| 68 | + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) |
| 69 | + |
| 70 | + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) |
| 71 | + self.RRDB_trunk = make_layer(RRDB_block_f, nb) |
| 72 | + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| 73 | + #### upsampling |
| 74 | + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| 75 | + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| 76 | + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
| 77 | + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) |
| 78 | + |
| 79 | + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| 80 | + |
| 81 | + def forward(self, x): |
| 82 | + fea = self.conv_first(x) |
| 83 | + trunk = self.trunk_conv(self.RRDB_trunk(fea)) |
| 84 | + fea = fea + trunk |
| 85 | + |
| 86 | + fea = self.lrelu( |
| 87 | + self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) |
| 88 | + ) |
| 89 | + fea = self.lrelu( |
| 90 | + self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) |
| 91 | + ) |
| 92 | + out = self.conv_last(self.lrelu(self.HRconv(fea))) |
| 93 | + |
| 94 | + return out |
| 95 | + |
| 96 | + |
| 97 | +############### Parsing args ##################### |
| 98 | +import argparse |
| 99 | + |
| 100 | +p = argparse.ArgumentParser( |
| 101 | + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| 102 | +) |
| 103 | + |
| 104 | +p.add_argument("--device", type=str, default="cpu", help="the device to use") |
| 105 | +p.add_argument( |
| 106 | + "--mlir_loc", |
| 107 | + type=str, |
| 108 | + default=None, |
| 109 | + help="location of the model's mlir file", |
| 110 | +) |
| 111 | +args = p.parse_args() |
| 112 | +################################################### |
| 113 | + |
| 114 | + |
| 115 | +def inference(input_m): |
| 116 | + return model(input_m) |
| 117 | + |
| 118 | + |
| 119 | +def load_mlir(mlir_loc): |
| 120 | + import os |
| 121 | + |
| 122 | + if mlir_loc == None: |
| 123 | + return None |
| 124 | + print(f"Trying to load the model from {mlir_loc}.") |
| 125 | + with open(os.path.join(mlir_loc)) as f: |
| 126 | + mlir_module = f.read() |
| 127 | + return mlir_module |
| 128 | + |
| 129 | + |
| 130 | +def compile_through_fx(model, inputs, mlir_loc=None): |
| 131 | + |
| 132 | + module = load_mlir(mlir_loc) |
| 133 | + if module == None: |
| 134 | + fx_g = make_fx( |
| 135 | + model, |
| 136 | + decomposition_table=get_decompositions( |
| 137 | + [ |
| 138 | + torch.ops.aten.embedding_dense_backward, |
| 139 | + torch.ops.aten.native_layer_norm_backward, |
| 140 | + torch.ops.aten.slice_backward, |
| 141 | + torch.ops.aten.select_backward, |
| 142 | + torch.ops.aten.norm.ScalarOpt_dim, |
| 143 | + torch.ops.aten.native_group_norm, |
| 144 | + torch.ops.aten.upsample_bilinear2d.vec, |
| 145 | + torch.ops.aten.split.Tensor, |
| 146 | + torch.ops.aten.split_with_sizes, |
| 147 | + ] |
| 148 | + ), |
| 149 | + )(inputs) |
| 150 | + |
| 151 | + fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) |
| 152 | + fx_g.recompile() |
| 153 | + |
| 154 | + def strip_overloads(gm): |
| 155 | + """ |
| 156 | + Modifies the target of graph nodes in :attr:`gm` to strip overloads. |
| 157 | + Args: |
| 158 | + gm(fx.GraphModule): The input Fx graph module to be modified |
| 159 | + """ |
| 160 | + for node in gm.graph.nodes: |
| 161 | + if isinstance(node.target, torch._ops.OpOverload): |
| 162 | + node.target = node.target.overloadpacket |
| 163 | + gm.recompile() |
| 164 | + |
| 165 | + strip_overloads(fx_g) |
| 166 | + |
| 167 | + ts_g = torch.jit.script(fx_g) |
| 168 | + |
| 169 | + print("Torchscript graph generated successfully") |
| 170 | + module = torch_mlir.compile( |
| 171 | + ts_g, |
| 172 | + inputs, |
| 173 | + torch_mlir.OutputType.LINALG_ON_TENSORS, |
| 174 | + use_tracing=False, |
| 175 | + verbose=False, |
| 176 | + ) |
| 177 | + |
| 178 | + mlir_model = str(module) |
| 179 | + func_name = "forward" |
| 180 | + shark_module = SharkInference( |
| 181 | + mlir_model, func_name, device=args.device, mlir_dialect="linalg" |
| 182 | + ) |
| 183 | + shark_module.compile() |
| 184 | + |
| 185 | + return shark_module |
| 186 | + |
| 187 | + |
| 188 | +model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth |
| 189 | +# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu |
| 190 | +device = torch.device("cpu") |
| 191 | + |
| 192 | +test_img_folder = "InputImages/*" |
| 193 | + |
| 194 | +model = RRDBNet(3, 3, 64, 23, gc=32) |
| 195 | +model.load_state_dict(torch.load(model_path), strict=True) |
| 196 | +model.eval() |
| 197 | +model = model.to(device) |
| 198 | + |
| 199 | +print("Model path {:s}. \nTesting...".format(model_path)) |
| 200 | + |
| 201 | +if __name__ == "__main__": |
| 202 | + idx = 0 |
| 203 | + for path in glob.glob(test_img_folder): |
| 204 | + idx += 1 |
| 205 | + base = osp.splitext(osp.basename(path))[0] |
| 206 | + print(idx, base) |
| 207 | + # read images |
| 208 | + img = cv2.imread(path, cv2.IMREAD_COLOR) |
| 209 | + img = img * 1.0 / 255 |
| 210 | + img = torch.from_numpy( |
| 211 | + np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1)) |
| 212 | + ).float() |
| 213 | + img_LR = img.unsqueeze(0) |
| 214 | + img_LR = img_LR.to(device) |
| 215 | + |
| 216 | + with torch.no_grad(): |
| 217 | + shark_module = compile_through_fx(inference, img_LR) |
| 218 | + shark_output = shark_module.forward((img_LR,)) |
| 219 | + shark_output = torch.from_numpy(shark_output) |
| 220 | + shark_output = ( |
| 221 | + shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
| 222 | + ) |
| 223 | + esrgan_output = ( |
| 224 | + model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() |
| 225 | + ) |
| 226 | + # SHARK OUTPUT |
| 227 | + shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0)) |
| 228 | + shark_output = (shark_output * 255.0).round() |
| 229 | + cv2.imwrite( |
| 230 | + "OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output |
| 231 | + ) |
| 232 | + print("Generated SHARK's output") |
| 233 | + # ESRGAN OUTPUT |
| 234 | + esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0)) |
| 235 | + esrgan_output = (esrgan_output * 255.0).round() |
| 236 | + cv2.imwrite( |
| 237 | + "OutputImages/{:s}_rlt_esrgan_output.png".format(base), |
| 238 | + esrgan_output, |
| 239 | + ) |
| 240 | + print("Generated ESRGAN's output") |
0 commit comments