From a4b73459781c927852b19fdb1f2f2e1d5ce23334 Mon Sep 17 00:00:00 2001 From: maito1201 Date: Sat, 3 Sep 2022 23:48:54 +0900 Subject: [PATCH] feat(example): add conditional image generation --- .../conditional_image_generation/README.md | 103 +++++++ .../dataset_example/huggingface.png | Bin 0 -> 25476 bytes .../requirements.txt | 3 + .../train_conditional.py | 265 ++++++++++++++++++ 4 files changed, 371 insertions(+) create mode 100644 examples/conditional_image_generation/README.md create mode 100644 examples/conditional_image_generation/dataset_example/huggingface.png create mode 100644 examples/conditional_image_generation/requirements.txt create mode 100644 examples/conditional_image_generation/train_conditional.py diff --git a/examples/conditional_image_generation/README.md b/examples/conditional_image_generation/README.md new file mode 100644 index 000000000000..9cd10c35a621 --- /dev/null +++ b/examples/conditional_image_generation/README.md @@ -0,0 +1,103 @@ +## Training examples + +Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets). + +### Installing the dependencies + +Before running the scipts, make sure to install the library's training dependencies: + +```bash +pip install diffusers[training] accelerate datasets +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### conditional example + +TODO: prepare examples + +### Using your own data + +To use your own dataset, there are 2 ways: +- you can either provide your own folder as `--train_data_dir` +- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. + +Below, we explain both in more detail. + +#### Provide the dataset as a folder + +If you provide your own folders with images, the script expects the following directory structure: + +```bash +data_dir/xxx.png +data_dir/xxy.png +data_dir/[...]/xxz.png +``` + +In other words, the script will take care of gathering all images inside the folder. You can then run the script like this: + +```bash +accelerate launch train_conditional.py \ + --train_data_dir \ + +``` + +Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects. + +#### Upload your data to the hub, as a (possibly private) repo + +It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following: + +```python +from datasets import load_dataset + +# example 1: local folder +dataset = load_dataset("imagefolder", data_dir="path_to_your_folder") + +# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset("imagefolder", data_files="path_to_zip_file") + +# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip") + +# example 4: providing several splits +dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]}) +``` + +`ImageFolder` will create an `image` column containing the PIL-encoded images. + +Next, push it to the hub! + +```python +# assuming you have ran the huggingface-cli login command in a terminal +dataset.push_to_hub("name_of_your_dataset") + +# if you want to push to a private repo, simply pass private=True: +dataset.push_to_hub("name_of_your_dataset", private=True) +``` + +and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. + +More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). + +#### How to use in the pipeline + +```python +# make sure you're logged in with `huggingface-cli login` +from torch import autocast +from diffusers import StableDiffusionPipeline + +# Replace it to model that you want to use. +unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=True) + +pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", unet=unet use_auth_token=True) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +with autocast("cuda"): + image = pipe(prompt)["sample"][0] +``` \ No newline at end of file diff --git a/examples/conditional_image_generation/dataset_example/huggingface.png b/examples/conditional_image_generation/dataset_example/huggingface.png new file mode 100644 index 0000000000000000000000000000000000000000..c4f5bbd66df2bbb0c80b739d57afe16b8d97a6b8 GIT binary patch literal 25476 zcmbSS^;cA1yuA~22t$XobR%6uNq2{IH%R9Uh|&$xVNfC^(lK->ilCIl5E6o1e`OGCha!Z~EQ3d$JNJQaAATAIt7+?-#BL$RO z95=TS<^cByVoIS*jMyj*@_PUZ4obOOu4ybyLOgYne>HcI_s_z=f9Z=sj94pqWeUF= zf*P_xUKRjy7f)XXyWS#ZaM;9Qg183%v{-=n+EBVhZ1*?K?L4Z;DeE#RZ0c!vAhb%m zLE95ZRp99Gnku!M-q}Aq&{;~5-@S(f^efqNw40qrp541J#8fpM2;bHQfi!SPEFpu~ zz$YBFS{z}Q@88@_|H-fd@n3!w)=yP?Qo25)!@6ph+4Da8qYwW!bwA4iuznM}uC|s0 zhvFE(^ul*hwt$)^6A*t2a%U`i3zAfd=BObzQKkf;62&9wIT^u;*NDbz#FSiF07B`7*@l;&+KJ zEUC@_zY1WIgASS)er4(@fF+NmAjp&YpW^}xOWmbl?ty@{w7a~;o8$8Bo}y;36F9bn z2T@0WFz!{Pz!9QcD#ZM7h#pk~p65U=o}fmb!uQRC`YpX@9)ACDModLr2~Rh3Z6Km< zBi-(PF`oLGcBujvu$Y*r*T3zKz|D)u>oH9Q4avg22n9bUecZv)i{NXVD&}mHCy+Rj z)@yD5b4P6J)ZcM%W8qb~fAH|?<6j_Gd*$J;+v1@M-^QpKUN&9Gl;zTt{Z^<^AW4qv zUlqU?UpIV&~;1frB2h9#|nHm$A+bub*QYU806BiZ+!me@*YnHSN)BnFm?& z9+S=0+0^DxpnU0eL1O64Fbx>E$0I*CoT^{d7U*~9 z@Uu3re}gMI?%I3A9bFYp`N2DK+EG0Br>VSPKsUHIsL1N}2#0f~v_1h$GL0}ww!bt} zw#B=rr!O;rXc&(_G&;TP%`Jjrdjn!!SlbmivKxr$(t`Ee{@ zgiSwXk=S-E%JvEmeX&%_T9157UO)C3MlwqcjAoCed9Gf9r~N@?6(8R+^GLrq8zmnn zB3uLrdt$DLGb{Di1Zya3^#H*p_;1|9jbAig!4XmQS*Z}bl%Wg#HTv%%0J4 zk$qs&wY>zxy^0EgD;OZecR07^rNqI)YAinwjpOWZ+lC8#Sv<^{1kI`&jIi7*J_{2% zBz%HzUBb^PP*>s^F-Q2M{7-v;^hjEXv!}NAE{FwkMMi?xYU(|8M&w(q9%kwFB%KXY z@WK}IH>;eci+F5e1|xP*z{~aLI@`$Umrl+d7n6RkC=i@|b@(dd{mL=}(#udnuPWVz z#65=qu>d0MPVGnDy;1=tO-vS*IPOQ!k?ZnI#>-{ak^9FS4%z5+BI%4zlDHB-Y*5#P z7Cjlb#7DH9$md=TR9f{ z1**WWwIIn<=fp+*G0l=X8g9izm??HCD0r)h9&P11d%%5y=Svqf<#&~T{#&)B>jZ4P zXq6wbegJ;@innva_8N?E$S_&y6M%gP-oRmg8bx_9BC2tv($e&$xymSy z4^%-z4BdF#u%Cijuz^Skl*y|CzKI{g2rGS#O-S|EKC^Bw2`AJ$HU5}>@{TikqI;I9RAVI zY?8$roj=TzVXT0eMe9n1$+E-Bqt^Sw=oc5ywXy^#-qcrli3@v2pvKIWf3guS5L?I- z;4CY*BgG!zED!*-Fs}2xYw?0%vf9ycg|C$9?ww=W=i(D6|qH! zO)(||ws^G>xSV_7l6#&j^9SjF*9-&yNMEo7**HE6%O&=TSn`!7h&={ZPb2Cyk$eYi z<)mDsAQX%2L)0iF>j9Xo7YK@`M(B18JGGF}1|Gj#CvE3w@@2A>h5~lOg?KF5k1Xz; zZMQ$cJ&i46;JXln1Ew~C!{t>PdyR>LuALH^ZX;}+9W&yHMa&;4^k^}mJHwUTIeO63 z@NZ6VcMm9_E&37ICPTgH)j7Fvces09Y!%xpqVh2E`(~Oa1Myb!HE19XhgOYX#D60Y_r^b&bZJ!!I?p&D>6Mat))3vEUofQp#rk((qbhXqV}w1QcOeD zF17=f$e;CZ^Pygfd|pohUqaZi#33WwufNw^W-&euNzY|oGu(|9*mdA(Q6cRgnR#pd zWnVa?xF>Rr^;*Yj<_EeizMX{V0p2J4xObh?A@Y<76On!)mar za(k5ufOc!&LOR?l-QGS~mfw8!sH`rB}eaTX_6$b_!Cz140K zfQE)&6+dp0K2N?sc(=ZBrYaSc@hLH%Qm;01p}~feDISF)Lb>%ms0}lQ0?*4r%bIaQ zu4Lh>SoiIb1>gWST;O%4jpK)G0zp1o7G1u<7pCWBz@VvqJ*vQTdNL!22q1R4N?Re? zyKO}JIk(xVxa5luM!OwPJ9i#7qcGtF=0t68qR?Y2mF7{_ zPorYs3B7^0gGXsicNw2vn~RAOrfl8hJpZA?^eN#O^m+z1{X&NbpucApJToj<^w?dN zy`!c3wlu5YO7r)A)Z#n+OI*cJvE&~FQ02HyLho?-$#Yl@aV;m{>$z|ken&yw`+)j7 zzjwDZ5mYVFexIQ-Gd}ztOcr{K+trej4Vd&MvJ;(JMUCSrY7ziJL{;>Vs7p%QV^Z4f zogMzMw|mq1$&{v!UTb)Ogp>lFg8VI@kNx$M%66jOz|D6EfS*$Pq6tvt^$lV+rLFr1 zZ09yf0XM83tUN|lkIfEQaq*0ZwtE>Q{ys3l1?HqMn+Tb6Yx|oWLx%0q<{Q&emSYB*;wZ8(_^PrDy|`ms#-=&SgR>MkeTkCY-4CH0MKjP{k;?Py!D5{Ydd7z+4JJYjKu(>gs^ zVEwSWm~!Udj@fKIYu)_cpI3b^lj-xz1W`Q9oRwUJ5+^>$x+FmV=?egDws9waI)WNK zzm1cd^5edJ_bc|m#{yguA%v_v8}yr1D&{m*`9XN!ppr>X7&N_k|JRcJ`>fk*RyRIN zAa7hf#5*r-t?SvrIDEL|qL>Nssdhjp!{$&^!&g5Fta$M7MPyK$J-Y)!gNgBh@poN!R zS=n0wlNrr^eKW#!5}omM0uuu7vqzM5h6!-&FQ%4csf%;0Ens7CJP)GDKlOIJ?EVS+y@9R_M>qfRQg{src5kL7vwUZ$n;%spd! zab2pj`u(ALb`6V`N7Okd%}1dhF6c~=&uy2-^jA`ZZ0lZ`@{OJv*qa)64#OS%Y|dzN zMy3`tZ2To3KVU{t)<)HKZ*euBbW*SO_Rp`C`D3Tvfj6{KH)gae;SG?i%T2%bmD4Ms zqHqf=OjfqY0x~yNu(Yo9!Q%5^EaG`0oNQv45YEB%#QL1D(NGP;n@H~`jfy`JibGqF zVdV0Aq-HPiX#OK(n6(v)6DQu>>!=1VOaTK$|Eu@#^E6OsnOl6x z288_1kzDMOefdn>bisf9)q9rI1Bp5_j)V&4jhN-^6feAy3=yWBNMD z`ZPCgB+iNTjnFIk+iL%c?0$rs5>KnE+*PW^5 zs|$xbaMV9D?9z^pjB!ozL$!s2Ms(4qo9hw|krWiTuZ_Xn2z!QBC9AhNyat=e(9abY;%1u}u!KkPF$ z1R}C$gdS}N)sOE^?mO)AlEODRn__3JiiZNJsC>V={mmk{IC7wo=@CJ8GMX=PNxUUp zN8k2}#7roOVhYx;Ar+S2WfAkehKFB%RkKWN;|oZd)rd_ywR|9qjixNN0tF9Zb23cm2&FYqRyB z+4^MIA5fzPlwF4eDPi!B_ggy_F9N_BbTM`b_ezW^xvpYe@Nt2qTNGZL{0t8}9NdR8 zth8Scs${~O4^wATS2zD&wwV87y!yk$5Xqcyv^RygB*IK<;!?#bSpRyFHQW6M=Z3hY zpcrm#l&EhX$1ifg-QmaX{mCn#c7bW0xpz~g6)jyeD)3(k=8c!yt0zA1gV}pCV z30tj7RJFVwOnttV>%>mmQ{%UY321rR_9>Z8U~E0d^!EY|K$tYBn5rFipIQZi2tdBx zG38)Wwab@X0_iIt$5JRmbke=ht4Rdt$&LA5*eoI4kl#-;x{%M(Jz|K5IBRpko;w?b z%a;p6U2BOl21jSxWEa!A#H$$(HBHAPpG8uL`+Qq%dF4oTrkFY)bv74c_GMr+NKGmt zDt*a!~2*I@^-eHz-iS|7Ej(whiukt+noB}ZXI zbov``@RVw=C}Qh;sA2bT>N$W(h^2bQfqJWq*adc{y7Sv#v@*y2Q*pqS57&By=1GY0 z8)k&pbjRo`w{^>zgj=^vHJohjX(@{5o8Xwgda-)tS{~27QVS|sAWfILw+oOHSinc+ z$op%#9r>L~aV2JW?bs2KYQPe3Ob7&=IGSpbgUsKat>!8e;i5I=br}Rz+O01fcKW3;7iBVMP_4B4+UkvG1hpYs8LPU{sWa>+FAjBK z#S`XU(Xdll_@-wKVKWv$e<$U;A&p@)mq!Y@4Ug;RD|^=`-_F$Nz)8*wk4!d79)tSm zEuft*9WHjr+#*g@my&uKDCYa#*L_)adbNQmuxHhND<^8$X$Kl>C8{Et$5Ho(GY*(` zG0@RRqRjL;=(A5p!X#fJcobqrm@=Lq147so9!drS-KOBBU%#M0P!uBIo}J;J+%jyh z*{focF>QDF8GPa&-LCBK##s`rf`K_PooEe9N<@`N3x*)U)fgkhLLb&*ANI)|si+r;Oxzv!4zu z74wtibnj}OAGy(Fh2j}M^~7azsEJMMa-e=6P8d+}0UfI#KEGpw_2tRz3pkVDBn`I> z`9eh!y!y{WYyh<>55BU&t_^c41W*n>EURfTs^D~vW#NR)rhCFHWvb)f!;HSSu^42E zC>S6JIh-f&xm)+%lOQ2j+T_+Gmdd2`?&f}}?cq+nC*

)mN@hAwZwJz9Wmb{X+1g zohx-_%KgQArERTzuSfJ2HSrI`sI7mnFgtLS~LN;DAq@>^c{1G*lo5@wg6L z3~?*}yMTO%lgB?AVs8rmwwE-}g6z>DpenY05){Ba&Qo0YYU*Zf+XOqxb5I>VFM%sHa5Gk^1RX+&mlYVW8l<_i0XYTP`r-|uhCkOoC;1K+NIbqFpWKJdw}ObWahUM+rV3Ve^6Jrg~D zmM6x1b5tGIYhvXjiGQFxV|uXDClLJT@H5q|Bw=sl30q64ald5Yg5i=&^_OPo-W=n= zY}8_`Y7kqb(9Q2xn~k@%FmKnl|ESn!h&^tLAJ>|Ag8_*z1k?;Cx@HfSu#1jNjfl*k z)eV~f{ESQEM)tZ=jx2hv>HzuGxt<9dn4SP3+MdW@F`AwyP-TT|x3l$}PS;7=8+p7a z^Is^ic>G;v!#V`LL(Z_0k1CdzzbqjRpliGMCmeM!gZIzz%dp~`YA^`?4iE3=9Wgeb zy%7IGU6DCzKslTrXQqpC?%&I#DXK|&ATI>*`2IZN2OR*TjSLiT&-E60VCzUrVDbky zG-YG6rfs*JkT}(qn!!+C1d7?Tn9cg2#ckvgK6t#}*0B7(PwY1EhWm`+=izV0GxXf^ z7=P{OERqt}&G;>ZKYhpY9~Je!i)ov&wSW9s1~CPe1r(S7wUI(qd|^J@Oi=bLeP8M` zBytz&#b}m&WQUpay1XMj!Y330XICiXnRdLYAILGic{YsmGn&9*tc)y_{Lho z<~4+6ti4MK!CR;dq@d)n zlb%D;Ms!IoN|kj3jOr{Mry+6oI|_AUy}%ND=*1* zBEwDIa`k!~20z?vv~CeAXpfRqcjMhHev_Vsn8I4TpvQaJLJvq#HXKZ&QhSf;?jhMX znnP5oSHJvlp)oLSYj!_&+nT=>FjdNa!mLLSkLAXel?A=LfZ`t;)yBl}RRIwTof^!e z**BfPNy-8qKjjNq1@!13XZLi6{qh$hy0K58K$2UW@Y1MNdn3-5_GL|G9oT%ex}cd@ zakFu^P>ZXSPPvCO>2GwZHcaBuYpBqg=^dIq97KSFSp4|*u5-&_nLnF%wy}!xyH=7_ zjmCtr&8hz-6co?JyJ<>UDJ|4A14bqi|1jh6ojA0nd)%FqS8mm0Z9XWr^?k?oV51-$ z`NNs1r(y`Oy^))`lRpeSNU7ke!$#Q`I%R!+#g~bORN_bnJ`d+t|N1hsO5^h`?8yvt zJ=3b_d737 zB7DTwK+99o;eu5;*Zc(8Dh=|3ddK%PP;Ajy9qXP;(II1~MU6Ryp%(pM+J}(+mq?73 zxvyzS2w2}x+3r=M=5%EMcz`tUDr`VMvBc}HBVC#N@_B+QYptUXYIXOGw-@f0TbuDu z89as*pX+Vxc1tI_Fxrv@Ay9IyX?x6Ud%W6yZw=e|r#lqs)XI|NIb?y1W0NeMK@;_2 z)rL_tT+L8I5H@$jq=;kbE>Ky0R>)Z*x$NF3OQDxj@sU{s8`@-va8>6b^?$tpgxwh( z5X?*&LLuql*qr;t*9T4soY@6aMf^1w_RrhI$s(6b&U0w_J>t{{TkJ+B34{``)?V+` z#OVgvrh^dV|0(*#j-uGWx*RYWh&MqX^YBEMG!NX#8M3cX># zMKvU_NOQ`zaCzAt@Vys>+dW8#2?3!B7!vP?Z~>O2JvI?)Ao&V_yOa4YFiwp&jLfll?o zf?&ibi}ZZ|=)%wya^e826YYZuU)B;au1|}#j|F&K&%wX`H>RgGz{G+05RF}Q+4=zr z5UV3ipH%V9O+&W5>R`T(8GJkdZn*~+H)mi| zS`Z=PyYnQT=eFcIMT_jZ0P^D<1L6l5K<6QhvG(UR3V>%l9q2@^S`;_J z4UaSX(Vh&sk}*|!;=DxUQ;M-XqR=>QoP0_Q^;7bTQS}{re-oBf!@C&AFE485c*p?4 z>4&hjO0(uuSyX;2#J_?9z)1hTVS)f*kZ5q7Md2K#k4qmM3?Am;FOig~5r;`@{|hQG zoAy9=@lm{8H?d3s3e$cDkiM&luM8UTi_qwsYmQkr6xK>gy-j)V-nRc5{ME|lX?_VH z07>1t-LmkGeBe85lO?S83qbIKV|mhW9e<(f280&x4m<>NUfJd~HyO|O*`^^sczCq! zG|6ChxSCiC>#V)XY?G0J`~vZ6!t=F_Xuol{U9S>i@QFl#RWS{O9UrdyGNHBCnNXm+ z3|GY*Swc$Hhz9}q)U1Er{4}J7mHShT zo8go!42qkab>mO^Rw|$_AzXp(d=Oy&&6hAd)wRdsL-gSvzplMCu*oSk&jX-FDip{H zjaK+IZB7h#jiU!}kc=h6H2?#FsXahqsAldI*A{}|s9es6kElwIUil_^h#C0N?cxO1 z?h7?k=#J!UIy~CvHV4lt*bYQz zvcrQ>p$P$Q(M+84rkVF(!RzlPYalbxXaLQC(f}bcpSzq59W8%*KI^-C@K|Wh3;p}m zR%4xwYV<2)YIa=SpD&VD&Bj)@8rsU`yms6Ib_W^M3?X6(m~NYz9A}CM`^=Zsw%-rd zu{5;I2T-zru5QPpPymSXXV8<{;KSojcx;r>3aL-mfHx{Kph}~|qF7RuBC%(P{Vf5lwRw$Il zI7q@)V+K11bPH9mXrV&8%mON<-f;D|)$kfENd(1fe4}Xmt?>ChL zgBU(ZBHRJ%_9=FlH1FFiO8c6WwoP7(Q;c3*!2G&`00V+YbCtEWoNoM0Vc{=CleM_> z`1x=0d)8Zo7Jb!0vJW@$!2r%AA2-CrE5p^a(A!+(J=YO97%$$=0lu@BYEbqh%_79$TZh6Bc6R^R(+ znmw8s-w-3Dx7ch}_&2&?-)X1Nw`()wV4zONN)+M72y9Ex&0UDuPBVU@2B2JEKOkKh zjChxgEkE?_?DE#H>%>7xke5^^i`Z2Ga2!*x`j|1fm#=4cqDK)ct1Kvm;UV(h#u@b! z|D~CXPy)@zBvYliO31xK&$Y!Y4bWongYi@4MwZtVNt4>;ZlyQ?`B-1+?EX98D zt$>?!L6=mRhQ?z|1+Mz{!xX^Xldo8{=@mlAsy@!<*6eJ$40!*uTPjnC&MyT}(5J_Z zfjp(s|A}Q3NZOwH{HT@?!==w_UfVj3^|cTgcU}I@YP;E+8kf5DdI z4QI1>Pz!vZBwV17;4pqf!!cGt;sM+_0Yw47V#qf-T#?@$)Q!sT3klkZv@G!a`-QNZ z6biRB62a;J(!f?wJa5%1SL;?2Ygzww?M2ebTuXU}KY(gl5(m|9D@7Bcb_W4-nWQ>%eXYtnnYIZKryqGq29gQ;1S<^>(<&U5+ejm(gi##+l9wnG` z+&{z4Ehq%_$aO3a!L8Sg6TrPsgHbo$d?x#29II%B$qGKAm2=J zfdEFpbAAWa?t)aZ0i8vBktWG;5`O@YFx3Jn@DUK9;ow6XMwCeSI(2?>KO$^*ERXun z?9umNhBy{Uk|7E((4thhWdEUeK2!SL+~>M(Dm8EHH4hAq;8sOKN+Cf}lh?s)pzT_^ z2Vb9oc%D<<3^D#&OQ0ePao3euI1OxZs(K+JZgjXTo>Eu)z`JRK@!+e`|Qe-Ecu{6=PUH^(4(&t7W$z`sd+gUD3)xW35_0RIQy!MZB3U)ewiHUn!o zjEme1WOo#xFuX2Jc5%^qrC+5cQnI4aqBYe|S6@2J_i8IQ1y{rMytksWaHS~$w2cM$ zs zwD}nm^+rKJsBy{Ld>qur+f=1hc#9fe-?RJ~$lxq^_AVJ!_6TIVfrDDD2dm!wq60Gh z46qU#mGhYN{PjpppE3Im6=ne}@AUEWQ`okRV#S{lQeF`ordQrs>xfQo7LU{N#J)uFjm>R zl`UGz_Ln@^pjHLMbWM&#eT?2D0%{c1ds7D!cuaEVwVOG~+<#da2(m9x z^JEn!#p{i~Kv9{ls?4aad`T&VK?hZs_;Kghlhkmb!1)heBYK=qh{{@gcs-}>M>Zy| z#JrY=$#Y^&VtNkuJv3&M#X{d#+|~>aXkPHP*8MO?SF}D1tIAy$` z#ri|Dx)G0gcwL0Ox^K&=6)|CVE;}U2Jqe7liE=VRqt0Llq;@fIV6w0}DVn#ec6zLy zZq=lNm+ZY@9jv`-qW|T~Qq9%At;~yk$OP9`Jiw-a(k@8$N{s(VLApdJt_`eI=Y?ps z9@--~8q6vE^10WD{IFn%LN6_DuQFB1%GhCxM?+W(Tlu5M~rq$JlToLOWPbCCj^T$7I(}y+aQ~LF=*k$pu zumW!S2!mJ8q?{!$6eb1-I^u*r=aa!7E1)c4!e6shJhN5Ysjy`WgbEdrc)$A2YuJaq zaOOjYu`o!g?I=^m+E0bJ`)}2aVO=}8r0BD*$XCT-7Zwn*qOJVEsPB3ui5Al<5FhXH z^g0`~*kJFk*1^}o3qTQ*lD z5o9B{hGXT&se_+IF0lLzo#U&5KG0pEOnG-`EXI%uQvVulQ_bo%Eyk3pISQH zf=TBq$S7NytAl*N@9;(Ujob;6UOEQ|L66yF82$XXfhCDUcj1&JSbCb zyC|L=jF;nWFOsyH2z^Tiw8QAcacR6)Fu&L98iy^y=YSr?-j9sj!5i?t6+lr&H zaDf?hur)*05lmZ=9g0qwCnhcsyDKj%ZT6nkWbsb=+-5|PL40H(O++V2D06f*Yh|WjFsE(L^qWcoN<|%r6*@_q3(x=?)3=q5rm@^m5@MkqJ$y-7lsxXx} ziLZWhvqL^y;j|K%NdooERk_hfUb{tND`nI`WZqg{#Cfl#I}Kp_=Cjxp@~*jm5UL-8 z`pz}`^?KXAbaQxR7rxw7lALhTYUF#Xf@Kxz?PpiG((;rcvIk79r8n&#^WbG0Mj#pQ z`@0?ShKlFkPtP>b+Tv$8li`mhySudkOx633;OP@6=}=B#0r{o89lf=Z{{=dB%ANSR z^>VLT0tpuuJ6J>slxVK1raAc@kh$qvkz?7a!47BIVy5TE9SA@4 zT^m()!29F;ivnYFFT!D)`Gg0|8Qac2L@lR5hGoWOim1sAckKL@n*DPaLa50%>%v?z z@iM+BlvMq|{aKmB!Er;|ZLLz_!LVIq3U0EZ(C17Fxu-Iem?L)hv6{fXD;w$5GrWkt&&1n ze;&R!(%KXA(wZlhGNcr%tyvw%JZp?LKBew)7b?+nnMA=d@pynjeibpP1KqdJS$=Os z6eTd_ieoQ)WEkQ)3M#AlScBZxJ}sU$9Ghe-h35$E?AMmw3oW?C%o(37sAOG=zZ2hV91)f* zYO5<~17CAgy5#C?3#S%pM|@1QWU4JgmYe-BdgOmRc%e@LMo1ud|Am(QyVJ1^OMRXff3F7 zW^tv$9;0`kD4aDNVgGy#kEBXQt^6`oN_G!JE_VO)5Y8QY+a6NFby{eEAH(sK5%;GE zuaJ*mm(K4sT6ChLrUCS?AD9L(+1Q_kjYK;Hy~73ANU4j7O~2~1ebM?ZSnJykW`LXp zT?pWpra9+<8J05!k&$@WX+wUyAsqIph}?=VNyJ}_Dafz6sh6{d+AYV`I6lPrRkV4| zK*>Ji`j4ov3vndzcRt=;QDc$aDk09bb)>mR$xyt1JoJY;GKx9b34)_o!;nBLcqojK zdo%A!t1z)jh}mu;7>m_e2R@M0weWtQpx?$E zjWl(y{m!baW%j_4U=GBJlgUZ1WsmMtkf;F6`2K)#iJ?QtpfTvibYi-$4{W6LnMl)+ zFHLlutb=jI%GPv{VsYD!t@-}dr-ckQ05|KvZ?0rc8YX;qfbp2owr#Hf3d8n1o4f$v z>3&n4G98yl`&0&Ac-P(w3?JYeODlpA`2QM7?d>{lE?nO=)^;jTMP%XyJiIODcpy;y zgELsK>LMJRsh0GHwej!{YKbt^E0h0D^eh^GxX(!QZ=_ow*0CN> zzmT3oOR7Zq=gV7(EWy=|`%?~sl=in3y6OfbqlG&YdNHo$bPBI*Q-*D|-5XwzH9T2s zIN>fTig}lc4+Rv*pQj5m)$=IO*DC(0R$$srUXa^6=Z$mb(NO&j7dVX@bN^ zd7GT+>)EM^xJcjpZe1&nWRSZ^*d=m(H7fen2b3>@7!1RIHvtn}$fsywg4V#QWDzr( z2CE6D|F%>>zzDYI!zu=x#7^XQM#F>X-<*1MvVPME!sUjZ>1@fB(RkTsai2k>!jl6{ z)lFL2s||JbDwRX_^8~K^FxP2z(fh3Hi^s(}Q3M*~oJHn1Xom$xQm z%XRw7O?l90_R)D(e(M~_&(r;`&?W=f$1g6lsAsTPcFss(uZZ6q`jC#O)IK3_1Fz`X zj;W=id>kncXx+o|qB14!;+_&FP#&$GJ1%ndTYHq8zgjck>h3`UuYZaR z-3SjCIZiD_JTjdF(?5M)Y=|~ulc1nEYsRkR?c*A2IKeNffmX8jSC1kdO*(v37SHFD z4Cv4dXd)=UQuxTsrWMZxQCy5kf6RRLgK3Z#SBSvAl~|z#IaGvpXzbFd-Yi}GtJMLf z6RLoT&W`O#3^LLbJ8LTuS%+IalVq?ySalh~kE@Oo#HmDtwm0}^Dk6G3kbyodG>z#u z$aWB_k!zp&PQZ3gj^&Iu?ue+0i+`XlWB-F^hQ{-E?hX$uayc`y^#(qXy?U>FIGTI2 zKHVnelju8;r>BP|r;rGxnQn>j_>?Vy~MvJZ3M& zy&45AIuc>*!Q{q(slUWzWAxuHV|jFMj!sE8a(CH}@Y*)4wx>%U1!UD``O4Ym3S^B` z2{ZwTOdt)~xnqX1`*-kJr@a!~69D5a5xKX%ak2fQ&pW^9T1SQ6 zieoOVLbwXHNg(OP^CJmDc2|GL{`>QgG+4OCrT*VbRR0ddgR7V>CU$lnf*8uwsJ^m z`?ns{XiNl#pGu@h6NhXN8Lbi6AOBp*bC$p_D*kCs_3Wr5oJ*N-n8GUcW1c|r=gM)^ zLQj@vv_eqG;yoydMU&q-<`nPs^vZ`@X^^!s%aaloj(g~m>2!rzLN8mK;K(v6xmgA`}LFEWsbQvA%#>{@*H#L z=*p-Sq8c0uwv)OuqA{DEZmVDlHjUkW6ZcQmiosBA7h!WJ+tPt>R0ta2mO zWi!Rc20(RRmE0KythA>m9*BzUzfyU^jX&MP54?t2oe`C7;jcVLj{yz4v8W6mNc+~U zc%+uyhnB#T+ct|NyN~sXw;|~#cYg?G*lG+yJ~N` zS5hBxTMc4=ld}I?Sa&nl;UBdJh3>HH)8~I%1^E`zpNlnRWu3WGESiak~epYZBjII8Y?aFJi6@2^|_OLepdEug!>*$YAX3X(L-!V7OpyM@3b89R18SbV%~F=ikS625tW;F$x}Glj)&7IJ{NnT!Z7k(y)@koFgxWz#g!6) zfEJL}qV8Dl`gw>CxpBp`LC($wxl-=eHZhCpur=OK4dUDt9YRcuyyICh3afxAaLj^> z2Og7pM^>jpU&)3ODYk!?qIJ1DeBwULw2q}*j-ca7!%4>7T5oQ{f5r;93vwS;=>{s~ zTKT2DotfLL&nI{G58P=e;q$Rh^zc@BqhGHiejE3snIV(Bcl(4HWP1#uV@s&|bov{v5~TUOL_Fz{OVd!hIHxJ`ANwOl`rxb= z7%DfF&H?*V$yuPD1c@OlFe+rcS*DRPT-Ap6NywQKZ4y%vL7UVTJSK~J()EhPkjq>7 zy&m~?gImChA@u{sNyoRR!obD3ATt7QqXZNbMIvz~d!JYI?94^Seyo#bz+>Q$zlZnz z#0$S-C`d0aw{UpQJVM!w#H-DHGGTo-J0tP??r=64h2sDlHXj}Jw=C)HSG(zt!nS4e zb0&y+O}LziUwe%hFBDlJ-ef>O*tP#gw`V#<+@1xt@rRU+jFbenFNDxZZBhMm><_Ig zu_julF(SNA9oxYELn(8=PQ(TZUL?j?$s0VD3)5$}+j+!TM~rar?cgZE{FYL6Tib%Z zb`bwUIae(?H25NRMf>;1`84e?5Ad-ZR4y(=&yir@!C(lVbZuvQQF5B) zx_UcPS~cqqw~V+yYEJAlZbO6I|LbX2Xu^lW9lcS|GL{R%gs5z^0`iiAPTh51nQNJy z|FN3j+JTQ*U2`|%M4OAW;DtXr`Iiy^EJ>pC&w-acYxTP}0xPoDFZ_L&_FO)CT`{%$ zU$8T%zr~r!k!qx*-)NC8b1YufuOOwy4S1)k1%B~C>af9Cm@xu2n14K}aRTn2s**H2 z3;L0BtJw3>ZQ3qVfQ)M8eMgzlcqltV-AQ(M9ahCLoyi|dlEq2g!b%PDS<1&=ujbA0 zI3dJKPn=LhK12ina28ETeaYZ^qk2d6mOJ@@;QQL4=$Y(c{nRA!Z0^1ipuRAOZP!Y? z?;R|>*w$R%>wmoflwBr;0^e>K7PD4kpIn|kd%(&GauO6^)tY}6NLBcVsh%F-G~wER z;GXuo+(5NS2y@us4|eZG^0K#MGJW(@>tmufJY%vNU%9PDhqrRr_hq%I`X7D&)(zS2 zrrhpm0|C^XRMhu0fne;MPa#1;K}y)cV=8^c2@IQ{nLJJ7pKeHKF^N^RiZP;Xt>Z5} zbaQ;oQj8TRR4*h{$IdymXpJ3DvQ?gwDxGiY6vrtp^t~H*i?AX{4$W#qY+^J;K0Gnm z|H|=Udgr%+|EcUO|Dt&RE^Kc$qubFc`*Lk1eK`@@jcIb}e6^sfS%a=Of6E!Q)`^8dW&KB+>GEUCEx$PSv9JjuHvPjgM5&}g={nY);eDh#X64Hn+ntY|3Jv!kzhx|pmf@k& zO8eU;rrERJ?_Mj@Pcn3Zi?rG%5}r}a?bl1~WqT_Was<#7fng@J#=Eo8uQoN6=%WP$ zOW}nzg{P2)JIXgf?o0K{Y_hNCXiK3NI%@^vpbU>d!?2xMFjCngPci&Zw$|6M(~8^D zmo`eKb+NuL5iDB2W7mF13}J#;Vg6O9OGa?Xsc!?!ChauT=QP@5a7YfDa1ZI!rzFZo zR=@HFJ={bx9b5KEcfqdc_}Hj@{M~mNz8SXCs<48olDr|bsV?~UFW=q8Ui(p1tC2GQ zUQKp35^1FF_We85`}Tl5b~oNs zb5UMP5zSwU^7E%=l4tk7R{aX9F5Q3Ld9;Wdo*`bybne$gvOEy8a{*f5l#c=t+NS0` zEAkN_-&gskk$HC=)vFl#(Jb2Ru&`Vt0R?SQK+&q+wevH-gG= z{Fm;o4=pq${QIPkUor-b1J%V{#>Zg-L#~0#$FS$THW}5sn268{NES^3_Vs-ry{mej z)QbxdYFgFr_zl5t7?a3PRyY<(YRXr`0{_WHr`~K&=@ng~1aW`63t0&1(4%@bM8CCK zxC$BwQLRS6#y_B?pYK|2laFfW?QD{Fy?PWP`Kr&ZgOTcGaip0VK$Y$LOw&j`eVFs* zjP0{q4f%axJDbe@QUj5XJEw-}D&4)M|Kd+R4{-*Z89EotxFld9YU~Dj)uN`AYbohB zES2aY`|5j35WQ3q$5uvf3rKlbTq6b2^089pq1lvb&f~2Q9J6!!^R^4?ET-d_%Ix`< zpvO5-r*885+On}-OakOcB3n!;ZKO6``QB`C+IZ)y(^okxm-b1HF9c`)ex;Iep7@U6X!U#$2xCbS)r`Y`SS$UbC!|+z1UN;e%QT(bPA92F{B91Q>5aWG z`E+9qeIZv93oSFg3)BM5u~F8GNFUj>EA~WBN}q>Aw>z{5YxL@W%=?aZSTtY>X#_ zxQMCOSe`dKKvQy`B8X6h`WVO7y`58SUTU1mORopNzY>;!}Z85Jm^c7 zjiD8L+K?XKDkAhmK@OA9#iTF}?o&d*6C)~HKN>~-{{$u|5qzT0wl+G38yTBWO%NZ~ zhWgm3YiNA=}cXQHT`w`-E*eefA;~VD*49oigeQTep~UuA z^@n;S@{2I2T_m8j=vxa88{O;!|}Wu+@plU9jq5Uztt;XR`{2!2gVKtrIM%8 z9um-Y?@W1>8gdht9kU@5FwjxfvnHrfP%l`2Vs+2_3AQ^c4mEW=( znmL|UsQq+F)_hj^DL;FiQ~Sfei7aDpim$}Xq@n>&r(@EH3i&`mZ6vKT!2yU|kf8Uj zUd@6wGajf*GQ}uBe){1o+pUJZ!s?{O#EF8dnOW3cx8%V{3jK6(QYKA zFKZ5QX&OAP<5t0n_G6IE+$qb@Bh87DM5aGg*G0zq7dpIC-ciMSkqF_Tj9~GWIs)39 zAgSAP!aY1}Ub(2ee%AHN94~qY(~Zy5bp%cwmj|6>Tp>B^-tawYp3vO7zxE`o@0UA9 zURBt}YrfUJ+I33&)7`d8`?;1F{lidL!+AK~n-Jw#shZIrBu>4oB_~mdFL^Uz&m`s6 zesGT!KNlAQ+Yn-zICmVyW9^iU^;pP;>^tVxj`BWulx{eV=*2YRd$??qQtXg4w%TOo zCBYJ>)}#YH&-%jOrP@(lS;?I7QtB4O zdbe&5`<;Zno@#pW1+b$2!A0d^x@ZL>_vw zq%5@?m>Z90)=9lh@_l;iJrLvkevFJWlD4=djP#uxmH=#27ENVwAJ-w-}tC;fG^x(PS z$Hj`0OPiRo^vMwxRm=>Jbp{xws?qsafDSB?bMO3(B~bub{XwZLzPSPVRBg|`DxTfM z5`#fSY*^LsB?|J5sRv}KcV$u@)8iLbq>(zW(^OaGgW}N~lOC&a_@EW!IA~_3L#Ovk z7XUq{Y3f}1s79IaLTe_J9^X$k`;MRUK=CV^SYCjI!O;r&@NEE9rHd7yehfNz{1|TG zlSr5_2Tqa87wcV89ejKOteBX3sB8&-8kF+sb~F$os^*vgGUFVq_THK%g9xJL4Ie*iPWQ;lbl*`7rNT^Q(erc%2_5 z?EBCq;&mYKO7knK_>}~DR2K^6Bh6EcDvOb>m3NYQmFC12pQ()WS;X9u-*BO>huWjy zC$GbS3M1iun~ZFBBRnT@8uRz@GA7atn(mn-^?#Xp_;&MVwsG~SqD9LutTWO_6&C8i+!K@pUO4Y&%>0UM`9&^z!}7cF&2YbCiG~IyZ*Xq%Z*y|^9VVD?QO5d^B*qha%&aSG8du>N9T(D z6o%j>rK``oucFK<*79yPyzlbBLEKCkC&_^6#bEc}FLqp}1ere^)ftW(&p+>5LmZ$H zs9{-Z>BHe`lNv^la&;ssAC1B=Sft0%6f50bn=a=s9ht*jee1frV_}pVCq==nG+=Z^ zZnLF03;;KtY}IaSO}p0ix4c4N^)v3p*%qwv%ST26z9s#v*YC{=P!urN{%`|1JK$(s z4?x-tUlA(tZ?OnWPy}L)Chm20mR284+CdkAj%;FVtE2omXpat4dU=(_ytKMLpf&vC zW}ZVTiNM8scV92_CR3^G){mJ#xsdoxUfQ;GZByaa&y>?^f94MHQQ58dBE*7rkgRdg z`GgTg=93f`iOx(fE`v7TG-1bMmb5uTGDa3Ivd1n|^0ZwGkAn&$L~}1fq*rp2%%JA) zhk*_APaA&6ZYJW@d&->Lz}mDu;&glo`>!f@S6V)pyy2+j#MiF!GfG~1F=7nSc|EnYG-fMx{D3p=^5y60QFKp;yBB=!3SkU+8)Fv9AAcO_VW#>IIH#w{3NoN0ETXyim4il@jlgW(@iq;Lsu#l=vvQFW>E!R0Hpfw!ClAp znOKe#*a&^T!jK9{!gUkg6h?oWI6oXMevq+^`-h)r3>iRmep#M<9S7WLEXE7J_48=g zv!_)}5}~WW*~?prac@4!)Y6J}ztUR%C+me_Di>%i2Vv3r1|sVl0Wk$?8y5=K4Nep4 zI3`krgNX0}v48Pypv;r8hUEa&Ky{|;qU#HcZqi7P6U^+kiS9aMtS$F!m`$CI;aK`b zO#9r$bkj?0mVqyE&R2I*BCdu8UsP|R!YbU)JXu#F;FOTu6E7PvmJ03Le_5jI*}o}Y zNbVWhP*dYyI+ps~K?~kRcOS0Ob3qrxC;&&Kgg}i)a$<0PU)ed~Ez{?+KhEf?S^0?m zl{x%8O}&}OqIP;Z*5TJ+W{YX@MttsuaM6(XZ6V;X7 znsV{Yevik-G-mw0b?@w3r?e@2I)Osc$=a>srC^9F9udk7+Xm_uee2$ructN{)7!!E zne}`i(-Qt#RjRW0g?--Xj*{b1Ocu^B1&@EM^KBQ~yDGZ&x`4pK&nQk%k-&vl?LHo| zKI(D*Y1vWB!%+twZNr-*TGz|Ip%B2y5&Z7-M7rAL9wv4ge)sR|5KNdbj}%#+(N}IIP4_Ud#prq z&7&yL2R!t4DJ;r|0yRpywvCj>wqsvI`fCWzY{g#|Kc?yaXx{$A#hon%zYNBNZ4Y*YB3wzA1wkTzBsNlCpe`IYX)1ZyZ7j?2=wWH3^sGk__f0ll&fY9-6&pS8t04X%e zyNag>uTmr3Y?k9NV(){t31|#+iVTLr$=G*dhk{QIFR_6=rV@@U^6;=E=p$83X-s#g z2!d$fC8jTmf&hbPnzwRfBI-u9U9r~2%X*umY|zx;j_0B*D>xuU32E;q5oN?JFvvx;dK&q}4hBN@?r zI<92uNvv{)96;@FOA4vckX*(q2W11*s&4pft6kkKF=+CpV`Q|~gxX+k{}EjRe~iA7 z<12Iabw<*(A2rW~XaQOt05oADO$u=ESoDOP0~%MjbdJtq;xx2+STTnbqqob40aeE? z$`<)K!cxbdtL8emZwNwt!(a~2GN9Q3Zh;F=*8&kx3`xo_`D3l#g3Cn-KJ^4u_uN=V zuD|jl>7?)z14x;4w9D>%=Yr-?987Q>{RcUa^)%@BlJbh(2mF#4;{wb4iY}--OzpH` za*ro*98<~H25VtReg!_|{l~FS-dq(>SsQYk zwiFl%Mc8YHN)y-}!k>fB+7*$i0K<9Y=!LG63kTpr$=RvY*@p z2CMAEyG`?13Lyc_XP8^R#)L_PYV~6$nL~A){SKserYz7$AD{q%PHd>MsU6JAs(HVb zh^PE%9vAp@W5q8-|5;<>_Z2%bf=8X!8qcz^B2?JRTJY+7cCy0me$)r3)Y1E!PX=2p z3plB~jzVY>-grYn{)S4mdyor|3$E{>TfQr98Ar=gf{eu5Wy>&HQQAgBPdIJ_Gj99w z7l76OvY(DVWx@V$q#di2d!UQ6CYkvpmv^$|XDw3vS{^Me+LU(uB>q|2vKrzG!$ypI z$=W2fT$5~(UiNOH)j#PcH+<^&w1B<*TJH;J#8+r)3zW7e;q79qtx9_|yiLqB^BsFt zKLE)%hgh_F4Bj#7>wEq}a};KwwP>LUTAms;wrJA-wj|mq_$o~qLkrc6Bfxi7&u3=j z`Fm-Qd7KmZ@)u!TMdMF@QHE%fz1ap58M=)>KvbVTMM$e$0s!^RgVoxfaWZ^XZUznA z{PIaD{h7PPvOI(tEh`s_`N0?ZM3JdyjB|EPH7O?8sk6_vG=wKt7GcNS(vOn!>Q<$S zl5@6Mcep{K@pbsKQV(3Iq@SQ6{E7I+kSg`}oF|F~m#M2#_pJL*J9*AO>*)ak!ys(E ze`2oxSWqMK!E`8W5HgJ~w7+eNWAc9nM>RFIBx-jYZWJ)V7%mB6%6S$x4oC0vsY+Cr zN`Xrmk~ahz?@tvaF3Xy49U1T{-zE#VXS_MTI{qed^SW`2@Ikd6kI$&(z-yM;(JJ)U zMz<^iwL}Wr(mIO<;RLAbLq217U%p_&Z}uWPAyO4i&Ce-@Lz+hfTr|O)(zN3$xHyq^ zni!+WSs!Ss=@5p$wKRPepf$ynT!LmA-yX=D!z(GTPF4rKXJ6QQ3QDFFFm=u`)-A&F zh_f;l9HU=%qFx+N&t3d`NJp%W1eIwO?n(Xz!{ppHwy9B>kI+?v(M3*3*Qjk39<#te zw{Q2JROlBIi;gj3$Bv#?s)69w{Rj^~-0$5~%hu+_Mw;3Mdq8 z1_H7tSNXr7l!jo4z+EAxHUMq2S2S?4ng=R2l!i5)k0xlF5^PJ0>FF6`^pHzi6Rg|? zJ4oEhWCVx6`NU)2W)6w6{lV(D8OgyPUE=4#Z$>{8b!!oaJuk`f@?jRa*yuWDdqh~$ z!(pKOG<5dLY%+UOfAmDWLI-3>4{OC*dTuo>7+k-`_;NM_3lKf0& z)CIPthQYE2v7qN7XN*B10b#NE4gmD7_mQyngwF=T(TPX=4_)h=m|%Igq!wm7U-m6j zmZhKWZXKU2^k3$-u)ukoPdL_IqXCfrdqnm(F+~LHD8C~!StuikYn%}LnMJ55KpOY_ zMciD6b0>KRfhPu#11PG*-0TCjnuk(0D6S>xD_QqgAk!W3*p5up&;rWOWwbbZXSiePe_doR+4~&j_}YArON7IJ-3T%Si;*s}nJ`=$F&2kS zOp4Da@cFB&wb)fEYh>*T`Ex6BQ?$2Pdq{{})Rz<-y$s?yTj^R2s`a|iJlL1Ib&?!k z!HuOvN-Yn3K>EpHxFr;Z5B>7to057f#4)%V=ew#gn5s#B{{zbFHXM&6^Fn9g#9v`t z-VOQ}rl#L0ozisDGOR>de56=JZmojHn^0_C-7)1GsC!WeqIju;*OD_BIfpR3s2rQh zIjEiad!r=q(iu9O4tQtNDc|I$P*BhrQLiJsGGLA9#^gR92l^?rT)4Ohe20CYlPdYH#vmEYGUo~qw-p-B>i~j; zaF_C4r#A5e)p)_p3^MN|IC|)aarcb`rKnjPXhrm73?C{MvVP8(=9NL_p3`IWwtTTh zhaqCsWY9$!o3#C_#-rvcdi|i^#&r_H3uz+PbZ<^(Lm#3zsM}7(OwS30?I*zg6n7pA zhc0BlIzZ)&0{5KqWxIP{#U)*YK;DZGS7c!v^geIh^{o$=_pq*oR1AJ4t`BA)ZP|%+ zH3=rcG zcf>q?5FYO`7`HT?#VTroN>tWQPYl%MK)PZyte-}~xm%kn{(K?2OA`wW(Tge^Iec!*RF}33@o-Z9Wjg(JZJ#YY@3!!nfZ5?!=1R6%9_>W(yEob18o_#2#F-ax>BrN81D z@4EvSN~KE@do^M(t3YQjE@r}B^P3`be!O7%rJXE;jyiG4z50hEHf_@;gAKbHst}Wy zt?cW-7a|v^Yt)mreyeW>TE~WKz%d~x)#7RcbGz3Qt{wVSYTeTY{hUs)up5Kh1y*Mm z8H;U`jn8@C@9w|6rrXa<&-lBSh>C8T*%t*5C5-otAyARMG}5YtHf8ICr%$}$ll0++ z5e}#^re6|Ts`LBiS40?U=ClLowx7^`WLgj|-$lV+#fOB;Q^9Y0T5mUWI`FCx*h~H- zt(R}}JNmjUTX<0irVY-p;_*m$}Cn!P59o zSfo<@7$-dX>J~loFex4tj!Udw`@+H2_+qswyaxevZ)1-k%%vj5*d0sz|b zzMp)d`Sr?H1+bW+#8j&`B+=JT4Dn>pXSlevMwERPQ>{yW$uBP|qO}b!UaZn@4T^{6 zB2cP2$clN;-EkW@V##00(MWm#L)4G~)rzvNid0wTJAajXY=C^3x(mIktH<~CI(z#w zY+X#8rh>FTlNGCX@%H2z)XScJ6K~@ImaZSr)y;4*wNu=rRA=O7nObiDXP9pg%l=kO zJ8z+(xS?@RK}mwLJI6ZHJ23azLmkn`LTnBt>l_RL%4g%z+I{ZWG3+R_&YqPb#LYHa zJN+SVj(UTW7x}rn5J?d0)lj)tMb&8d$!8qi^aKchurh58Y)43UdyWLV4@JLjeZ z@6pi)5ZAszmon!~0HDy3Zo`ejHu<@p#r(+7mY&6u`hjT+w@@xw&YGpDFmrr~k~|$9 zs;j~ZOYcbZkO!0Za@Vw_jiQ#F!a!!f@Dr)YFIMTJs6O-DdDY)s7k%nN5715S3@-dhtbDrJ(y`M8Tej1kZM$oRwHlliYb z0BVIyUCP?`q65*P-*tsY>rpAVb;CrX8v)j0CRE*yId zSs2QylLfVO3Q_nGl>Jc4IF7Gz@6sxR$KJHG1`yN^%&f1hHiv+K*VBEQ4sn7QjTCQq z!3-;uh5sDdD%b|r-|$B|{#20)^;)w-U!L+OZ`@`v114!`_yqh!$Ul`jG>wC3o2K6@ l{oe{Ga3b&jpDX0|KYSskz+c>-wlV?ar6{i|S1n^2_CLow*4Y36 literal 0 HcmV?d00001 diff --git a/examples/conditional_image_generation/requirements.txt b/examples/conditional_image_generation/requirements.txt new file mode 100644 index 000000000000..bbc690556020 --- /dev/null +++ b/examples/conditional_image_generation/requirements.txt @@ -0,0 +1,3 @@ +accelerate +torchvision +datasets diff --git a/examples/conditional_image_generation/train_conditional.py b/examples/conditional_image_generation/train_conditional.py new file mode 100644 index 000000000000..eb333c86779f --- /dev/null +++ b/examples/conditional_image_generation/train_conditional.py @@ -0,0 +1,265 @@ +import argparse +import math +import os + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from accelerate.logging import get_logger +from datasets import load_dataset +# from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DConditionModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + + +logger = get_logger(__name__) + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # FIXME implement training script + model = UNet2DConditionModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # it is needed to generate tokenized input to train. + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + use_auth_token=True if args.use_auth_token else None, + split="train", + ) + else: + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + + def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + dataset.set_transform(transforms) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) + + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + global_step = 0 + for epoch in range(args.num_epochs): + model.train() + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bsz = clean_images.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device + ).long() + + # FIXME The input should probably select the appropriate one from the dataset. + # Sample a text input + uncond_input = tokenizer( + [""] * args.eval_batch_size, padding="max_length", max_length=77, return_tensors="pt" + ) + uncond_embeddings = text_encoder(uncond_input.input_ids.to(clean_images.device))[0] + hidden_state = uncond_embeddings + + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + + with accelerator.accumulate(model): + # Predict the noise residual + # FIXME Implement a successfully trainable model and training script + noise_pred = model(noisy_images, timesteps, encoder_hidden_states=hidden_state)["sample"] + loss = F.mse_loss(noise_pred, noise) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + if args.use_ema: + ema_model.step(model) + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_model.decay + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # Generate sample images for visual inspection + if accelerator.is_main_process: + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + pipeline = DDPMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + scheduler=noise_scheduler, + ) + + generator = torch.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] + + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + accelerator.trackers[0].writer.add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) + + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) + accelerator.wait_for_everyone() + + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dataset_name", type=str, default=None) + parser.add_argument("--dataset_config_name", type=str, default=None) + parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") + parser.add_argument("--output_dir", type=str, default="ddpm-model-64") + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--eval_batch_size", type=int, default=16) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10) + parser.add_argument("--save_model_epochs", type=int, default=10) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--adam_beta1", type=float, default=0.95) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-6) + parser.add_argument("--adam_epsilon", type=float, default=1e-08) + parser.add_argument("--use_ema", action="store_true", default=True) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3 / 4) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--use_auth_token", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + main(args)