Skip to content

Commit 2f36de3

Browse files
Abhishek-Varmavivekkhandelwal1
authored andcommitted
[SHARK_INFERENCE] Add ESRGAN model test file
-- This commit adds ESRGAN model test file to SHARK_INFERENCE. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 2005bce commit 2f36de3

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
## Running ESRGAN
2+
3+
```
4+
1. pip install numpy opencv-python
5+
2. mkdir InputImages
6+
(this is where all the input images will reside in)
7+
3. mkdir OutputImages
8+
(this is where the model will generate all the images)
9+
4. mkdir models
10+
(save the .pth checkpoint file here)
11+
5. python esrgan.py
12+
```
13+
14+
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
15+
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
from ast import arg
2+
import os.path as osp
3+
import glob
4+
import cv2
5+
import numpy as np
6+
import torch
7+
8+
from torch.fx.experimental.proxy_tensor import make_fx
9+
from torch._decomp import get_decompositions
10+
from shark.shark_inference import SharkInference
11+
import torch_mlir
12+
import tempfile
13+
import functools
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
18+
19+
def make_layer(block, n_layers):
20+
layers = []
21+
for _ in range(n_layers):
22+
layers.append(block())
23+
return nn.Sequential(*layers)
24+
25+
26+
class ResidualDenseBlock_5C(nn.Module):
27+
def __init__(self, nf=64, gc=32, bias=True):
28+
super(ResidualDenseBlock_5C, self).__init__()
29+
# gc: growth channel, i.e. intermediate channels
30+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
31+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
32+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
33+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
34+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
35+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
36+
37+
# initialization
38+
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
39+
40+
def forward(self, x):
41+
x1 = self.lrelu(self.conv1(x))
42+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
43+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
44+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
45+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
46+
return x5 * 0.2 + x
47+
48+
49+
class RRDB(nn.Module):
50+
"""Residual in Residual Dense Block"""
51+
52+
def __init__(self, nf, gc=32):
53+
super(RRDB, self).__init__()
54+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
55+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
56+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
57+
58+
def forward(self, x):
59+
out = self.RDB1(x)
60+
out = self.RDB2(out)
61+
out = self.RDB3(out)
62+
return out * 0.2 + x
63+
64+
65+
class RRDBNet(nn.Module):
66+
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
67+
super(RRDBNet, self).__init__()
68+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
69+
70+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
71+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
72+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
73+
#### upsampling
74+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
75+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
76+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
77+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
78+
79+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
80+
81+
def forward(self, x):
82+
fea = self.conv_first(x)
83+
trunk = self.trunk_conv(self.RRDB_trunk(fea))
84+
fea = fea + trunk
85+
86+
fea = self.lrelu(
87+
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
88+
)
89+
fea = self.lrelu(
90+
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
91+
)
92+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
93+
94+
return out
95+
96+
97+
############### Parsing args #####################
98+
import argparse
99+
100+
p = argparse.ArgumentParser(
101+
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
102+
)
103+
104+
p.add_argument("--device", type=str, default="cpu", help="the device to use")
105+
p.add_argument(
106+
"--mlir_loc",
107+
type=str,
108+
default=None,
109+
help="location of the model's mlir file",
110+
)
111+
args = p.parse_args()
112+
###################################################
113+
114+
115+
def inference(input_m):
116+
return model(input_m)
117+
118+
119+
def load_mlir(mlir_loc):
120+
import os
121+
122+
if mlir_loc == None:
123+
return None
124+
print(f"Trying to load the model from {mlir_loc}.")
125+
with open(os.path.join(mlir_loc)) as f:
126+
mlir_module = f.read()
127+
return mlir_module
128+
129+
130+
def compile_through_fx(model, inputs, mlir_loc=None):
131+
132+
module = load_mlir(mlir_loc)
133+
if module == None:
134+
fx_g = make_fx(
135+
model,
136+
decomposition_table=get_decompositions(
137+
[
138+
torch.ops.aten.embedding_dense_backward,
139+
torch.ops.aten.native_layer_norm_backward,
140+
torch.ops.aten.slice_backward,
141+
torch.ops.aten.select_backward,
142+
torch.ops.aten.norm.ScalarOpt_dim,
143+
torch.ops.aten.native_group_norm,
144+
torch.ops.aten.upsample_bilinear2d.vec,
145+
torch.ops.aten.split.Tensor,
146+
torch.ops.aten.split_with_sizes,
147+
]
148+
),
149+
)(inputs)
150+
151+
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
152+
fx_g.recompile()
153+
154+
def strip_overloads(gm):
155+
"""
156+
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
157+
Args:
158+
gm(fx.GraphModule): The input Fx graph module to be modified
159+
"""
160+
for node in gm.graph.nodes:
161+
if isinstance(node.target, torch._ops.OpOverload):
162+
node.target = node.target.overloadpacket
163+
gm.recompile()
164+
165+
strip_overloads(fx_g)
166+
167+
ts_g = torch.jit.script(fx_g)
168+
169+
print("Torchscript graph generated successfully")
170+
module = torch_mlir.compile(
171+
ts_g,
172+
inputs,
173+
torch_mlir.OutputType.LINALG_ON_TENSORS,
174+
use_tracing=False,
175+
verbose=False,
176+
)
177+
178+
mlir_model = str(module)
179+
func_name = "forward"
180+
shark_module = SharkInference(
181+
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
182+
)
183+
shark_module.compile()
184+
185+
return shark_module
186+
187+
188+
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
189+
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
190+
device = torch.device("cpu")
191+
192+
test_img_folder = "InputImages/*"
193+
194+
model = RRDBNet(3, 3, 64, 23, gc=32)
195+
model.load_state_dict(torch.load(model_path), strict=True)
196+
model.eval()
197+
model = model.to(device)
198+
199+
print("Model path {:s}. \nTesting...".format(model_path))
200+
201+
if __name__ == "__main__":
202+
idx = 0
203+
for path in glob.glob(test_img_folder):
204+
idx += 1
205+
base = osp.splitext(osp.basename(path))[0]
206+
print(idx, base)
207+
# read images
208+
img = cv2.imread(path, cv2.IMREAD_COLOR)
209+
img = img * 1.0 / 255
210+
img = torch.from_numpy(
211+
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
212+
).float()
213+
img_LR = img.unsqueeze(0)
214+
img_LR = img_LR.to(device)
215+
216+
with torch.no_grad():
217+
shark_module = compile_through_fx(inference, img_LR)
218+
shark_output = shark_module.forward((img_LR,))
219+
shark_output = torch.from_numpy(shark_output)
220+
shark_output = (
221+
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
222+
)
223+
esrgan_output = (
224+
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
225+
)
226+
# SHARK OUTPUT
227+
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
228+
shark_output = (shark_output * 255.0).round()
229+
cv2.imwrite(
230+
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
231+
)
232+
print("Generated SHARK's output")
233+
# ESRGAN OUTPUT
234+
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
235+
esrgan_output = (esrgan_output * 255.0).round()
236+
cv2.imwrite(
237+
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
238+
esrgan_output,
239+
)
240+
print("Generated ESRGAN's output")

0 commit comments

Comments
 (0)