diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..1691f22a315762e154b4e48ff0b620d47d2e4003 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,2 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..34296f3f74a4fdc83385ae540d1a3b3e91a388ef --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.DS_Store + +.gradio/ + +outputs/ +gradio_tmp/ +__pycache__/ \ No newline at end of file diff --git a/README.md b/README.md index 29d5bf6db2b13751c10bedfc14de3eb1a1785008..711082365593be2fa515ae0408ae43742a2f2a13 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ --- -title: HD Painter -emoji: 💻 -colorFrom: red -colorTo: red +title: HD-Painter +emoji: 🧑‍🎨 +colorFrom: green +colorTo: blue sdk: gradio -sdk_version: 4.11.0 +sdk_version: 3.47.1 +python_version: 3.9 +suggested_hardware: a100-large app_file: app.py pinned: false +pipeline_tag: hd-painter --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Paper: https://arxiv.org/abs/2312.14091 \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8430827d8af24a2444b65c965e81f90ad63e3d6f --- /dev/null +++ b/app.py @@ -0,0 +1,350 @@ +import os +from collections import OrderedDict + +import gradio as gr +import shutil +import uuid +import torch +from pathlib import Path +from lib.utils.iimage import IImage +from PIL import Image + +from lib import models +from lib.methods import rasg, sd, sr +from lib.utils import poisson_blend, image_from_url_text + + +TMP_DIR = 'gradio_tmp' +if Path(TMP_DIR).exists(): + shutil.rmtree(TMP_DIR) +Path(TMP_DIR).mkdir(exist_ok=True, parents=True) + +os.environ['GRADIO_TEMP_DIR'] = TMP_DIR + +on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" + +negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality" +positive_prompt_str = "Full HD, 4K, high quality, high resolution" + +example_inputs = [ + ['assets/examples/images/a40.jpg', 'medieval castle'], + ['assets/examples/images/a4.jpg', 'parrot'], + ['assets/examples/images/a65.jpg', 'hoodie'], + ['assets/examples/images/a54.jpg', 'salad'], + ['assets/examples/images/a51.jpg', 'space helmet'], + ['assets/examples/images/a46.jpg', 'teddy bear'], + ['assets/examples/images/a19.jpg', 'antique greek vase'], + ['assets/examples/images/a2.jpg', 'sunglasses'], +] +thumbnails = [ + 'https://lh3.googleusercontent.com/fife/AK0iWDxaRlJeZGIBuB3_oGKKhd5buKaL3kJ6moPp7r6svDYFkehrv5XKyF6mj_Pqy3yV-qDQQZj_n8CMpuYH_iDy5717rPL-qpXf-prcIv2pET4LDjFInFQVLoxuurB3_7fugCUogt5ZYIGlTgSbirJkHDqN5min3riiRJLd0ZuGN-ETDDCs5e0wohdX_Wl_Kv5RAdYjZqWFGfKcmzCF-1ny6bjCab-1hcDzIaokggkl3INTG23nhSLWhNB8EdeCdfkQmoF5fROfCPe0Lsvk6RwlAr-ZQ9jaszJ355oXOz4Y-IQLfWvnyfdyxQ02anJF7DiBZbfhH4WcA6faK7Pbjo7RIt-nTN6THwTGlEmxHEzrO-1iuy3j4UfMxPB7r-RjDrn7M9KwLuWSybIJ3dgSx_DF9OqIYkHeG_TSvs1Vi-ugbq9E0K20HNXkFlhe6ty7Ee5xr0nNqhD6lVr-6wbBvI1SJUQg03KoJkduYbTTQ-ibGIwCu7J24kdlo-_d1xBWC01zeW2MfqPjpdNHUtnPAF2IuldEsWMJMHEpWjmXwYfM1D1BmIcOuRLvdNEA6IyPD6VCgLVD27MdLWxdKHgpSvTZ1beTcMM_QuHV2vMJnBT4H7faPmLBe9yMTKlIe7Yf8fGiehOrfEgXZwNyRzmbOyGKfNRKiSVuqeqJt3h1Vze0UKFzAx3rYnibzc--58atwasvdKY_dj8f0_QKR0l2vzWogqh3NdhO9m2r77ni5JXaik25QQzF-BPe70ikVgEomHa4xySlf-Gr3Z1v_HWDa4kYg2KE3P0WjqUD_cmdylbZ45TIOkGAU1qmiEcTCs_wOkOIfCC58Z9P6Lff_BRxc_lhut8hp-Pe8tfRXhITYFFRthXforXyDuqzPmWBBz2EnUHqxa1aYOo4WeQc7KTXK2kF36qAzPwm2QFDreqV3QmS02Gev8MCz2U3bQQa9H8VSB4ZGhrzNWIwG1R8YU_G1Xb4BAgjYEZ5qX6WJaQrjuD6_Zw_pDxRew8t0mjj_tjCrmoZjpxjHsgtudH4IBah5xag0bGdUSThnszYJPM5g1weldimKE63HqaQTG-IN_N51nBken4K_0-liw73ABzUiA6EJzqCQKEoT2pejNTN88N9RFXXB5ZJ9x0NvtuMcy_JsrsVArfA5b7m4OGbwF6b5wN3Ag3XNQ3d58hVJ_Hw99HNIvrjTVCVmU0-DsYIu_njIfyjdqps1cyv6_f23F0X-q4ZsbooPoNg2lc3oeFtC2K58Dgr8JsBjt33Lnmra3YG2nBh3lkIycCvrUOS69xo8Aq9z2ODklCd9soUUQVNa2XKKsMofRi8ESGwuiWYKWdSI7XAXi4dbzFQhWwMRfsSAk3KGMpnUnlOd8Jx68fiGMwTKCCsgIPgo8mFZHSaN29ipzNoACG1bd0ZFX2qxa7UfMnnRobl-AcFLzTtOYD2T2PcIKKjnxy1-gG1Ff0mgGS-BCDl9TtCscAoKcXqTYWCSY5otnfVPcjgEFw_0KEmIbmrf6Rpyr7YGJFPTQfDRah5Ro5LhBIbXhFAkyyvWQZCecGWi3lRAh-pSm2zuGl0nVdykgMqzVwiat8_lhnHMBNRg9xWyLWrgw3bWpdf64Bjgshr0V1XV7kikgRDpyPMJPDNzX1jVASiL5S3dmWDnd49tdnaKBUXGAIWmTyUWfs1bDFZZoOfrmLrnLLUT9R11lMO0EgkZ48_UU7CMJKwgq6hegh-ErsV6S6_SLa2tQYqYlSNUESO7jw1w4_T9KTkM2635QlPH-A81yvgrSsRo0lq3uPtFEHP35-dfGT63yFd=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDw7CKs6uhv5cJnJ48zW_5UW3IjrXwWqvemhbJGzB_xNhHgNyt4NIP0xPkCK_nyGFEsbb-4iaeWnoJbbeQxU4yBnnuckhD5d-t1qgOuFke8NRVaGCDgtDimbwNY7Jkkc9pj3pMd33P7UhBg7rVueTL7hmAyhtT_wLD9B5a50VfM_T9ErEeYUgsiEYPoE-msuALr2Sl5t0jXttYNt8R2JHdXG8Y6EegiQaEdZjZbJxdwA5S175pCmrOvVVBH334NwarLF_HDr8ESIVEkYxNtLDwb0QPDZSDeci-JS2WZcvEQUmClbxuk3hjtOR0Rm6VX_wXo0-nczybNDdTiZFsBc1Fj1L_08VLYP1OIz1fbcRF7iv2luyGmgUijGaJLr0wtW_C65eXK9aU-NJ5KVt67JT-TncumCqBpvn1-msbzbcL8krOmYqiZthOcvnrIQHeW1jIUVD_zMZADmuRdpjxO-Dn3yC3uev0ve0u4b3vEHR4iiLx-4Jf5DrMvpfdHHIL76verfqhLkiz3gtZENg1jsRTcjH35AWspo3-lMJnZqMP9wzvw7ubBBP1QrblSwfflz4aAIzmh-WisQ5UWMSltQ9DwpT0UggiXe0JbbtloWkbo2_VpaaMJh7VhBQbvRbFUxCm__UPVfglTfdiNv1m2777oGwyDbv682I_qGDW5nG91D22pqo9-enRGLvCnN-STKtWTnJQ5Qod2QYRAEI-1IR9_h-UWtCyBLLpqcGKxkHaLZvjpDmTdiPhVkoG1irracCLbPJGvrrclorr1k0nqTIhvVVH8pNZdx-yCK6KFFGNAz3aSsmxGJURWEt4TQEVDLfet8iFzuLfD2Tg9gJ4kozB88G8PrfLnGopwlPO4y6GkcfJggtshQr7yuo3xnxtci1FlLhOJNkMKCg0xhL-tVDIfWsMCzSIk08XtGgitU7DK8CE7DlkA7chCuq2BHUTeFCEALF-5DLFxa-3ru30gtNsbd71sNH01S8qhTVegWM8yQwwXXhIVqoWM-e-SSUyQldgklvatev1V6iDxLh9u2nbhoxxCcVZvyKx7rUyhSEEQzXF6nYNQn5HaAmp4jW1rCoVVvVMNP_wfvY76vamAPkw22z_nylCpbKW0BQmTBZStyoB4u_lZANVkL0fbyWoxyfZHKOoDgyTKdzDWmbeiIPmr7nvF8QWYsq160pKzIKIhj9ccF2NqyXHqjaDzmqVfZ6EUg0sJ8CeiCUskbSnyxZHVeM5GgG7cj9xMncUUsdvd0WTw3g0aqstb2MAxF3ucO0EXP8dAWqU2XrJ6E0rA51Jdn16CW9KV5HMGEcB5Y6bN-Md8-FT1r_0NQq8ZOJ5nmTzg5JNSTEw-FOBwSvukEMDmbf6ADs1emocozcqk9KiYv1ii3niD58XPIhQWfBcgyYBsqbJQ5x-UqijqcihN7i3RgQMOVJem954MKU8D-DzSet2FUcGWseyzF6Sr4GJQn5g0rOXtFP2HwT2kTXy_pqad1ukQkovfQG30gbRvIRAzMmigVJskadvadasF7Lc5eo5Pm7CPPZQ-ZJUlZuowbagWm0Kz_T-PIreVNJ5WxxQ4w_HH27QqKxe_LZvHJ9y74O8oVCywVJeKQirxDc-yKuUNJ7vagJ5a-paB_cXf6RECpu7caFM6U3g43xVMtxDf-3d0r6G7zs0oPHb4uDbovgLzqooAnjNY3_RGWldXUks0pmCqFbEhLgl8JJd4Lyv5mt34WTDqPggRnraAV5AEtZ8NvGNDWk-ItblCZQ-neXYVj0z1p1rWSJiG1XaDa8ro7G9-7XJtwfJDnFZsf=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDxK7tdzr1tx6g49uS-pyR-7BdPJfk0_wWoErGJiQPMRbggXquguwth2I35go_GFsW8TUoj17p-jtCKCi0ryGH2gTGNvZEki8xfPn1aroIxOXRt7Ucl64Lu2cyixNSBqyJUBQEOL6LzY6DlwXJ5SVxRRvJkj1Vz-yJYz4mNA12YdDFPj7UkLs_7PEDRanydfEVFhIoiMrYitDW8fGsqn53PHaB8GoxoQHAXFyf82HPfhqqgKNRUR9yz1LAN5q72ERzxg5h4apStr16aLzijuuJF7wmtltf3HdTIMwAAd8Blsgk0d88rNZzItVZMdCeHgAyvoD7JKUCTWWJbu-uXf1PRd6mIg11OzweG04c-uloZFvJdF0pdfcfocibniKt-filYASFN43KYtO7Eyzc-YUHJn1qm73eDr0RkntHy1kQaDPmlRvlZrtbDgZIEau_FsMa-BRNq_jHlTDkyRR_8dyBf__na0I45lcChJWZCOfinbQwzQfiryK13yBB1bQDGk47JM5PTIEBK2JvsOEwezA9geZWhi3oIM6N4b1YXzvWbDYbN-cQvrBd6cuXoecUL-i_Qqda2-xGByMP4BUI0J1qF3mKwzTLvzZAk1f-zCUNHpOrzF2-WsXCUoL4R_t7MOZ2OeT7a5pCAnyiFX8Vdq3x9QzrJjjTqvttS5ElXVUGi5NBVeRV0hbXI5XLMKmiu-lWV3625eosYK0FE6hmFZk9ZAeSkrLDjkh9k1XwpB2x83kcpC_tkk25G9czy6utoNhC2YHc1uVvnUjZ0DjK2d0naLVOCKOCVRcTSvN3kt7PaLASArW5vOzZX0_LcjuhHSzK63mok6XvXXZ5AOiF2fDhUNUl1W8TaDz7aIssbibuLXxjpluone6WJglRrxAkpwtIw2rwk73icupAeVG6Frx2QctrHC0vLG4PKSfrIrZRiJC52Oxpp2dAMnhj76CIwX9wwbj86uSyIuxeIviiKumwTolZjbOKhPrHY-ZOydVDn_ZUsln4mOfDZXwUl9p1CanzpLCUsZTXcq12zTbHJRPwNw6SH9srQc3cTYYsWpBu77VmS5zVkIndUPItXWmUqds_2AI8LCUSWE9NVHiCRSw-B8J8j3SkqOD95-np2cMxDxoVX4nD11CwBr2W0xWS1kqc4mZL4aHdwhKdDcSKk6EGG4kmY1Eq1RsYxc0I08TJK-_nrbWTgA4NDTjh-oJpFFiF11ZHbEksKlSWBhH-MXF-0zxmiar-EIhAQe4hQX_suDty4GXxMwzcF84JthDAvWC0tGLJV5WQGLkLvBkTODythSNoN7l5HjnzoJf2JKI5AF8W3HJJWgntD0F2iFe5K0Ik_huovMbzGjkgPZDvWzrYpA2V6VZyPI2q08axvAfaevPrCd7G8Jy_gliK3hj3qxjIvqJXBUSO2puqum-TlUaSYgbjhWUCLXKE1BH8RRDt0brLklVvNwGpHxG6Eg9vmzp00fV3qhjoJqokTYOAxcAumtfUpyJFVk3cvpgm4826o4kKUM6eXRz8ke1L34YI9Rtf7Ppk2864hIBIy5xA6ajIpI_rIXWG8ogrhEp9GMbXGkLGiLMfd4t7P50JynPc9hYWfachaTlQfQWuz1NtV3ZGdTaEe5y_SpEwWG-UQ-_MglvY8p0EYNjGc6yhu4oEx9D0RN5H4QA2e0YjmVybG05Hfm7LwEBMBvHb8GHnpHGTxL-WDlZtU8DqDmGmhv5l4npcLnuVmSE4JcEtRJ651ZYn88wD2ghTAJtcEuyFW3nzZ2Na3z0LW4Y8Y42YdWb3hmgGsJAYBIa9YC_Cmh=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDwPG9L-g7hX4NKVZ45PuqnRkQUV3XZlMEWZxXASV2zYVytOhbhds--yBA0ZXUzxeEAbpGa3qCl8sXlu-of8ZjVVNSdSen-zgpoN0BT-R7JGmTqjT9aEdhnysy85Gqr4e4A5KPrInLLvCCkFSHolBEhu-hO1u8kEZ7aDuk8FsPvchP1SvqjnuWtY_OCYa2FpjHH_i24cITjs58nlDxNTFdQQNCnX4KLJQAzh5cKxtG_7pqoBfBzpXBvVDwoxPafbFn0X0u8oJ7V-VOh7faO0JtJrsfcdS06tvw_J29RbDdWojpUS_pRUtA8w9Z6flNlbShj4Ib8X1V6veQLrEYWi3SojO5tfntpl5bKXPGgWhsT8mykvfhg4Tq-Ti9kyPBDPQqEqf0ll9-wFHYoAWSMsmCvZ7kakuDMH1766rOk8QgqYbDr6kMuJw_OFBR2qX7DaQw9XnJqGv7J3guCj-yU5vXes4sOtZ3n1IOheXnlvJL79KhXbsLgznYBjC9uv0fDLuqKaFL34YisLY5xY_2zi_uzTc5BXmtoAFH8otrUVUTVyt6sEPaCtjyPzG2uoSq094QFY_FxagW0E6OYPqskBUwPzg_Wc2eBwazaGy4MXEoDzgvK3PId5N4MWqU232uHsUrEEaCUUKm8-KX43c0C6O2daqjwsh-bxKOIic3pHqDoAOcq-QA83qB7pPyWwGddsaOWRIdzf-QLrB55YuvTeTmOEL84m2YDtxNEHWdYcnXYlZEXAex1xMOfqkbCQGM8jSgC2di_794HKMpsNtKwYH31WoI-Pl73t4THuq9CX1pWYdjhH0ss1j3PUMJ4hELE187E2m6fhQLHNNSHRajfesIwPVP1FgP0W5o-AKLaC7o53R0XrOZ7SqYjOua5TX0RXyJE6ceZATyiZ_2tzUpc4baqRGJb88vr8dhfECh_1J7O_8ufMvYL47HWhVNqltcIGjujtIXp64XEqShlut8TkqtB1-3OqE7gNwfiP4859pPkk0q-kfIdaLCjqB6PeqNFNgcX0dfK_-weXpM-LjrgNM_aVlQEBKwTDjJteUHY3pFTrEU8FoPnPeIjbU_rGIcwA3Lf_NE0CJUhKG1gzKYWG00CCg0srvQRipJpVrW-4SkUrxtW28iq6kVRMER83sN2RnVi7ugEuZ3S-OPdgXkUglk3bIz9ehDcFfQrll0iCabIGQxrN9-7GfKSpi3j2nU0aLeZ6DxgQ7x9f9f-hUgxM97i_90SX-S_M6Lo6bA28RB515HqqXc8FN72Wsp_XlFp3oTb4cJtOI5SSUYkdxGtrn0AQWXuE0Rp1DEaUDEdVbu757FQsct4Us0jLByauEfiZcQS_lGTjzGQxud-2NMKIeIYOWAxBGz08eTWv4b2k_IsnekNnXAHzP5WCYFoNtDiGbgURj7QWz6Qh8HbTsFJCrI3mPYkTN2jMUwviTUKpeElUwqDeA_DhT5tiHa7ldtdzncNzpxH-J75asams5J_O2W5dJMN4PYUxGmVw5mhWEClFo2stMJOfPkrmaga4D1OXc_C3Utf6OWB5CBBHGNjfAekE3QWm7ibEtwC91g1pIujyCUEYVs8YiFi0RWcMsmhWG2yrghA9Hu3kERWuVT0nHHfLRx8L0_PjlQBkNijUjK_gI-C66729qLNsmdwZxym1JCFgV4xRT7Vu9EQGL3tyhbOLRWYlHAjBE6itM-DM5T7idIyWamFBb9Nt6ZFehCpslKzEHy2VtEyiRlZ2z9kH-IEIuZ2qCu3kiGL8m6yQTyboTCI7LNLYbSn497VR8h3WkcqWqlVlWOzGUikL=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDyoheia4yOtefRVFWnA5EHolnS-xa4pPy0yL4Wb3uIr1-mSHxQ54i0wKr3lk6uhAm5qjJ9UVqwut3UqR54log5XVnaeNu_3y9Gsn44quU_HsGAY0-84HygWkr9Ld0_Dt1JEefD_f82Ijp1TXQf1CS6IbruUOnrljFOraQ2Bu-1To3p2Pk-T3tU8xVCxU3pr6zvFz9UYHFTss8_Xw70ZtLRMT4x4suHtOaSPI42VTq_T8HEnm04Ie_0Yh0Ri2_P-qsaP2ysz-Wnw4Ykbj9nc2VIDqtvCRwti36mlNheyg_8xOLD9sMNWDu3PXoRtn7aBUpw50GMCqeGUMAUMPKhJuZoTDdQK2SVHJ8QwNhYcC8mhAmRFvt85hvuT6NyZ00SeYwyj9_rux77vZThx5ioDoUAH3CQBcwgH82xahatReym7ehL3DXm3JLHDdQLbRM6xvsE0X0MsMXkuNlx8wn6HyteW_yq8fK3wQJ_3XLh-gK5YOFdvd08A7IIK6qi__-o8nvEK_hfHgMMS-O9eT9acfa2Sr0rGGNvoUlpljVOyONyQftP2nGD92R51K2Xcq9oV2sjSu8TDDel2t2DYr5gMB3FsDfgEhQEWE-O9fRLkzIZnOTTAcUDoS-b3R_kB485Ry56FzKFbz3w3tvHhwJ3W-sqvygb8LDoF3qjURWrf7Pau6UMjTSPH6FTjbVzZeWITKsSnA14xA2wj6xi9Bp3JkCOsT0qOfXrPkK7uT3H2U1M8uqFjpNTj6u3tZyF8GgueprmH13rJjYst9d_vevhpXpXIhSDvAbJuA3xG-YNr_SG8BMXpi7N3NC0VHmBXhl_wDBVUnAD6VVqqNtXzB-6NdZzjZKnxApDdi5SGp8C9kDd6bkaXUmwG__BRNVqdbchMw3H1re5t9VxiqWTelfGl6UqAX1W8RzQR21bgu1x7EAGbVsC_UpMxeDeJq9PprMF9cCRC9ziT_H2-ubctN7O9qPpADPT0nqWN28vH-9CqB4jPeBqYwi658twIpLwRFvEukajsrvb-OqmeesnT9QpCPEpL4G0HrjB9rkRX7g-T6q3kqbGfvnqgW4Q8ilOUnkEsFK3qLCIVxwp6B3t2yg8XSOsM3tnzsA1Ua8MofFKqvmwaq7QbIBcMOHa50I1fxRMs4YEVLgu89fSZxRrKSr_8oUXcRWiqSgF-pLLU37GYMrn3yXUJxxO4bXiEifeK8id5H1khO-8ZEBXzwuZBQBYLXbCkCou6enZu98tfPA-prr_NpKfSZM4e0clZWhjo1761-BmJSlIG0JrTo2N2cKhVz-WM5BZjVr1FPYOri5fIjORUL17_RbqMw5MefYN6tLPnpVrOSUmKW9bVgdFdOVpj9Wg5lZxxswAM5qK-wOzEjfdBCW0xjKzxD97zhszCKxc7Rj2uoaJzk9CeauU83LYcihkyMHn5IKhLeAou2yKEwgqXkU0LUObdUxqjavnVgVMcVYRds_j7zpmpKNT_KOV9s2jus8aptJl8sXZ_Gzw0vi7wC_AuAGfmNCsZBEFhn_b_ZgONqdaR9EKwP0hVRDNw4dZkYn9MGqeiX77I40eEbnwVbKaUZvK2Nrt7ukjkgSP6FdvZFfs-aVIUFMc6rBAknAFDFHPzFYcy9ANPDgVAlms8fO5GGuid8kgpxtjSoneUG_A9Az6JY-suY8zFr6mDtJC1LuY9ftIKW37ZaMtutqhbBX5b5w4DLnO32Uv_ZtK1nbV8E4T0KY6qHm2kZiIYuYCsHysbvNW4MODUihrpDh1UfILdMEsae6zGpCTQ7gnKlrB7QJ5Ig28y96Om=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDxriikbWcV-sJ5xBcy8xcJHsRC9EEBmCimXzsrBiSOy-UnRwSoGhdXNqnB92vMZk1LiOn0h3KbBMgbFr0I0SmQG5TwG__bM8AvtMrXA-DGAaNTktO1JcSPb-wVr_OepLK6P1hHyGYSvcGDdF03pIFNcKxN6QLCZB6rgFLdaWd72z3Dx8eB8QLtU8P_4G4sT2oJ0hAUlz9mKz9lTxWDWl0-1ufKUctmvWfU8EjNuQQzckJA0YwvV2jDl4ZA1r42UHDss6dy6hkjmPoDZN-p2UW_3Ju4vOVAdv0Pf73demNu-L1LALuK0rq4d7OHUaAuP0bubXJAH-wsuVervwPQDEsmBwR-FdW8jfdppKxy4MC2ISf-eyLmsTYR9dLPIlKkOAHh_84vdLeWdGtxs9gES-jhrqOiW-brFtIZoKbH_nR1yLeq9IJ7Z7-GJk3PVi_Ex7gT9WJyIuySNi6s7GH6AnDFf9wfHkzyJ1qLDKMddrNi4GfEyO89S0yScxZFW7hERAH7T2_1YqkeMv48ik9dMA0RcpJK0LYA8GDD5_MQaycXUjeoCO5tvlGQUEE8Dt815Ev3xh9dlKJiKJf_cClK_kL7iBICasjcVSNxP606Zn1fBTc_hF3QymDP8Q_Xl-g9p-ufobo_x2x6nOOdfiq1q37ik_3kPZCPPCnI2OtM3EYZ_yVlFmwWbtqxp2Rz3jaWo6t1TL8PWkFZno-aqJ8YEm4ppZVyX7ne6GTFKzFpW1SKnAnjqu67LjS0DCFhFATAcYhWRmdb5cMXve72eU5DNLgKZOaM0THDa5dnQPaRmu-c-7HlB3WISebcJHj0vIDw7DLwxaCnqqLgybvqW7O-Rt82bat5Lc-jIyMjjkZOutnc3OoYTPRN5PrLZ6HXGFBTq_s5fCimxpXvlw0bzNHqOzovgP4NC6UChXwn9CxSrbLou8vuqeD4YqyjBhh5Do3l7KcGMZtYUUMUhVf7fUrZIJcLb-ZBrg6UdiPRc3h9jcubLjXcPIrzxZeqgQ2ccJRljZqq4CBoX8WU6AiEe51Hbp0C96693G7rVomhzxa8JCMCCL3sy8v8nl1tqfqQ3kE53XnvzuqNMzJCNIfZg-GehqMBmJZp-Vaup9DOLWYWnzRqYsO9pC9r37Ajhh-wkxZyV0XhXD6EecYGcXBX8TOQRF9OMV9BYRDYVnpnGfTBGyxI7lkCU3prZQL0H4JlS48k38oWj9fE1F-PnGysDXLAxYqEERP-AXIwgsPyoKYrq7mG5UzBcEn6qminlmT2wM-EnInAb7F9e2aMxFwWv17RG2fWkO28p4t9hiMEyvAQfYyWKhGui1yqPnHnytNe3BJspV0uekBUBauvvHHLF0_tgHTJwcU03KlOLto7iIOUuwLgBT3z-_diWu_w9WTiOqZppBYKHUVfQlR1JsFn0j0Tg-kk5NbRfRU5BB4nMtQlkpW3vjTrQW88SrSKNOr_vepd1F39EqxBNG4tIyR6lvSCSGefWzUxwHBoHou8MtGPvcdxInB7imHdpwCFtoVy0wGh2kXy36CyxXqV82VehVvmfc8AnNtgVWTPJ8vvs2AQ2WL7xd7lE-DCfy-la9d31t6iLxVAkq_S067Z3iyTiJg508a_BGt_UF41VW2m2dW2v8KRjgXks4Wz7kAvO7rLLhL2CMdh-W7bBP4bfVa9BfcKVKGxaa6kCQB4rXtFs2gEbWx4o6iXf6KIUgl-RAdAgYYjIsavlkyihtgQ28-1--7JjaeO61d3wLYgJ5_t_POL8OjmFJfPeO1m9HTzDJf30C5hVP348dRhrkxSTPlSo=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDzcrCgKc1IMiriL10KOqvN79VCXca9U5g657RP5HU1zsUNVcZPyefBpRCtevRE0k6FkGxUa4yKW5ELtqaSmVRE8R2jwZPHmpd0xeBvIHYeoySmntJD3wJl_iC9Ma20qubbH3OFIxNFFLCXMJDPY9cJ2D1xFzjD5t6jQi7tJKXc0B5W9vGTyQ3dzHlZuLEfTp4D58WszOFBLPsOq8zeve0ej-2wCEMkCrT8kTwfKTnsi_GXRpK1xFRZjczbA9kSwIxZ_x1iWNWW_XY8aw5Kgn5Du-4r3rBqpHr4_fzv5ehY20T7Cq6Rf15zCBnBad1HGxLLnrJrXujKwxHELaRfqXPItfQAmuoIfL-tgK6giSSNKDDT2Ynn7AXxHz5vzD3m4ALYWql9Qpe1Jw69AxGZt3vyuSC76LPhMmNIJzVIXVMLnw55hHCLH28GVg2WxLTLyUXoX7o-PHhFjGoBbx3X58yQjyrEQPbfmY3gtuijG_vrjoz3SPcoG4eXB6d_NnEM3b84Ml9eqE8BjlACMEJriwtqLEstPBhIsJygDZKB822bY63pG2hYzYi5-bWS19NRwun-jdNraiye8D9AIqxrXaknLvbCnWbgbKDJzlRnt9Nz8GJKU__oZ1wQgzq2DOQqEoJxPwpocJxwYjnzIb-dFqhcLkR9SQVx7rAEyJ_u23WBSp4AVTw6c7sYxd597C5EFBvzOv4qYtJp1G-hBrvHYFvYD13oDtgcFDO_F0nXg7XwmCJ7aQ28Hp79dWewQiSq0nBcgWSOic1Q4feWLFEL21Dw0pFmSVF9f3C67YNA7ZXukOAnv6xUN5RNVoBl58-KDOFSQZYysT8FGYaaY8bFRvYvI-VXcBtMuAw1VYMKLKHrhg2mSvvdPCxsGH_rQWijgGUm9pvqbmyrPzYRd-Tr_i6pKH7EMcBYgcvZNMk9_B2EMj6NUWLVkHEUTDiovgFQJDrYAuiUcKZgp_MXJcTDO79qJEsiP52C83C8Vg91w4t_Q0dPklY8EjjpbMZGvsr5NC-TlQraNFszsrUA_JsNdnMW6Te2GMnXHDaHjEDCdX2kY3XC1Ltzvl3f4tHY12OMOUzMKHFWHWyBNYNKlHnwPdxYNxvHKa-9p_okvxw547oBfsXwUrRpQQVxljmLZJbbGpxkfbEW6Rg70MRKHdbEUg0h0FS7F90kuD7pR1zUFv92fPoI5BUNUn2XjQb-DZZaC1oF5VtRMgO8RSFM_Dolrx_c6ZrPLy7bllakKTDj48CNvafL3UxaD6x9FFu4dNQVHKidcVS__EZ1SYMSZoDkD2sna5Pzpjl0Qz61U4KVAz5lSCJbF1stdiwO4jMwzmAgMNV5-fnJA3kkf9dDaIzqk2diKUh-WGSOwEwwizH54Y-e2EASTqCzGK7FGoVRVr1d1PdN7wcd8MuRMXqBfonrIGmf5cSiuvOL6odbSChO2WFOKkRbBHeV1pO46uaBVeLKGjcbAALymYsv23_veIW-RQdHnaVlvB2HE7wD1afI-3LaRKCjCUD_X5QJuTXff_EQayDwtW46_0hYRI76LGXjQuOc_LUdIi2QHlO4kt2kR8eLBPm09gKl0GJrZoS9HxKS0sERHCjUvugFcvVmgt7idjFKfi0AuUz9XmqYCoSRiMYj9z612Ot5L2D00SLTOcTLV6nlx-PeBpAyiu2ia0ehvVjzLUn9ai0_XOg1bU-ab8fvzrRJIrao5ZxdRp9wF41lJpkanwvurynNJXXu0uk2SoA1_soIeLshOsONxT4DG_PitVFMYjxY8rgx7hfNBJmAJW4GZPsnx1_P5Ojip=w3385-h1364', + 'https://lh3.googleusercontent.com/fife/AK0iWDziEIX1beA1lUdgcqMcnVcmRuODfH0IHpIkZW5YYZzhQRcRmYCfi9N7-vRWfcfMTuR28ZWDA5EngVjUpwIbVxRBF2DaIb_lJhd9zG3arGxRbi7CwmWdhAXeODEvniYR-IxtWB9lYpNd9hJ8wdleTP_ai10Xscy7iFeXOmFZ76dnr0r690LnZULOd7iyGv9EWZmhKhw4wtEJsqi3e1Yu4CXRsVrM5KLYKG-EWdRW_-m9H3o0G2W7KvOwDvqAewIz5zApPBHvuEE6x-XaUOuF_FuYQVhsKpfI_1Y70SCjOCphpbn51Bv8idg0tgTDn8oL8hkvSl50VqgQqjCLNxmCHlQE88xmjK_4NMI3kbIBLWfiPGCURr95dZt8eniqVi7yu8LNgkaMixAdRBCrQF_z56EsIvkozenXBdS2FiaUEh7LqIHLBcOa6ZaxV6_3t5Q83wgJTaM4cNkvH5_nCeQ9wkwKjf7zcBxFusa5LvhM-qSm4BJz3WzE1zgqTLVnDeh-EFNPMilPhevOdBuNfTY_VF8chvWS5Nwwcxlls8xSdVVqblYGw8YBlzWdi_X5PqynTKn6aWE0IiWOzA_O0hg2q1FtAHRT-PaINo4wjIbBar6fiNNwcZIeTKJrijHcpkIhnI8PHxrUtO3s0c2pfLFvuxCCRMSEfxpcwt0rz-ODEWIZkALajwE2SyFV6Qioc4fH_xWnI-jgRvzHjDf2c14vx3bXjM_gy-25mrECLQYcSWWZVINUbvKf6_YQDwwzAKL9zhMpyGa4EToTBhMSmroGi-NwIPxh8gAfdBCAh8TFAdg2aA7D3N_KpAv4Eh5bkovhCiALFYkGLch6KogZcn7NU3OX8qyn_wJ3oGO2CFmfKkMtLHqmjQpnLtM1U9BPnRELir7pNyG9bTNzs-Vz7-Hzu3vJavGeRhypl0JCoGfO08be-ee_7EnUcKSdepfd3dG39Gc1eulLgIVCRb82Ga5mAkNp_SDSa2BGI24--uOyAUwTBazpQjJ25W0wsHpLRF4obk8Tygl8Fgt2F-VPXYz1-q0x3_KZVWf-PJmKjYD6t3ICuBMoFeJtQUxp88WlSKC7KvhEZYdEaHmEabNNK7j-VTAgi0BeBaw_dTTO0tad9rXbCW9Co3Tc1YXv53oz96VURj-FAKHk_PKPRSV7-NO-BWAk1DOTq3ZDnlKUTA5-x6k4IR5HyNzW9C7rIPGzd_PRA9ddSiRxOjSiBru_P8xS0zQn6p75V48ZkoNsLPWEWCKhANJOaOB7Y01pg3wjjnftuxkp0KpokrlCZVUn2eKPmB0Oee6TP_6DVFhgM6ksqLHO-sNxpehUjWDx84znkN0MihGRgl6TK-6xnWzD9tjvIOsK0mBzk_XY3Vuvb5OEZvLzDJ5POqNHjLcAFaDtX7gsAUtEWk20qmRbpGBnHiZv2kLOUWCy6ICkc3yFv5uUMx7pxgfc_YO95ybO8-FTDG7m1yaoz-WdLV3tHao4_MfFaRXGKtV0_7xnlyXEZ3tMYwKu4hRx2lIOsL4Aff_O8-H0jmJId0llt__iOdVDkuypQWQDOKGGP9B1_gfLkV-ymEP0Bl59jQWNnAqE-jUpTeRRcUB6FkcH8XBPKL7F9N0sq-6XeOmPPpsecmm3SflF6zJ1YV8Uv6H4_9_uQLVBB8wXSvtcQuwgzYnrtpjpMQwFqSvJDhcCPGRfRCR6H7oa-T_ACYAMcICpl8felwVUOQs4O03ywLHNrZBY05hS13cj-_aYw69kw9TdetT-GbvTKC6eY5uwBTq4ytb4eeJQJc4zBlB2Dw1vKmcgIFfZ=w3385-h1364', +] + +example_previews = [ + [thumbnails[0], 'Prompt: medieval castle'], + [thumbnails[1], 'Prompt: parrot'], + [thumbnails[2], 'Prompt: hoodie'], + [thumbnails[3], 'Prompt: salad'], + [thumbnails[4], 'Prompt: space helmet'], + [thumbnails[5], 'Prompt: laptop'], + [thumbnails[6], 'Prompt: antique greek vase'], + [thumbnails[7], 'Prompt: sunglasses'], +] + +# Load models +inpainting_models = OrderedDict([ + ("Dreamshaper Inpainting V8", models.ds_inp.load_model()), + ("Stable-Inpainting 2.0", models.sd2_inp.load_model()), + ("Stable-Inpainting 1.5", models.sd15_inp.load_model()) +]) +sr_model = models.sd2_sr.load_model() +sam_predictor = models.sam.load_model() + +inp_model = None +cached_inp_model_name = '' + +def remove_cached_inpainting_model(): + global inp_model + global cached_inp_model_name + del inp_model + inp_model = None + cached_inp_model_name = '' + torch.cuda.empty_cache() + + +def set_model_from_name(inp_model_name): + global cached_inp_model_name + global inp_model + + if inp_model_name == cached_inp_model_name: + print (f"Activating Cached Inpaintng Model: {inp_model_name}") + return + + print (f"Activating Inpaintng Model: {inp_model_name}") + inp_model = inpainting_models[inp_model_name] + cached_inp_model_name = inp_model_name + + +def rasg_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps, +guidance_scale=7.5, batch_size=4): + torch.cuda.empty_cache() + + seed = int(seed) + batch_size = max(1, min(int(batch_size), 4)) + + image = IImage(input['image']).resize(512) + mask = IImage(input['mask']).rgb().resize(512) + + method = ['rasg'] + if use_painta: method.append('painta') + + inpainted_images = [] + blended_images = [] + for i in range(batch_size): + inpainted_image = rasg.run( + ddim = inp_model, + method = '-'.join(method), + prompt = prompt, + image = image.padx(64), + mask = mask.alpha().padx(64), + seed = seed+i*1000, + eta = eta, + prefix = '{}', + negative_prompt = negative_prompt, + positive_prompt = f', {positive_prompt}', + dt = 1000 // ddim_steps, + guidance_scale = guidance_scale + ).crop(image.size) + blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0], + mask = mask.data[0], dilation = 12) + + blended_images.append(blended_image) + inpainted_images.append(inpainted_image.numpy()[0]) + + return blended_images, inpainted_images + + +def sd_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps, +guidance_scale=7.5, batch_size=4): + torch.cuda.empty_cache() + + seed = int(seed) + batch_size = max(1, min(int(batch_size), 4)) + + image = IImage(input['image']).resize(512) + mask = IImage(input['mask']).rgb().resize(512) + + method = ['default'] + if use_painta: method.append('painta') + + inpainted_images = [] + blended_images = [] + for i in range(batch_size): + inpainted_image = sd.run( + ddim = inp_model, + method = '-'.join(method), + prompt = prompt, + image = image.padx(64), + mask = mask.alpha().padx(64), + seed = seed+i*1000, + eta = eta, + prefix = '{}', + negative_prompt = negative_prompt, + positive_prompt = f', {positive_prompt}', + dt = 1000 // ddim_steps, + guidance_scale = guidance_scale + ).crop(image.size) + + blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0], + mask = mask.data[0], dilation = 12) + + blended_images.append(blended_image) + inpainted_images.append(inpainted_image.numpy()[0]) + + return blended_images, inpainted_images + + +def upscale_run( + prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index, +negative_prompt='', positive_prompt=', high resolution professional photo'): + torch.cuda.empty_cache() + + # Load SR model and SAM predictor + # sr_model = models.sd2_sr.load_model() + # sam_predictor = None + # if use_sam_mask: + # sam_predictor = models.sam.load_model() + + seed = int(seed) + img_index = int(img_index) + + img_index = 0 if img_index < 0 else img_index + img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index + img_info = gallery[img_index if img_index >= 0 else 0] + inpainted_image = image_from_url_text(img_info) + lr_image = IImage(inpainted_image) + hr_image = IImage(input['image']).resize(2048) + hr_mask = IImage(input['mask']).resize(2048) + output_image = sr.run(sr_model, sam_predictor, lr_image, hr_image, hr_mask, prompt=prompt + positive_prompt, + noise_level=0, blend_trick=True, blend_output=True, negative_prompt=negative_prompt, + seed=seed, use_sam_mask=use_sam_mask) + return output_image.numpy()[0], output_image.numpy()[0] + + +def switch_run(use_rasg, model_name, *args): + set_model_from_name(model_name) + if use_rasg: + return rasg_run(*args) + return sd_run(*args) + + +with gr.Blocks(css='style.css') as demo: + gr.HTML( + """ +
+

+ 🧑‍🎨 HD-Painter Demo +

+

+ Hayk Manukyan1*, Andranik Sargsyan1*, Barsegh Atanyan1, Zhangyang Wang1,2, Shant Navasardyan1 + and Humphrey Shi1,3 +

+

+ 1Picsart AI Resarch (PAIR), 2UT Austin, 3Georgia Tech +

+

+ [arXiv] + [GitHub] +

+

+ HD-Painter enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. +

+
+ """) + + if on_huggingspace: + gr.HTML(""" +

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. +
+ + Duplicate Space +

""") + + with open('script.js', 'r') as f: + js_str = f.read() + + demo.load(_js=js_str) + + with gr.Row(): + with gr.Column(): + model_picker = gr.Dropdown( + list(inpainting_models.keys()), + value=0, + label = "Please select a model!", + ) + with gr.Column(): + use_painta = gr.Checkbox(value = True, label = "Use PAIntA") + use_rasg = gr.Checkbox(value = True, label = "Use RASG") + + prompt = gr.Textbox(label = "Inpainting Prompt") + with gr.Row(): + with gr.Column(): + input = gr.ImageMask(label = "Input Image", brush_color='#ff0000', elem_id="inputmask") + + with gr.Row(): + inpaint_btn = gr.Button("Inpaint", scale = 0) + + with gr.Accordion('Advanced options', open=False): + guidance_scale = gr.Slider(minimum = 0, maximum = 30, value = 7.5, label = "Guidance Scale") + eta = gr.Slider(minimum = 0, maximum = 1, value = 0.1, label = "eta") + ddim_steps = gr.Slider(minimum = 10, maximum = 100, value = 50, step = 1, label = 'Number of diffusion steps') + with gr.Row(): + seed = gr.Number(value = 49123, label = "Seed") + batch_size = gr.Number(value = 1, label = "Batch size", minimum=1, maximum=4) + negative_prompt = gr.Textbox(value=negative_prompt_str, label = "Negative prompt", lines=3) + positive_prompt = gr.Textbox(value=positive_prompt_str, label = "Positive prompt", lines=1) + + with gr.Column(): + with gr.Row(): + output_gallery = gr.Gallery( + [], + columns = 4, + preview = True, + allow_preview = True, + object_fit='scale-down', + elem_id='outputgallery' + ) + with gr.Row(): + upscale_btn = gr.Button("Send to Inpainting-Specialized Super-Resolution (x4)", scale = 1) + with gr.Row(): + use_sam_mask = gr.Checkbox(value = False, label = "Use SAM mask for background preservation (for SR only, experimental feature)") + with gr.Row(): + hires_image = gr.Image(label = "Hi-res Image") + + label = gr.Markdown("## High-Resolution Generation Samples (2048px large side)") + + with gr.Column(): + example_container = gr.Gallery( + example_previews, + columns = 4, + preview = True, + allow_preview = True, + object_fit='scale-down' + ) + + gr.Examples( + [ + example_inputs[i] + [[example_previews[i]]] + for i in range(len(example_previews)) + ], + [input, prompt, example_container] + ) + + mock_output_gallery = gr.Gallery([], columns = 4, visible=False) + mock_hires = gr.Image(label = "__MHRO__", visible = False) + html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext") + + inpaint_btn.click( + fn=switch_run, + inputs=[ + use_rasg, + model_picker, + use_painta, + prompt, + input, + seed, + eta, + negative_prompt, + positive_prompt, + ddim_steps, + guidance_scale, + batch_size + ], + outputs=[output_gallery, mock_output_gallery], + api_name="inpaint" + ) + upscale_btn.click( + fn=upscale_run, + inputs=[ + prompt, + input, + ddim_steps, + seed, + use_sam_mask, + mock_output_gallery, + html_info + ], + outputs=[hires_image, mock_hires], + api_name="upscale", + _js="function(a, b, c, d, e, f, g){ return [a, b, c, d, e, f, selected_gallery_index()] }", + ) + +demo.queue() +demo.launch(share=True, allowed_paths=[TMP_DIR]) \ No newline at end of file diff --git a/assets/.gitignore b/assets/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6ea8874968d000cd47f52f55f32a92f0127532b3 --- /dev/null +++ b/assets/.gitignore @@ -0,0 +1 @@ +models/ \ No newline at end of file diff --git a/assets/config/ddpm/v1.yaml b/assets/config/ddpm/v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..95c4053aac12d443ea8071c23f07c3d1a8b97488 --- /dev/null +++ b/assets/config/ddpm/v1.yaml @@ -0,0 +1,14 @@ +linear_start: 0.00085 +linear_end: 0.0120 +num_timesteps_cond: 1 +log_every_t: 200 +timesteps: 1000 +first_stage_key: "jpg" +cond_stage_key: "txt" +image_size: 64 +channels: 4 +cond_stage_trainable: false +conditioning_key: crossattn +monitor: val/loss_simple_ema +scale_factor: 0.18215 +use_ema: False # we set this to false because this is an inference only config \ No newline at end of file diff --git a/assets/config/ddpm/v2-upsample.yaml b/assets/config/ddpm/v2-upsample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..450576d21d0cb33958465db0179151a521828606 --- /dev/null +++ b/assets/config/ddpm/v2-upsample.yaml @@ -0,0 +1,24 @@ +parameterization: "v" +low_scale_key: "lr" +linear_start: 0.0001 +linear_end: 0.02 +num_timesteps_cond: 1 +log_every_t: 200 +timesteps: 1000 +first_stage_key: "jpg" +cond_stage_key: "txt" +image_size: 128 +channels: 4 +cond_stage_trainable: false +conditioning_key: "hybrid-adm" +monitor: val/loss_simple_ema +scale_factor: 0.08333 +use_ema: False + +low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 diff --git a/assets/config/encoders/clip.yaml b/assets/config/encoders/clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8082b5b56b0cec7d586f4e0830d206ab3fccde10 --- /dev/null +++ b/assets/config/encoders/clip.yaml @@ -0,0 +1 @@ +__class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder \ No newline at end of file diff --git a/assets/config/encoders/openclip.yaml b/assets/config/encoders/openclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca74e9c97a230642a8023d843d793858c9e4c5c0 --- /dev/null +++ b/assets/config/encoders/openclip.yaml @@ -0,0 +1,4 @@ +__class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder +__init__: + freeze: True + layer: "penultimate" \ No newline at end of file diff --git a/assets/config/unet/inpainting/v1.yaml b/assets/config/unet/inpainting/v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6be2f6b5c129bb8c52f0432824723851bda325d --- /dev/null +++ b/assets/config/unet/inpainting/v1.yaml @@ -0,0 +1,15 @@ +__class__: smplfusion.models.unet.UNetModel +__init__: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False \ No newline at end of file diff --git a/assets/config/unet/inpainting/v2.yaml b/assets/config/unet/inpainting/v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c78bc2a37a344dd1499aaab01580e0c2cd7e27bc --- /dev/null +++ b/assets/config/unet/inpainting/v2.yaml @@ -0,0 +1,16 @@ +__class__: smplfusion.models.unet.UNetModel +__init__: + use_checkpoint: False + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False \ No newline at end of file diff --git a/assets/config/unet/upsample/v2.yaml b/assets/config/unet/upsample/v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fcbd87e98a510aa718afc552bb111e1110d71fe --- /dev/null +++ b/assets/config/unet/upsample/v2.yaml @@ -0,0 +1,19 @@ +__class__: smplfusion.models.unet.UNetModel +__init__: + use_checkpoint: False + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True \ No newline at end of file diff --git a/assets/config/vae-upsample.yaml b/assets/config/vae-upsample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..989a5c6d374ebc3bce469c151765250d49071330 --- /dev/null +++ b/assets/config/vae-upsample.yaml @@ -0,0 +1,16 @@ +__class__: smplfusion.models.vae.AutoencoderKL +__init__: + embed_dim: 4 + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/assets/config/vae.yaml b/assets/config/vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52e68334ce00ab350838f4dd865aff9da6482dae --- /dev/null +++ b/assets/config/vae.yaml @@ -0,0 +1,17 @@ +__class__: smplfusion.models.vae.AutoencoderKL +__init__: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1,2,4,4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/assets/examples/images/a19.jpg b/assets/examples/images/a19.jpg new file mode 100644 index 0000000000000000000000000000000000000000..97d84ac02147682222dddf20d09f60dfc20830db --- /dev/null +++ b/assets/examples/images/a19.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4622138454df716ad6a8015c13cf7889a94c63e54c759f130a576eb5280eabf1 +size 237183 diff --git a/assets/examples/images/a2.jpg b/assets/examples/images/a2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6317ae4bab3d7a1d09850e53f6f0d578c30abc15 --- /dev/null +++ b/assets/examples/images/a2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74cc2a7407234fc477e66a5a776a57d7b21618e5fea166a32f9b20d6dbd272ba +size 1320325 diff --git a/assets/examples/images/a4.jpg b/assets/examples/images/a4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e88f8539c77878897b6803ee40913458a15bc4e --- /dev/null +++ b/assets/examples/images/a4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7811c416c6352720d853de721cc41f6d52b2d034c10277c35301c272c3843f7f +size 824215 diff --git a/assets/examples/images/a40.jpg b/assets/examples/images/a40.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c563c52dcdbd9ca45c4926d89aac129d03ce8f70 --- /dev/null +++ b/assets/examples/images/a40.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03c76588a2a8782e2bab2f5b309e1b3a69932b1c87c38b16599ea0fedb9d30e7 +size 470085 diff --git a/assets/examples/images/a46.jpg b/assets/examples/images/a46.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dfea11f744ea69d55e5b7e74994a8215e5b37a60 --- /dev/null +++ b/assets/examples/images/a46.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0e06678355e5798e1ae280932b806891d111ac9f23fbbc8fde7429df666aadb +size 118803 diff --git a/assets/examples/images/a51.jpg b/assets/examples/images/a51.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d6d0e79b5339dc0d898469de82dcaf2ddf4f7ab --- /dev/null +++ b/assets/examples/images/a51.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c18be20ed65489dfc851624923828a481275f680d200d1e57ba743d04208ff6 +size 190989 diff --git a/assets/examples/images/a54.jpg b/assets/examples/images/a54.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6c76911cae56a861539ba243c17b5afdfad26c59 --- /dev/null +++ b/assets/examples/images/a54.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4333811c5597d088979d58ff778e854bc55f829e9c9b4f564148c1f59b99c36 +size 638627 diff --git a/assets/examples/images/a65.jpg b/assets/examples/images/a65.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bec1fe3ed351887be1702e92f2145e7cc35df0b2 --- /dev/null +++ b/assets/examples/images/a65.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f7afcb446f4903c37df7a640f931278c9defe79d1d014a50268b3c0ae232543 +size 239481 diff --git a/assets/examples/masked/a19.png b/assets/examples/masked/a19.png new file mode 100644 index 0000000000000000000000000000000000000000..b8499e168933d3503cf122c7d4d678366ab8269e --- /dev/null +++ b/assets/examples/masked/a19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af2982b36bce993fa9bc0152317a805dea3e765dd5ffae961eba05cf7c54164f +size 1801745 diff --git a/assets/examples/masked/a2.png b/assets/examples/masked/a2.png new file mode 100644 index 0000000000000000000000000000000000000000..1724a2f07a18472e2d0f6b7395ccfce1a1354802 --- /dev/null +++ b/assets/examples/masked/a2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e860721ef68093859d6aae3dc4b595c07719149fb9f5c55cc3cc147ef82ed6b +size 2069494 diff --git a/assets/examples/masked/a4.png b/assets/examples/masked/a4.png new file mode 100644 index 0000000000000000000000000000000000000000..0613458692b90037ec73f54d5215410a9d202ef9 --- /dev/null +++ b/assets/examples/masked/a4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d8f25579699a11a94dc9dafdcab95000a71e82707dc8f6f0d18ceaf4fe44c0a +size 3243031 diff --git a/assets/examples/masked/a40.png b/assets/examples/masked/a40.png new file mode 100644 index 0000000000000000000000000000000000000000..98eec49323ec88ac0ecf6a1951a7aa2141cb345b --- /dev/null +++ b/assets/examples/masked/a40.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:988a74bfd29509b1fede12e4234ca8afd9926ad8bca2b87c9a7e6c5e0e758bb9 +size 3727734 diff --git a/assets/examples/masked/a46.png b/assets/examples/masked/a46.png new file mode 100644 index 0000000000000000000000000000000000000000..ad0ca0f769a73914ce6ea3b527078965d8f8ee13 --- /dev/null +++ b/assets/examples/masked/a46.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a624e2155e1c88c11c7f221179c329e44e7a3f500e22775805ce0280ee39b4af +size 768909 diff --git a/assets/examples/masked/a51.png b/assets/examples/masked/a51.png new file mode 100644 index 0000000000000000000000000000000000000000..ee34c510787386918fa72d114b494a56bef9a9fb --- /dev/null +++ b/assets/examples/masked/a51.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfa02db8e37622915725c9a3bdafb8b16462bddca414f004f51a3c41d2faaa51 +size 1869562 diff --git a/assets/examples/masked/a54.png b/assets/examples/masked/a54.png new file mode 100644 index 0000000000000000000000000000000000000000..8936dc7500a92486c4ec80a83dc204ea9bb5d6e8 --- /dev/null +++ b/assets/examples/masked/a54.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31cd7b538ee33e1e83a5aea6879fa3b898dc379d4e2d210bcd80246d8bec9a40 +size 5330619 diff --git a/assets/examples/masked/a65.png b/assets/examples/masked/a65.png new file mode 100644 index 0000000000000000000000000000000000000000..23f79b0971ff26b0155739f90b23750fd458efb2 --- /dev/null +++ b/assets/examples/masked/a65.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0039ca6e324c3603cb1ed4c830b1d9b69224de12a3fa33dddfaa558b4076cefe +size 2310014 diff --git a/assets/examples/sbs/a19.png b/assets/examples/sbs/a19.png new file mode 100644 index 0000000000000000000000000000000000000000..7f228601d02bc5b1f49ecbd5dac89ac6443c730b --- /dev/null +++ b/assets/examples/sbs/a19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:305df57833fcaf6d75c15cb1a29d048a43e687bb23e400ee37b2b8d48004bb39 +size 2421308 diff --git a/assets/examples/sbs/a2.png b/assets/examples/sbs/a2.png new file mode 100644 index 0000000000000000000000000000000000000000..bb1a2b18417fa31c6f33833937b1742fb9916229 --- /dev/null +++ b/assets/examples/sbs/a2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a034a4eeef7f49654e7d0efc29517f0cf79b5acef398721df8db437bd91a3ada +size 4051036 diff --git a/assets/examples/sbs/a4.png b/assets/examples/sbs/a4.png new file mode 100644 index 0000000000000000000000000000000000000000..d389d9e0d520f95a87017e42c0fa18ef0cadca0f --- /dev/null +++ b/assets/examples/sbs/a4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd4e51f460c957abf6195d76a49915531faa3a0945d1f6dd37b870dd172f4de8 +size 7042935 diff --git a/assets/examples/sbs/a40.png b/assets/examples/sbs/a40.png new file mode 100644 index 0000000000000000000000000000000000000000..9a89d71a75951a2cee7bf4d9f603ad46d75c5482 --- /dev/null +++ b/assets/examples/sbs/a40.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:186d588650180703e7046ecc61bd8c53f35d53769ceae45083304f1adc7e31d7 +size 5865496 diff --git a/assets/examples/sbs/a46.png b/assets/examples/sbs/a46.png new file mode 100644 index 0000000000000000000000000000000000000000..b04b1548540facd46653a715b8471dd43eaae47b --- /dev/null +++ b/assets/examples/sbs/a46.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87487caeccfa7fea5280b8fa6c97ab0179d7c4568cf0f01d363852d6a48168aa +size 1298406 diff --git a/assets/examples/sbs/a51.png b/assets/examples/sbs/a51.png new file mode 100644 index 0000000000000000000000000000000000000000..ea485ab125644bc45d66b0d859585310167c42b6 --- /dev/null +++ b/assets/examples/sbs/a51.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e5c983156abb921c25e43794e6a6be5dc5bc4af815121331bccb8259f35efd4 +size 2445924 diff --git a/assets/examples/sbs/a54.png b/assets/examples/sbs/a54.png new file mode 100644 index 0000000000000000000000000000000000000000..8cd45a08fcb9511bfe672546c3003ff18b828c39 --- /dev/null +++ b/assets/examples/sbs/a54.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e12ff2497806dfe1f04f2c29a62cec89eef7e8c9382b25ac04cc966cdfdd157 +size 8355872 diff --git a/assets/examples/sbs/a65.png b/assets/examples/sbs/a65.png new file mode 100644 index 0000000000000000000000000000000000000000..4313ec3896a3a68b3c66688ebdaca3eecb9ef355 --- /dev/null +++ b/assets/examples/sbs/a65.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60f66a473f8e52bbcdfdd6c92ff99bae8a2af654a6e0f1b4c6d75385400fe0f9 +size 3288688 diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/methods/__init__.py b/lib/methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/methods/rasg.py b/lib/methods/rasg.py new file mode 100644 index 0000000000000000000000000000000000000000..abb017fe55e833c4b39c72fe8461d118c4d36f79 --- /dev/null +++ b/lib/methods/rasg.py @@ -0,0 +1,88 @@ +import torch +from lib.utils.iimage import IImage +from pytorch_lightning import seed_everything +from tqdm import tqdm + +from lib.smplfusion import share, router, attentionpatch, transformerpatch +from lib.smplfusion.patches.attentionpatch import painta +from lib.utils import tokenize, scores + +verbose = False + + +def init_painta(token_idx): + # Initialize painta + router.attention_forward = attentionpatch.painta.forward + router.basic_transformer_forward = transformerpatch.painta.forward + painta.painta_on = True + painta.painta_res = [16, 32] + painta.token_idx = token_idx + +def init_guidance(): + # Setup model for guidance only! + router.attention_forward = attentionpatch.default.forward_and_save + router.basic_transformer_forward = transformerpatch.default.forward + +def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, positive_prompt, dt, guidance_scale): + # Text condition + prompt = prefix.format(prompt) + context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt]) + token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index(''))) + token_idx += [tokenize(prompt + positive_prompt).index('')] + + # Initialize painta + if 'painta' in method: init_painta(token_idx) + else: init_guidance() + + # Image condition + unet_condition = ddim.get_inpainting_condition(image, mask) + share.set_mask(mask) + + # Starting latent + seed_everything(seed) + zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda() + + # Setup unet for guidance + ddim.unet.requires_grad_(True) + + pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) + + for timestep in share.DDIMIterator(pbar): + if 'painta' in method and share.timestep <= 500: init_guidance() + + zt = zt.detach() + zt.requires_grad = True + + # Reset storage + share._crossattn_similarity_res16 = [] + + # Run the model + _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) + eps_uncond, eps = ddim.unet( + torch.cat([_zt, _zt]), + timesteps = torch.tensor([timestep, timestep]).cuda(), + context = context + ).detach().chunk(2) + + # Unconditional guidance + eps = (eps_uncond + guidance_scale * (eps - eps_uncond)) + z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep] + + # Gradient Computation + score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx) + score.backward() + grad = zt.grad.detach() + ddim.unet.zero_grad() # Cleanup already + + # DDIM Step + with torch.no_grad(): + sigma = share.schedule.sigma(share.timestep, dt) + # Standartization + grad -= grad.mean() + grad /= grad.std() + + zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + eta * sigma * grad + + with torch.no_grad(): + output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor)) + return output_image \ No newline at end of file diff --git a/lib/methods/sd.py b/lib/methods/sd.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff79cd071b82f93d491350aa151ec4b0b8e11cf --- /dev/null +++ b/lib/methods/sd.py @@ -0,0 +1,74 @@ +import torch +from pytorch_lightning import seed_everything +from tqdm import tqdm + +from lib.utils.iimage import IImage +from lib.smplfusion import share, router, attentionpatch, transformerpatch +from lib.smplfusion.patches.attentionpatch import painta +from lib.utils import tokenize + +verbose = False + + +def init_painta(token_idx): + # Initialize painta + router.attention_forward = attentionpatch.painta.forward + router.basic_transformer_forward = transformerpatch.painta.forward + painta.painta_on = True + painta.painta_res = [16, 32] + painta.token_idx = token_idx + +def run( + ddim, + method, + prompt, + image, + mask, + seed, + eta, + prefix, + negative_prompt, + positive_prompt, + dt, + guidance_scale +): + # Text condition + context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt]) + token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index(''))) + token_idx += [tokenize(prompt + positive_prompt).index('')] + + # Setup painta if needed + if 'painta' in method: init_painta(token_idx) + else: router.reset() + + # Image condition + unet_condition = ddim.get_inpainting_condition(image, mask) + share.set_mask(mask) + + # Starting latent + seed_everything(seed) + zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda() + + # Turn off gradients + ddim.unet.requires_grad_(False) + + pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) + + for timestep in share.DDIMIterator(pbar): + if share.timestep <= 500: router.reset() + + _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) + eps_uncond, eps = ddim.unet( + torch.cat([_zt, _zt]), + timesteps = torch.tensor([timestep, timestep]).cuda(), + context = context + ).chunk(2) + + eps = (eps_uncond + guidance_scale * (eps - eps_uncond)) + z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep] + zt = share.schedule.sqrt_alphas[timestep - dt] * z0 + share.schedule.sqrt_one_minus_alphas[timestep - dt] * eps + + with torch.no_grad(): + output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor)) + + return output_image diff --git a/lib/methods/sr.py b/lib/methods/sr.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ce34a63b76b2f668dfda1b974001d88cff0955 --- /dev/null +++ b/lib/methods/sr.py @@ -0,0 +1,141 @@ +import os +from functools import partial +from glob import glob +from pathlib import Path as PythonPath + +import cv2 +import torchvision.transforms.functional as TvF +import torch +import torch.nn as nn +import numpy as np +from inspect import isfunction +from PIL import Image + +from lib import smplfusion +from lib.smplfusion import share, router, attentionpatch, transformerpatch +from lib.utils.iimage import IImage +from lib.utils import poisson_blend +from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v + + +def refine_mask(hr_image, hr_mask, lr_image, sam_predictor): + lr_mask = hr_mask.resize(512) + + x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0]) + x_min = max(x_min - 1, 0) + y_min = max(y_min - 1, 0) + x_max = x_min + rect_w + 1 + y_max = y_min + rect_h + 1 + + input_box = np.array([x_min, y_min, x_max, y_max]) + + sam_predictor.set_image(hr_image.resize(512).data[0]) + masks, _, _ = sam_predictor.predict( + point_coords=None, + point_labels=None, + box=input_box[None, :], + multimask_output=True, + ) + dilation_kernel = np.ones((13, 13)) + original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) + original_object_mask = cv2.dilate(original_object_mask, dilation_kernel) + + sam_predictor.set_image(lr_image.resize(512).data[0]) + masks, _, _ = sam_predictor.predict( + point_coords=None, + point_labels=None, + box=input_box[None, :], + multimask_output=True, + ) + dilation_kernel = np.ones((3, 3)) + inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8) + inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel) + + lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8) + new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis] + new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC) + return new_mask + + +def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20, +blend_output = True, blend_trick = True, no_superres = False, +dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False, dtype=torch.bfloat16): + torch.manual_seed(seed) + + router.attention_forward = attentionpatch.default.forward_xformers + router.basic_transformer_forward = transformerpatch.default.forward + + if use_sam_mask: + with torch.no_grad(): + hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor) + + orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3] + hr_image = hr_image.padx(256, padding_mode='reflect') + hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19) + hr_mask_orig = hr_mask + lr_image = lr_image.padx(64, padding_mode='reflect') + lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).cuda() + lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19) + + if no_superres: + output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda() + output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8) + output_tensor = poisson_blend( + orig_img=hr_image.data[0][:orig_h, :orig_w, :], + fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :], + mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :] + ) + return IImage(output_tensor[:orig_h, :orig_w, :]) + + # encode hr image + with torch.no_grad(): + hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype)).mean * ddim.config.scale_factor + + assert hr_z0.shape[2] == lr_image.torch().shape[2] + assert hr_z0.shape[3] == lr_image.torch().shape[3] + + unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype) + zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype) + + with torch.no_grad(): + context = ddim.encoder.encode([negative_prompt, prompt]) + + noise_level = torch.Tensor(1 * [noise_level]).to('cuda').long() + unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level) + + with torch.autocast('cuda'), torch.no_grad(): + zt = zT + for index,t in enumerate(range(999, 0, -dt)): + + _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) + + eps_uncond, eps = ddim.unet( + torch.cat([_zt, _zt]).to(dtype), + timesteps = torch.tensor([t, t]).cuda(), + context = context, + y=torch.cat([noise_level]*2) + ).chunk(2) + + ts = torch.full((zt.shape[0],), t, device='cuda', dtype=torch.long) + model_output = (eps_uncond + guidance_scale * (eps - eps_uncond)) + eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) + z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype) + + if blend_trick: + z0 = z0 * lr_mask + hr_z0 * (1-lr_mask) + + zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps + + with torch.no_grad(): + output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor) + + if blend_output: + output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8) + output_tensor = poisson_blend( + orig_img=hr_image.data[0][:orig_h, :orig_w, :], + fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :], + mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :] + ) + return IImage(output_tensor[:orig_h, :orig_w, :]) + else: + return IImage(output_tensor[:, :, :orig_h, :orig_w]) diff --git a/lib/models/__init__.py b/lib/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f38837760733e194c475623e9b3821017918940 --- /dev/null +++ b/lib/models/__init__.py @@ -0,0 +1 @@ +from . import sd2_inp, ds_inp, sd15_inp, sd2_sr, sam \ No newline at end of file diff --git a/lib/models/common.py b/lib/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..70c5473274efadf65e4bfc67da956a4317273998 --- /dev/null +++ b/lib/models/common.py @@ -0,0 +1,49 @@ +import importlib +import requests +from pathlib import Path +from os.path import dirname + +from omegaconf import OmegaConf +from tqdm import tqdm + + +PROJECT_DIR = dirname(dirname(dirname(__file__))) +CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config' +MODEL_FOLDER = f'{PROJECT_DIR}/assets/models' + + +def download_file(url, save_path, chunk_size=1024): + try: + save_path = Path(save_path) + if save_path.exists(): + print(f'{save_path.name} exists') + return + save_path.parent.mkdir(exist_ok=True, parents=True) + resp = requests.get(url, stream=True) + total = int(resp.headers.get('content-length', 0)) + with open(save_path, 'wb') as file, tqdm( + desc=save_path.name, + total=total, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_content(chunk_size=chunk_size): + size = file.write(data) + bar.update(size) + print(f'{save_path.name} download finished') + except Exception as e: + raise Exception(f"Download failed: {e}") + + +def get_obj_from_str(string): + module, cls = string.rsplit(".", 1) + try: + return getattr(importlib.import_module(module, package=None), cls) + except: + return getattr(importlib.import_module('lib.' + module, package=None), cls) + + +def load_obj(path): + objyaml = OmegaConf.load(path) + return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) diff --git a/lib/models/ds_inp.py b/lib/models/ds_inp.py new file mode 100644 index 0000000000000000000000000000000000000000..93d00049d3ba98fa10ab67f516d48bf8e091c022 --- /dev/null +++ b/lib/models/ds_inp.py @@ -0,0 +1,46 @@ +import importlib +from omegaconf import OmegaConf +import torch +import safetensors +import safetensors.torch + +from lib.smplfusion import DDIM, share, scheduler +from .common import * + + +MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors' +DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004' + +# pre-download +download_file(DOWNLOAD_URL, MODEL_PATH) + + +def load_model(): + print ("Loading model: Dreamshaper Inpainting V8") + + download_file(DOWNLOAD_URL, MODEL_PATH) + + state_dict = safetensors.torch.load_file(MODEL_PATH) + + config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml') + unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda() + vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda() + encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda() + + extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} + unet_state = extract(state_dict, 'model.diffusion_model') + encoder_state = extract(state_dict, 'cond_stage_model') + vae_state = extract(state_dict, 'first_stage_model') + + unet.load_state_dict(unet_state) + encoder.load_state_dict(encoder_state) + vae.load_state_dict(vae_state) + + unet = unet.requires_grad_(False) + encoder = encoder.requires_grad_(False) + vae = vae.requires_grad_(False) + + ddim = DDIM(config, vae, encoder, unet) + share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end) + + return ddim diff --git a/lib/models/sam.py b/lib/models/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..3b649e4827fef52361a325f40b06f3f144e2bcd1 --- /dev/null +++ b/lib/models/sam.py @@ -0,0 +1,20 @@ +from segment_anything import sam_model_registry, SamPredictor +from .common import * + +MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth' +DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth' + +# pre-download +download_file(DOWNLOAD_URL, MODEL_PATH) + + +def load_model(): + print ("Loading model: SAM") + download_file(DOWNLOAD_URL, MODEL_PATH) + model_type = "vit_h" + device = "cuda" + sam = sam_model_registry[model_type](checkpoint=MODEL_PATH) + sam.to(device=device) + sam_predictor = SamPredictor(sam) + print ("SAM loaded") + return sam_predictor diff --git a/lib/models/sd15_inp.py b/lib/models/sd15_inp.py new file mode 100644 index 0000000000000000000000000000000000000000..c1261b8562d37215c7afee96e333ad6cd45688b8 --- /dev/null +++ b/lib/models/sd15_inp.py @@ -0,0 +1,44 @@ +from omegaconf import OmegaConf +import torch + +from lib.smplfusion import DDIM, share, scheduler +from .common import * + + +DOWNLOAD_URL = 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt?download=true' +MODEL_PATH = f'{MODEL_FOLDER}/sd-1-5-inpainting/sd-v1-5-inpainting.ckpt' + +# pre-download +download_file(DOWNLOAD_URL, MODEL_PATH) + + +def load_model(): + download_file(DOWNLOAD_URL, MODEL_PATH) + + state_dict = torch.load(MODEL_PATH)['state_dict'] + + config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml') + + print ("Loading model: Stable-Inpainting 1.5") + + unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda() + vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda() + encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda() + + extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} + unet_state = extract(state_dict, 'model.diffusion_model') + encoder_state = extract(state_dict, 'cond_stage_model') + vae_state = extract(state_dict, 'first_stage_model') + + unet.load_state_dict(unet_state) + encoder.load_state_dict(encoder_state) + vae.load_state_dict(vae_state) + + unet = unet.requires_grad_(False) + encoder = encoder.requires_grad_(False) + vae = vae.requires_grad_(False) + + ddim = DDIM(config, vae, encoder, unet) + share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end) + + return ddim diff --git a/lib/models/sd2_inp.py b/lib/models/sd2_inp.py new file mode 100644 index 0000000000000000000000000000000000000000..9d59f3dacb437b4d08d257732ac446a006f8f1c0 --- /dev/null +++ b/lib/models/sd2_inp.py @@ -0,0 +1,47 @@ +import safetensors +import safetensors.torch +import torch +from omegaconf import OmegaConf + +from lib.smplfusion import DDIM, share, scheduler +from .common import * + +MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-inpainting/512-inpainting-ema.safetensors' +DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true' + +# pre-download +download_file(DOWNLOAD_URL, MODEL_PATH) + + +def load_model(): + print ("Loading model: Stable-Inpainting 2.0") + + download_file(DOWNLOAD_URL, MODEL_PATH) + + state_dict = safetensors.torch.load_file(MODEL_PATH) + + config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml') + + unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda() + vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda() + encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda() + ddim = DDIM(config, vae, encoder, unet) + + extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} + unet_state = extract(state_dict, 'model.diffusion_model') + encoder_state = extract(state_dict, 'cond_stage_model') + vae_state = extract(state_dict, 'first_stage_model') + + unet.load_state_dict(unet_state) + encoder.load_state_dict(encoder_state) + vae.load_state_dict(vae_state) + + unet = unet.requires_grad_(False) + encoder = encoder.requires_grad_(False) + vae = vae.requires_grad_(False) + + ddim = DDIM(config, vae, encoder, unet) + share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end) + + print('Stable-Inpainting 2.0 loaded') + return ddim diff --git a/lib/models/sd2_sr.py b/lib/models/sd2_sr.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4a1adbdf27e16ef4f7244bb4563d6f8b853486 --- /dev/null +++ b/lib/models/sd2_sr.py @@ -0,0 +1,204 @@ +import importlib +from functools import partial + +import cv2 +import numpy as np +import safetensors +import safetensors.torch +import torch +import torch.nn as nn +from inspect import isfunction +from omegaconf import OmegaConf + +from lib.smplfusion import DDIM, share, scheduler +from .common import * + + +DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.safetensors?download=true' +MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-upsample/x4-upscaler-ema.safetensors' + +# pre-download +download_file(DOWNLOAD_URL, MODEL_PATH) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def predict_eps_from_z_and_v(schedule, x_t, t, v): + return ( + extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * v + + extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * x_t + ) + + +def predict_start_from_z_and_v(schedule, x_t, t, v): + return ( + extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * x_t - + extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * v + ) + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + +def get_obj_from_str(string): + module, cls = string.rsplit(".", 1) + try: + return getattr(importlib.import_module(module, package=None), cls) + except: + return getattr(importlib.import_module('lib.' + module, package=None), cls) +def load_obj(path): + objyaml = OmegaConf.load(path) + return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) + + +def load_model(dtype=torch.bfloat16): + print ("Loading model: SD2 superresolution...") + + download_file(DOWNLOAD_URL, MODEL_PATH) + + state_dict = safetensors.torch.load_file(MODEL_PATH) + + config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml') + + unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().cuda() + vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().cuda() + encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda() + ddim = DDIM(config, vae, encoder, unet) + + extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} + unet_state = extract(state_dict, 'model.diffusion_model') + encoder_state = extract(state_dict, 'cond_stage_model') + vae_state = extract(state_dict, 'first_stage_model') + + unet.load_state_dict(unet_state) + encoder.load_state_dict(encoder_state) + vae.load_state_dict(vae_state) + + unet = unet.requires_grad_(False) + encoder = encoder.requires_grad_(False) + vae = vae.requires_grad_(False) + + unet.to(dtype) + vae.to(dtype) + encoder.to(dtype) + + ddim = DDIM(config, vae, encoder, unet) + + params = { + 'noise_schedule_config': { + 'linear_start': 0.0001, + 'linear_end': 0.02 + }, + 'max_noise_level': 350 + } + + low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to('cuda') + low_scale_model.train = disabled_train + for param in low_scale_model.parameters(): + param.requires_grad = False + + ddim.low_scale_model = low_scale_model + print('SD2 superresolution loaded') + return ddim diff --git a/lib/smplfusion/__init__.py b/lib/smplfusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2e660b96b945f0d0d1ee30b783f3e2123f21b0 --- /dev/null +++ b/lib/smplfusion/__init__.py @@ -0,0 +1,3 @@ +from . import share, scheduler +from .ddim import DDIM +from .patches import router, attentionpatch, transformerpatch \ No newline at end of file diff --git a/lib/smplfusion/ddim.py b/lib/smplfusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..b40e70e5eaf413f76ddcfc5bb6ebbbc55f0b7976 --- /dev/null +++ b/lib/smplfusion/ddim.py @@ -0,0 +1,55 @@ +import torch +from tqdm.notebook import tqdm +from . import scheduler +from . import share + +from lib.utils.iimage import IImage + +class DDIM: + def __init__(self, config, vae, encoder, unet): + self.vae = vae + self.encoder = encoder + self.unet = unet + self.config = config + self.schedule = scheduler.linear(1000, config.linear_start, config.linear_end) + + def __call__( + self, prompt = '', dt = 50, shape = (1,4,64,64), seed = None, negative_prompt = '', unet_condition = None, + context = None, verbose = True): + if seed is not None: torch.manual_seed(seed) + if unet_condition is not None: + zT = torch.randn((1,4) + unet_condition.shape[2:]).cuda() + else: + zT = torch.randn(shape).cuda() + + with torch.autocast('cuda'), torch.no_grad(): + if context is None: context = self.encoder.encode([negative_prompt, prompt]) + + zt = zT + pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) + for timestep in share.DDIMIterator(pbar): + _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) + eps_uncond, eps = self.unet( + torch.cat([_zt, _zt]), + timesteps = torch.tensor([timestep, timestep]).cuda(), + context = context + ).chunk(2) + + eps = (eps_uncond + 7.5 * (eps - eps_uncond)) + + z0 = (zt - self.schedule.sqrt_one_minus_alphas[timestep] * eps) / self.schedule.sqrt_alphas[timestep] + zt = self.schedule.sqrt_alphas[timestep - dt] * z0 + self.schedule.sqrt_one_minus_alphas[timestep - dt] * eps + return IImage(self.vae.decode(z0 / self.config.scale_factor)) + + def get_inpainting_condition(self, image, mask): + latent_size = [x//8 for x in image.size] + with torch.no_grad(): + condition_x0 = self.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean * self.config.scale_factor + + condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float() + + condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask) + return torch.cat([condition_mask, condition_x0], 1) + + inpainting_condition = get_inpainting_condition + diff --git a/lib/smplfusion/models/__init__.py b/lib/smplfusion/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/smplfusion/models/encoders/clip_embedder.py b/lib/smplfusion/models/encoders/clip_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c31fd669cd70ab4d10177486cc8b9ef0c2966e --- /dev/null +++ b/lib/smplfusion/models/encoders/clip_embedder.py @@ -0,0 +1,48 @@ +from torch import nn +from transformers import CLIPTokenizer, CLIPTextModel + +class FrozenCLIPEmbedder(nn.Module): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) \ No newline at end of file diff --git a/lib/smplfusion/models/encoders/open_clip_embedder.py b/lib/smplfusion/models/encoders/open_clip_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3de8869dc2f6077c0dff0a860a24c88e81ee73 --- /dev/null +++ b/lib/smplfusion/models/encoders/open_clip_embedder.py @@ -0,0 +1,56 @@ +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +import open_clip + +class FrozenOpenCLIPEmbedder(nn.Module): + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) \ No newline at end of file diff --git a/lib/smplfusion/models/unet.py b/lib/smplfusion/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..903fa392adf13398e70f1d3a24292680170634d0 --- /dev/null +++ b/lib/smplfusion/models/unet.py @@ -0,0 +1,495 @@ +import math +from abc import abstractmethod + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .util import ( + checkpoint,conv_nd,linear,avg_pool_nd, + zero_module,normalization,timestep_embedding, +) +from ..modules.attention.spatial_transformer import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): pass +def convert_module_to_f32(x): pass + + +## go +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims,self.channels,self.out_channels,3,stride=stride,padding=padding,) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self,channels,emb_channels,dropout,out_channels=None,use_conv=False,use_scale_shift_norm=False, + dims=2,use_checkpoint=False,up=False,down=False + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self,image_size,in_channels,model_channels,out_channels,num_res_blocks,attention_resolutions,dropout=0, + channel_mult=(1, 2, 4, 8),conv_resample=True,dims=2,num_classes=None,use_checkpoint=False,use_fp16=False, + use_bf16=False,num_heads=-1,num_head_channels=-1,num_heads_upsample=-1,use_scale_shift_norm=False,resblock_updown=False, + use_new_attention_order=False,use_spatial_transformer=False,transformer_depth=1,context_dim=None, + n_embed=None,legacy=True,disable_self_attentions=None,num_attention_blocks=None,disable_middle_self_attn=False, + use_linear_in_transformer=False,adm_in_channels=None, + ): + super().__init__() + + if context_dim is not None: + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ]) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch,time_embed_dim,dropout,out_channels=mult * model_channels,dims=dims, + use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = (ch // num_heads if use_spatial_transformer else num_head_channels) + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + SpatialTransformer( + ch,num_heads,dim_head,depth=transformer_depth,context_dim=context_dim, + disable_self_attn=disabled_sa,use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch,time_embed_dim,dropout,out_channels=out_ch,dims=dims,use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm,down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch,time_embed_dim,dropout,dims=dims, + use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm, + ), + SpatialTransformer( # always uses a self-attn + ch,num_heads,dim_head,depth=transformer_depth,context_dim=context_dim,disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint, + ), + ResBlock( + ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich,time_embed_dim,dropout,out_channels=model_channels * mult,dims=dims, + use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if (num_attention_blocks is None or i < num_attention_blocks[level]): + layers.append( + SpatialTransformer( + ch,num_heads,dim_head,depth=transformer_depth,context_dim=context_dim,disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch,time_embed_dim,dropout,out_channels=out_ch,dims=dims, + use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :return: an [N x C x ...] Tensor of outputs. + """ + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/lib/smplfusion/models/util.py b/lib/smplfusion/models/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e7f25b93728f3f2af48efc0ff7eeff9157c422 --- /dev/null +++ b/lib/smplfusion/models/util.py @@ -0,0 +1,257 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/lib/smplfusion/models/vae.py b/lib/smplfusion/models/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..f909b5b418e5a143c30832940177247e258d6c61 --- /dev/null +++ b/lib/smplfusion/models/vae.py @@ -0,0 +1,197 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +from contextlib import contextmanager + +from ..modules.autoencoder import Encoder, Decoder +from ..modules.distributions import DiagonalGaussianDistribution + +from ..util import instantiate_from_config +from ..modules.ema import LitEma + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x \ No newline at end of file diff --git a/lib/smplfusion/modules/__init__.py b/lib/smplfusion/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/smplfusion/modules/attention/__init__.py b/lib/smplfusion/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/smplfusion/modules/attention/basic_transformer_block.py b/lib/smplfusion/modules/attention/basic_transformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ea5d4c344095d982ab0a644b6f90b7e9f66d2a --- /dev/null +++ b/lib/smplfusion/modules/attention/basic_transformer_block.py @@ -0,0 +1,62 @@ +import torch +from torch import nn +from .feed_forward import FeedForward + +try: + from .cross_attention import PatchedCrossAttention as CrossAttention +except: + try: + from .memory_efficient_cross_attention import MemoryEfficientCrossAttention as CrossAttention + except: + from .cross_attention import CrossAttention +from ..util import checkpoint +from ...patches import router + +class BasicTransformerBlock(nn.Module): + def __init__( + self,dim,n_heads,d_head,dropout=0.0,context_dim=None, + gated_ff=True,checkpoint=True,disable_self_attn=False, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + # is a self-attention if not self.disable_self_attn + self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + # is self-attn if context is none + self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = x + self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x = x + self.attn2(self.norm2(x), context=context) + x = x + self.ff(self.norm3(x)) + return x + +class PatchedBasicTransformerBlock(nn.Module): + def __init__( + self,dim,n_heads,d_head,dropout=0.0,context_dim=None, + gated_ff=True,checkpoint=True,disable_self_attn=False, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + # is a self-attention if not self.disable_self_attn + self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + # is self-attn if context is none + self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + return router.basic_transformer_forward(self, x, context) diff --git a/lib/smplfusion/modules/attention/cross_attention.py b/lib/smplfusion/modules/attention/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ad00d1c8feffb84182237df17dca3f341496b9 --- /dev/null +++ b/lib/smplfusion/modules/attention/cross_attention.py @@ -0,0 +1,85 @@ +# CrossAttn precision handling +import os + +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +import torch +from torch import nn + +from torch import einsum +from einops import rearrange, repeat +import torch +from torch import nn +from typing import Optional, Any +from ...patches import router + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim or query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == "fp32": + with torch.autocast(enabled=False, device_type="cuda"): + q, k = q.float(), k.float() + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + else: + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + del q, k + + if mask is not None: + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + +class PatchedCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim or query_dim + + self.heads = heads + self.dim_head = dim_head + self.scale = dim_head**-0.5 + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + return router.attention_forward(self, x, context, mask) \ No newline at end of file diff --git a/lib/smplfusion/modules/attention/feed_forward.py b/lib/smplfusion/modules/attention/feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..13bc0fdf026683147bed675d30f225e537a132e7 --- /dev/null +++ b/lib/smplfusion/modules/attention/feed_forward.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out or dim + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) \ No newline at end of file diff --git a/lib/smplfusion/modules/attention/memory_efficient_cross_attention.py b/lib/smplfusion/modules/attention/memory_efficient_cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c1ba6a212f45be15a4ed09f53d28f78028abbb --- /dev/null +++ b/lib/smplfusion/modules/attention/memory_efficient_cross_attention.py @@ -0,0 +1,56 @@ +import torch +from torch import nn +from typing import Optional, Any + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads.") + inner_dim = dim_head * heads + context_dim = context_dim or query_dim + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if mask is not None: + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) \ No newline at end of file diff --git a/lib/smplfusion/modules/attention/spatial_transformer.py b/lib/smplfusion/modules/attention/spatial_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6a32aaaace110723c0205c46f8dc90e7f168eb9a --- /dev/null +++ b/lib/smplfusion/modules/attention/spatial_transformer.py @@ -0,0 +1,88 @@ +import torch +from torch import nn +import math + +from torch import einsum +from einops import rearrange, repeat +from .basic_transformer_block import PatchedBasicTransformerBlock as BasicTransformerBlock + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/lib/smplfusion/modules/autoencoder.py b/lib/smplfusion/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..59273c58cd694adba9f8b6b9b3bbd4f26228bf71 --- /dev/null +++ b/lib/smplfusion/modules/autoencoder.py @@ -0,0 +1,422 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from .attention.memory_efficient_cross_attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=2,padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.k = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.v = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.proj_out = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h*w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.k = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.v = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.proj_out = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", + "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + # print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in,2*z_channels if double_z else z_channels,kernel_size=3,stride=1,padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels,block_in,kernel_size=3,stride=1,padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in,out_ch,kernel_size=3,stride=1,padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/lib/smplfusion/modules/distributions.py b/lib/smplfusion/modules/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..d51aed2c1ca3845b4f1031730f2f5aca6844478c --- /dev/null +++ b/lib/smplfusion/modules/distributions.py @@ -0,0 +1,73 @@ +import torch +import numpy as np + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/lib/smplfusion/modules/ema.py b/lib/smplfusion/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4 --- /dev/null +++ b/lib/smplfusion/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/lib/smplfusion/modules/util.py b/lib/smplfusion/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..33e4f0d4b2f9ffd9508a9b27c1581607c97a7fec --- /dev/null +++ b/lib/smplfusion/modules/util.py @@ -0,0 +1,262 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/lib/smplfusion/patches/__init__.py b/lib/smplfusion/patches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/smplfusion/patches/attentionpatch/__init__.py b/lib/smplfusion/patches/attentionpatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..609e406c8f352c8e7a1bbd9574a55601b727c83b --- /dev/null +++ b/lib/smplfusion/patches/attentionpatch/__init__.py @@ -0,0 +1,2 @@ +from . import default +from . import painta diff --git a/lib/smplfusion/patches/attentionpatch/default.py b/lib/smplfusion/patches/attentionpatch/default.py new file mode 100644 index 0000000000000000000000000000000000000000..96559723fb64c1d3860fc1a5ec112356e33de932 --- /dev/null +++ b/lib/smplfusion/patches/attentionpatch/default.py @@ -0,0 +1,102 @@ +from ... import share + +import xformers +import xformers.ops + + +import torch +from torch import nn, einsum +import torchvision.transforms.functional as TF +from einops import rearrange, repeat + +_ATTN_PRECISION = None + +def forward_sd2(self, x, context=None, mask=None): + h = self.heads + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + if _ATTN_PRECISION =="fp32": # force cast to fp32 to avoid overflowing + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + +def forward_xformers(self, x, context=None, mask=None): + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if mask is not None: + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + +forward = forward_xformers + +import traceback + +def forward_and_save(self, x, context=None, mask=None): + att_type = "self" if context is None else "cross" + + h = self.heads + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross': + share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': + share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross': + share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross': + share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) \ No newline at end of file diff --git a/lib/smplfusion/patches/attentionpatch/painta.py b/lib/smplfusion/patches/attentionpatch/painta.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffbe84f02eb433b0eb7303d55cbc632739a02a5 --- /dev/null +++ b/lib/smplfusion/patches/attentionpatch/painta.py @@ -0,0 +1,156 @@ +import cv2 +import math +import numbers +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +from torch import nn, einsum +from einops import rearrange, repeat + +from ... import share +from lib.utils.iimage import IImage + +# params +painta_res = [16, 32] +painta_on = True +token_idx = [1,2] + + +# GaussianSmoothing is taken from https://github.com/yuval-alaluf/Attend-and-Excite/blob/main/utils/gaussian_smoothing.py +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / (2 * std)) ** 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding='same') + + +def forward(self, x, context=None, mask=None): + is_cross = context is not None + att_type = "self" if context is None else "cross" + + h = self.heads + + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim_before = sim + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross': + share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': + share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross': + share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross': + share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts + + sim = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + if is_cross: + return self.to_out(out) + + return self.to_out(out), v, sim_before + + +def painta_rescale(y, self_v, self_sim, cross_sim, self_h, to_out): + mask = share.painta_mask.get_res(self_v) + shape = share.painta_mask.get_shape(self_v) + res = share.painta_mask.get_res_val(self_v) + + mask = (mask > 0.5).to(y.dtype) + m = mask.to(self_v.device) + m = rearrange(m, 'b c h w -> b (h w) c').contiguous() + m = torch.matmul(m, m.permute(0, 2, 1)) + (1-m) + + cross_sim = cross_sim[:, token_idx].sum(dim=1) + cross_sim = cross_sim.reshape(shape) + gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda() + cross_sim = gaussian_smoothing(cross_sim.unsqueeze(0))[0] # optional smoothing + cross_sim = cross_sim.reshape(-1) + cross_sim = ((cross_sim - torch.median(cross_sim.ravel())) / torch.max(cross_sim.ravel())).clip(0, 1) + + if painta_on and res in painta_res: + c = (1 - m) * cross_sim.reshape(1, 1, -1) + m + self_sim = self_sim * c + self_sim = self_sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', self_sim, self_v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=self_h) + out = to_out(out) + else: + out = y + return out + diff --git a/lib/smplfusion/patches/router.py b/lib/smplfusion/patches/router.py new file mode 100644 index 0000000000000000000000000000000000000000..fa03d3888c4b204eec87b64af08ea7bba870caaa --- /dev/null +++ b/lib/smplfusion/patches/router.py @@ -0,0 +1,10 @@ +from . import attentionpatch +from . import transformerpatch + +attention_forward = attentionpatch.default.forward +basic_transformer_forward = transformerpatch.default.forward + +def reset(): + global attention_forward, basic_transformer_forward + attention_forward = attentionpatch.default.forward + basic_transformer_forward = transformerpatch.default.forward diff --git a/lib/smplfusion/patches/transformerpatch/__init__.py b/lib/smplfusion/patches/transformerpatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..609e406c8f352c8e7a1bbd9574a55601b727c83b --- /dev/null +++ b/lib/smplfusion/patches/transformerpatch/__init__.py @@ -0,0 +1,2 @@ +from . import default +from . import painta diff --git a/lib/smplfusion/patches/transformerpatch/default.py b/lib/smplfusion/patches/transformerpatch/default.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfe2d5648c098b5f868f75ce5f292686506baa3 --- /dev/null +++ b/lib/smplfusion/patches/transformerpatch/default.py @@ -0,0 +1,8 @@ +import torch +from ... import share + +def forward(self, x, context=None): + x = x + self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) # Self Attn. + x = x + self.attn2(self.norm2(x), context=context) # Cross Attn. + x = x + self.ff(self.norm3(x)) + return x diff --git a/lib/smplfusion/patches/transformerpatch/painta.py b/lib/smplfusion/patches/transformerpatch/painta.py new file mode 100644 index 0000000000000000000000000000000000000000..3374a46d4b3ad7f3aeb2ee6b73f8c115a26cd57f --- /dev/null +++ b/lib/smplfusion/patches/transformerpatch/painta.py @@ -0,0 +1,50 @@ +import torch +from torch import nn, einsum +from einops import rearrange, repeat +from ... import share +from ..attentionpatch import painta + + +use_grad = True + +def forward(self, x, context=None): + # Todo: add batch inference support + if use_grad: + y, self_v, self_sim = self.attn1(self.norm1(x), None) # Self Attn. + + x_uncond, x_cond = x.chunk(2) + context_uncond, context_cond = context.chunk(2) + + y_uncond, y_cond = y.chunk(2) + self_sim_uncond, self_sim_cond = self_sim.chunk(2) + self_v_uncond, self_v_cond = self_v.chunk(2) + + # Calculate CA similarities with conditional context + cross_h = self.attn2.heads + cross_q = self.attn2.to_q(self.norm2(x_cond+y_cond)) + cross_k = self.attn2.to_k(context_cond) + cross_v = self.attn2.to_v(context_cond) + + cross_q, cross_k, cross_v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=cross_h), (cross_q, cross_k, cross_v)) + + with torch.autocast(enabled=False, device_type = 'cuda'): + cross_q, cross_k = cross_q.float(), cross_k.float() + cross_sim = einsum('b i d, b j d -> b i j', cross_q, cross_k) * self.attn2.scale + + del cross_q, cross_k + cross_sim = cross_sim.softmax(dim=-1) # Up to this point cross_sim is regular cross_sim in CA layer + + cross_sim = cross_sim.mean(dim=0) # Calculate mean across heads + + # PAIntA rescale + y_cond = painta.painta_rescale( + y_cond, self_v_cond, self_sim_cond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale cond + y_uncond = painta.painta_rescale( + y_uncond, self_v_uncond, self_sim_uncond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale uncond + + y = torch.cat([y_uncond, y_cond], dim=0) + + x = x + y + x = x + self.attn2(self.norm2(x), context=context) # Cross Attn. + x = x + self.ff(self.norm3(x)) + return x \ No newline at end of file diff --git a/lib/smplfusion/scheduler.py b/lib/smplfusion/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e39fefc55f8f8537eca5b395a0d94a60886d094 --- /dev/null +++ b/lib/smplfusion/scheduler.py @@ -0,0 +1,21 @@ +import torch + +def linear(n_timestep = 1000, start = 1e-4, end = 2e-2): + return Schedule(torch.linspace(start ** 0.5, end ** 0.5, n_timestep, dtype = torch.float64) ** 2) + +class Schedule: + def __init__(self, betas): + self.betas = betas + self._alphas = 1 - betas + self.alphas = torch.cumprod(self._alphas, 0) + self.one_minus_alphas = 1 - self.alphas + self.sqrt_alphas = torch.sqrt(self.alphas) + self.sqrt_one_minus_alphas = torch.sqrt(1 - self.alphas) + self.sqrt_noise_signal_ratio = self.sqrt_one_minus_alphas / self.sqrt_alphas + self.noise_signal_ratio = (1 - self.alphas) / self.alphas + + def range(self, dt): + return range(len(self.betas)-1, 0, -dt) + + def sigma(self, t, dt): + return torch.sqrt((1 - self._alphas[t - dt]) / (1 - self._alphas[t]) * (1 - self._alphas[t] / self._alphas[t - dt])) diff --git a/lib/smplfusion/share.py b/lib/smplfusion/share.py new file mode 100644 index 0000000000000000000000000000000000000000..3807657fdb54db02ff21f4831ad8dff0513fe559 --- /dev/null +++ b/lib/smplfusion/share.py @@ -0,0 +1,61 @@ +import torchvision.transforms.functional as TF +from lib.utils.iimage import IImage +import torch +import sys +from .utils import * + +input_mask = None +input_shape = None +timestep = None +timestep_index = None + +class Seed: + def __getitem__(self, idx): + if isinstance(idx, slice): + idx = list(range(*idx.indices(idx.stop))) + if isinstance(idx, list) or isinstance(idx, tuple): + return [self[_idx] for _idx in idx] + return 12345 ** idx % 54321 + +class DDIMIterator: + def __init__(self, iterator): + self.iterator = iterator + def __iter__(self): + self.iterator = iter(self.iterator) + global timestep_index + timestep_index = 0 + return self + def __next__(self): + global timestep, timestep_index + timestep = next(self.iterator) + timestep_index += 1 + return timestep +seed = Seed() +self = sys.modules[__name__] + +def reshape(x): + return input_shape.reshape(x) + +def set_shape(image_or_shape): + global input_shape + # if isinstance(image_or_shape, IImage): + if hasattr(image_or_shape, 'size'): + input_shape = InputShape(image_or_shape.size) + if isinstance(image_or_shape, torch.Tensor): + input_shape = InputShape(image_or_shape.shape[-2:][::-1]) + elif isinstance(image_or_shape, list) or isinstance(image_or_shape, tuple): + input_shape = InputShape(image_or_shape) + +def set_mask(mask): + global input_mask, mask64, mask32, mask16, mask8, painta_mask + input_mask = InputMask(mask) + painta_mask = InputMask(mask) + + mask64 = input_mask.val64[0,0] + mask32 = input_mask.val32[0,0] + mask16 = input_mask.val16[0,0] + mask8 = input_mask.val8[0,0] + set_shape(mask) + +def exists(name): + return hasattr(self, name) and getattr(self, name) is not None \ No newline at end of file diff --git a/lib/smplfusion/util.py b/lib/smplfusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5cfadaf9a094146e67a0bc1fe12e4b27e9cfc6f1 --- /dev/null +++ b/lib/smplfusion/util.py @@ -0,0 +1,20 @@ +import importlib +from lib.utils import IImage + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) \ No newline at end of file diff --git a/lib/smplfusion/utils/__init__.py b/lib/smplfusion/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41c1db38215f0b6367d525dad5c010ccda7f77f3 --- /dev/null +++ b/lib/smplfusion/utils/__init__.py @@ -0,0 +1,3 @@ +from .input_image import InputImage +from .input_mask import InputMask +from .input_shape import InputShape diff --git a/lib/smplfusion/utils/input_image.py b/lib/smplfusion/utils/input_image.py new file mode 100644 index 0000000000000000000000000000000000000000..5b56a7087f9f455d4676bdec304f5533724067aa --- /dev/null +++ b/lib/smplfusion/utils/input_image.py @@ -0,0 +1,62 @@ +import torch +from lib.utils.iimage import IImage + +class InputImage: + def to(self, device): return InputImage(self.image, device = device) + def cuda(self): return InputImage(self.image, device = 'cuda') + def cpu(self): return InputImage(self.image, device = 'cpu') + + def __init__(self, input_image): + ''' + args: + input_image: (b,c,h,w) tensor + ''' + if hasattr(input_image, 'is_iimage'): + self.image = input_image + self.val512 = self.full = input_image.torch(0) + elif isinstance(input_image, torch.Tensor): + self.val512 = self.full = input_image.clone() + self.image = IImage(input_image,0) + + self.h,self.w = h,w = self.val512.shape[-2:] + self.shape = [self.h, self.w] + self.shape64 = [self.h // 8, self.w // 8] + self.shape32 = [self.h // 16, self.w // 16] + self.shape16 = [self.h // 32, self.w // 32] + self.shape8 = [self.h // 64, self.w // 64] + + self.res = self.h * self.w + self.res64 = self.res // 64 + self.res32 = self.res // 64 // 4 + self.res16 = self.res // 64 // 16 + self.res8 = self.res // 64 // 64 + + self.img = self.image + self.img512 = self.image + self.img64 = self.image.resize((h//8,w//8)) + self.img32 = self.image.resize((h//16,w//16)) + self.img16 = self.image.resize((h//32,w//32)) + self.img8 = self.image.resize((h//64,w//64)) + + self.val64 = self.img64.torch() + self.val32 = self.img32.torch() + self.val16 = self.img16.torch() + self.val8 = self.img8.torch() + + def get_res(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.val64.to(device) + if q.shape[1] == self.res32: return self.val32.to(device) + if q.shape[1] == self.res16: return self.val16.to(device) + if q.shape[1] == self.res8: return self.val8.to(device) + + def get_shape(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.shape64 + if q.shape[1] == self.res32: return self.shape32 + if q.shape[1] == self.res16: return self.shape16 + if q.shape[1] == self.res8: return self.shape8 + + def get_res_val(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return 64 + if q.shape[1] == self.res32: return 32 + if q.shape[1] == self.res16: return 16 + if q.shape[1] == self.res8: return 8 diff --git a/lib/smplfusion/utils/input_mask.py b/lib/smplfusion/utils/input_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..755c297226b100ef2a83860c72ca364f62551ad2 --- /dev/null +++ b/lib/smplfusion/utils/input_mask.py @@ -0,0 +1,137 @@ +import torch +from lib.utils.iimage import IImage + +class InputMask: + def to(self, device): return InputMask(self.image, device = device) + def cuda(self): return InputMask(self.image, device = 'cuda') + def cpu(self): return InputMask(self.image, device = 'cpu') + + def __init__(self, input_image, device = 'cpu'): + ''' + args: + input_image: (b,c,h,w) tensor + ''' + if hasattr(input_image, 'is_iimage'): + self.image = input_image + self.val512 = self.full = (input_image.torch(0) > 0.5).float() + elif isinstance(input_image, torch.Tensor): + self.val512 = self.full = input_image.clone() + self.image = IImage(input_image,0) + + self.h,self.w = h,w = self.val512.shape[-2:] + self.shape = [self.h, self.w] + self.shape64 = [self.h // 8, self.w // 8] + self.shape32 = [self.h // 16, self.w // 16] + self.shape16 = [self.h // 32, self.w // 32] + self.shape8 = [self.h // 64, self.w // 64] + + self.res = self.h * self.w + self.res64 = self.res // 64 + self.res32 = self.res // 64 // 4 + self.res16 = self.res // 64 // 16 + self.res8 = self.res // 64 // 64 + + self.img = self.image + self.img512 = self.image + self.img64 = self.image.resize((h//8,w//8)) + self.img32 = self.image.resize((h//16,w//16)) + self.img16 = self.image.resize((h//32,w//32)) + self.img8 = self.image.resize((h//64,w//64)) + + self.val64 = (self.img64.torch(0) > 0.5).float() + self.val32 = (self.img32.torch(0) > 0.5).float() + self.val16 = (self.img16.torch(0) > 0.5).float() + self.val8 = ( self.img8.torch(0) > 0.5).float() + + + def get_res(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.val64.to(device) + if q.shape[1] == self.res32: return self.val32.to(device) + if q.shape[1] == self.res16: return self.val16.to(device) + if q.shape[1] == self.res8: return self.val8.to(device) + + def get_res(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.val64.to(device) + if q.shape[1] == self.res32: return self.val32.to(device) + if q.shape[1] == self.res16: return self.val16.to(device) + if q.shape[1] == self.res8: return self.val8.to(device) + + def get_shape(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.shape64 + if q.shape[1] == self.res32: return self.shape32 + if q.shape[1] == self.res16: return self.shape16 + if q.shape[1] == self.res8: return self.shape8 + + def get_res_val(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return 64 + if q.shape[1] == self.res32: return 32 + if q.shape[1] == self.res16: return 16 + if q.shape[1] == self.res8: return 8 + + +class InputMask2: + def to(self, device): return InputMask2(self.image, device = device) + def cuda(self): return InputMask2(self.image, device = 'cuda') + def cpu(self): return InputMask2(self.image, device = 'cpu') + + def __init__(self, input_image, device = 'cpu'): + ''' + args: + input_image: (b,c,h,w) tensor + ''' + if hasattr(input_image, 'is_iimage'): + self.image = input_image + self.val512 = self.full = input_image.torch(0).bool().float() + elif isinstance(input_image, torch.Tensor): + self.val512 = self.full = input_image.clone() + self.image = IImage(input_image,0) + + self.h,self.w = h,w = self.val512.shape[-2:] + self.shape = [self.h, self.w] + self.shape64 = [self.h // 8, self.w // 8] + self.shape32 = [self.h // 16, self.w // 16] + self.shape16 = [self.h // 32, self.w // 32] + self.shape8 = [self.h // 64, self.w // 64] + + self.res = self.h * self.w + self.res64 = self.res // 64 + self.res32 = self.res // 64 // 4 + self.res16 = self.res // 64 // 16 + self.res8 = self.res // 64 // 64 + + self.img = self.image + self.img512 = self.image + self.img64 = self.image.resize((h//8,w//8)) + self.img32 = self.image.resize((h//16,w//16)) + self.img16 = self.image.resize((h//32,w//32)).dilate(1) + self.img8 = self.image.resize((h//64,w//64)).dilate(1) + + self.val64 = self.img64.torch(0).bool().float() + self.val32 = self.img32.torch(0).bool().float() + self.val16 = self.img16.torch(0).bool().float() + self.val8 = self.img8.torch(0).bool().float() + + + def get_res(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.val64.to(device) + if q.shape[1] == self.res32: return self.val32.to(device) + if q.shape[1] == self.res16: return self.val16.to(device) + if q.shape[1] == self.res8: return self.val8.to(device) + + def get_res(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.val64.to(device) + if q.shape[1] == self.res32: return self.val32.to(device) + if q.shape[1] == self.res16: return self.val16.to(device) + if q.shape[1] == self.res8: return self.val8.to(device) + + def get_shape(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return self.shape64 + if q.shape[1] == self.res32: return self.shape32 + if q.shape[1] == self.res16: return self.shape16 + if q.shape[1] == self.res8: return self.shape8 + + def get_res_val(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return 64 + if q.shape[1] == self.res32: return 32 + if q.shape[1] == self.res16: return 16 + if q.shape[1] == self.res8: return 8 \ No newline at end of file diff --git a/lib/smplfusion/utils/input_shape.py b/lib/smplfusion/utils/input_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..abeb21fac1a30d64f2dcbec5991440df7d138055 --- /dev/null +++ b/lib/smplfusion/utils/input_shape.py @@ -0,0 +1,27 @@ +class InputShape: + def __init__(self, image_size): + self.h,self.w = image_size[::-1] + self.res = self.h * self.w + self.res64 = self.res // 64 + self.res32 = self.res // 64 // 4 + self.res16 = self.res // 64 // 16 + self.res8 = self.res // 64 // 64 + self.shape = [self.h, self.w] + self.shape64 = [self.h // 8, self.w // 8] + self.shape32 = [self.h // 16, self.w // 16] + self.shape16 = [self.h // 32, self.w // 32] + self.shape8 = [self.h // 64, self.w // 64] + + def reshape(self, x): + assert len(x.shape) == 3 + if x.shape[1] == self.res64: return x.reshape([x.shape[0]] + self.shape64 + [x.shape[-1]]) + if x.shape[1] == self.res32: return x.reshape([x.shape[0]] + self.shape32 + [x.shape[-1]]) + if x.shape[1] == self.res16: return x.reshape([x.shape[0]] + self.shape16 + [x.shape[-1]]) + if x.shape[1] == self.res8: return x.reshape([x.shape[0]] + self.shape8 + [x.shape[-1]]) + raise Exception("Unknown shape") + + def get_res(self, q, device = 'cpu'): + if q.shape[1] == self.res64: return 64 + if q.shape[1] == self.res32: return 32 + if q.shape[1] == self.res16: return 16 + if q.shape[1] == self.res8: return 8 \ No newline at end of file diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f07d571faa4022a8312a7094ea9a2706bc89e11 --- /dev/null +++ b/lib/utils/__init__.py @@ -0,0 +1,76 @@ +import base64 + +import cv2 +import numpy as np +import open_clip +from PIL import Image +from tqdm import tqdm + +from .iimage import IImage + + +def tokenize(prompt): + tokens = open_clip.tokenize(prompt)[0] + return [open_clip.tokenizer._tokenizer.decoder[x.item()] for x in tokens] + + +def poisson_blend( + orig_img: np.ndarray, + fake_img: np.ndarray, + mask: np.ndarray, + pad_width: int = 32, + dilation: int = 48 +) -> np.ndarray: + """Does poisson blending with some tricks. + + Args: + orig_img (np.ndarray): Original image. + fake_img (np.ndarray): Generated fake image to blend. + mask (np.ndarray): Binary 0-1 mask to use for blending. + pad_width (np.ndarray): Amount of padding to add before blending (useful to avoid some issues). + dilation (np.ndarray): Amount of dilation to add to the mask before blending (useful to avoid some issues). + + Returns: + np.ndarray: Blended image. + """ + mask = mask[:, :, 0] + padding_config = ((pad_width, pad_width), (pad_width, pad_width), (0, 0)) + padded_fake_img = np.pad(fake_img, pad_width=padding_config, mode="reflect") + padded_orig_img = np.pad(orig_img, pad_width=padding_config, mode="reflect") + padded_orig_img[:pad_width, :, :] = padded_fake_img[:pad_width, :, :] + padded_orig_img[:, :pad_width, :] = padded_fake_img[:, :pad_width, :] + padded_orig_img[-pad_width:, :, :] = padded_fake_img[-pad_width:, :, :] + padded_orig_img[:, -pad_width:, :] = padded_fake_img[:, -pad_width:, :] + padded_mask = np.pad(mask, pad_width=padding_config[:2], mode="constant") + padded_dmask = cv2.dilate(padded_mask, np.ones((dilation, dilation), np.uint8), iterations=1) + x_min, y_min, rect_w, rect_h = cv2.boundingRect(padded_dmask) + center = (x_min + rect_w // 2, y_min + rect_h // 2) + output = cv2.seamlessClone(padded_fake_img, padded_orig_img, padded_dmask, center, cv2.NORMAL_CLONE) + output = output[pad_width:-pad_width, pad_width:-pad_width] + return output + + +def image_from_url_text(filedata): + if filedata is None: + return None + + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): + filedata = filedata[0] + + if type(filedata) == dict and filedata.get("is_file", False): + filename = filedata["name"] + filename = filename.rsplit('?', 1)[0] + return Image.open(filename) + + if type(filedata) == list: + if len(filedata) == 0: + return None + + filedata = filedata[0] + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + filedata = base64.decodebytes(filedata.encode('utf-8')) + image = Image.open(io.BytesIO(filedata)) + return image diff --git a/lib/utils/iimage.py b/lib/utils/iimage.py new file mode 100644 index 0000000000000000000000000000000000000000..a217b853681b8451713149acb2c08392a25ff99a --- /dev/null +++ b/lib/utils/iimage.py @@ -0,0 +1,180 @@ +import io +import math +import os +import warnings + +import PIL.Image +import numpy as np +import cv2 +import torch +import torchvision.transforms.functional as tvF +from scipy.ndimage import binary_dilation + + +def stack(images, axis = 0): + return IImage(np.concatenate([x.data for x in images], axis)) + + +def torch2np(x, vmin=-1, vmax=1): + if x.ndim != 4: + # raise Exception("Please only use (B,C,H,W) torch tensors!") + warnings.warn( + "Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!") + if x.ndim == 3: + x = x[None] + if x.ndim == 2: + x = x[None, None] + x = x.detach().cpu().float() + if x.dtype == torch.uint8: + return x.numpy().astype(np.uint8) + elif vmin is not None and vmax is not None: + x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin)) + x = x.permute(0, 2, 3, 1).to(torch.uint8) + return x.numpy() + else: + raise NotImplementedError() + + +class IImage: + @staticmethod + def open(path): + data = np.array(PIL.Image.open(path)) + if data.ndim == 3: + data = data[..., None] + image = IImage(data) + return image + + @staticmethod + def normalized(x, dims=[-1, -2]): + x = (x - x.amin(dims, True)) / \ + (x.amax(dims, True) - x.amin(dims, True)) + return IImage(x, 0) + + def numpy(self): return self.data + + def torch(self, vmin=-1, vmax=1): + if self.data.ndim == 3: + data = self.data.transpose(2, 0, 1) / 255. + else: + data = self.data.transpose(0, 3, 1, 2) / 255. + return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin) + + def cuda(self): + self.device = 'cuda' + return self + + def cpu(self): + self.device = 'cpu' + return self + + def pil(self): + ans = [] + for x in self.data: + if x.shape[-1] == 1: + x = x[..., 0] + + ans.append(PIL.Image.fromarray(x)) + if len(ans) == 1: + return ans[0] + return ans + + def is_iimage(self): + return True + + @property + def shape(self): return self.data.shape + @property + def size(self): return (self.data.shape[-2], self.data.shape[-3]) + + def __init__(self, x, vmin=-1, vmax=1): + if isinstance(x, PIL.Image.Image): + self.data = np.array(x) + if self.data.ndim == 2: + self.data = self.data[..., None] # (H,W,C) + self.data = self.data[None] # (B,H,W,C) + elif isinstance(x, IImage): + self.data = x.data.copy() # Simple Copy + elif isinstance(x, np.ndarray): + self.data = x.copy().astype(np.uint8) + if self.data.ndim == 2: + self.data = self.data[None, ..., None] + if self.data.ndim == 3: + warnings.warn( + "Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)") + self.data = self.data[None] + elif isinstance(x, torch.Tensor): + self.data = torch2np(x, vmin, vmax) + self.device = 'cpu' + + def resize(self, size, *args, **kwargs): + if size is None: + return self + use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False) + + resample = kwargs.pop('filter', PIL.Image.BICUBIC) # Backward compatibility + resample = kwargs.pop('resample', resample) + + if isinstance(size, int): + if use_small_edge_when_int: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (max(size, int(size * aspect_ratio)), + max(size, int(size / aspect_ratio))) + else: + h, w = self.data.shape[1:3] + aspect_ratio = h / w + size = (min(size, int(size * aspect_ratio)), + min(size, int(size / aspect_ratio))) + + if self.size == size[::-1]: + return self + return stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self]) + + def pad(self, padding, *args, **kwargs): + return IImage(tvF.pad(self.torch(0), padding=padding, *args, **kwargs), 0) + + def padx(self, multiplier, *args, **kwargs): + size = np.array(self.size) + padding = np.concatenate( + [[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size]) + return self.pad(list(padding), *args, **kwargs) + + def pad2wh(self, w=0, h=0, **kwargs): + cw, ch = self.size + return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs) + + def pad2square(self, *args, **kwargs): + if self.size[0] > self.size[1]: + dx = self.size[0] - self.size[1] + return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs) + elif self.size[0] < self.size[1]: + dx = self.size[1] - self.size[0] + return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs) + return self + + def alpha(self): + return IImage(self.data[..., -1, None]) + + def rgb(self): + return IImage(self.pil().convert('RGB')) + + def dilate(self, iterations=1, *args, **kwargs): + return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) + + def save(self, path): + _, ext = os.path.splitext(path) + data = self.data if self.data.ndim == 3 else self.data[0] + PIL.Image.fromarray(data).save(path) + return self + + def crop(self, bbox): + assert len(bbox) in [2,4] + if len(bbox) == 2: + x,y = 0,0 + w,h = bbox + elif len(bbox) == 4: + x, y, w, h = bbox + return IImage(self.data[:, y:y+h, x:x+w, :]) + + def __getitem__(self, idx): + return IImage(self.data[None, idx]) diff --git a/lib/utils/scores.py b/lib/utils/scores.py new file mode 100644 index 0000000000000000000000000000000000000000..2b703302cee77c282e88a4fbce7aed8402322ba6 --- /dev/null +++ b/lib/utils/scores.py @@ -0,0 +1,31 @@ +import torch +from torch import nn +import torch.nn.functional as F + +def l1(_crossattn_similarity, mask, token_idx = [1,2]): + similarity = torch.cat(_crossattn_similarity,1)[1] + similarity = similarity.mean(0).permute(2,0,1) + # similarity = similarity.softmax(dim = 0) + + return (similarity[token_idx] * mask.cuda()).sum() + +def bce(_crossattn_similarity, mask, token_idx = [1,2]): + similarity = torch.cat(_crossattn_similarity,1)[1] + similarity = similarity.mean(0).permute(2,0,1) + # similarity = similarity.softmax(dim = 0) + + return -sum([ + F.binary_cross_entropy_with_logits(x - 1.0, mask.cuda()) + for x in similarity[token_idx] + ]) + +def softmax(_crossattn_similarity, mask, token_idx = [1,2]): + similarity = torch.cat(_crossattn_similarity,1)[1] + similarity = similarity.mean(0).permute(2,0,1) + + similarity = similarity[1:].softmax(dim = 0) # Comute the softmax to obtain probability values + token_idx = [x - 1 for x in token_idx] + + score = similarity[token_idx].sum(dim = 0) # Sum up all relevant tokens to get pixel-wise probability of belonging to the correct class + score = torch.log(score) # Obtain log-probabilities per-pixel + return (score * mask.cuda()).sum() # Sum up log-probabilities (equivalent to multiplying P-values) for all pixels inside of the mask \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..aeeac8a517dfaf20a2a72351229744f1cdb76015 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 + +einops==0.7.0 +gradio==3.47.1 +numpy==1.24.1 +omegaconf==2.3.0 +open-clip-torch==2.23.0 +opencv-python==4.7.* +Pillow==9.4.0 +pytorch-lightning==2.1.2 +PyYAML==6.0.1 +safetensors==0.3.2 +scipy==1.10.0 +segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 +torch==2.1.1 +torchvision==0.16.1 +tqdm==4.66.1 +transformers==4.28.0 +xformers==0.0.23 \ No newline at end of file diff --git a/script.js b/script.js new file mode 100644 index 0000000000000000000000000000000000000000..a1d0e4c8bd0f078c2a784aa45a707a731d3b41ff --- /dev/null +++ b/script.js @@ -0,0 +1,39 @@ +function demo_load(x) { + document.body.scrollTop = document.documentElement.scrollTop = 0; + + function gradioApp() { + const elems = document.getElementsByTagName('gradio-app'); + const elem = elems.length == 0 ? document : elems[0]; + + if (elem !== document) { + elem.getElementById = function(id) { + return document.getElementById(id); + }; + } + return elem.shadowRoot ? elem.shadowRoot : elem; + } + + function all_gallery_buttons() { + var allGalleryButtons = gradioApp().querySelectorAll('#outputgallery .thumbnail-item.thumbnail-small'); + var visibleGalleryButtons = []; + allGalleryButtons.forEach(function(elem) { + if (elem.parentElement.offsetParent) { + visibleGalleryButtons.push(elem); + } + }); + return visibleGalleryButtons; + } + + function selected_gallery_button() { + return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null; + } + + function selected_gallery_index() { + return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected')); + } + + window.gradioApp = gradioApp + window.all_gallery_buttons = all_gallery_buttons + window.selected_gallery_button = selected_gallery_button + window.selected_gallery_index = selected_gallery_index +} \ No newline at end of file diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..eeecbd0cceccc5d6b0e1e17eba535e05665800d6 --- /dev/null +++ b/style.css @@ -0,0 +1,32 @@ +/* Extra small devices (phones, 768px and down) */ +@media only screen and (max-width: 768px) { + #inputmask { + height: 400px !important; + } +} + +/* Small devices (portrait tablets and large phones, 768px and up) */ +@media only screen and (min-width: 768px) { + +} + +/* Medium devices (landscape tablets, 992px and up) */ +@media only screen and (min-width: 992px) { + #inputmask { + height: 300px !important; + } +} + +/* Large devices (laptops/desktops, 1200px and up) */ +@media only screen and (min-width: 1200px) { + #inputmask { + height: 400px !important; + } +} + +/* Extra large devices (large laptops and desktops, 1400px and up) */ +@media only screen and (min-width: 1400px) { + #inputmask { + height: 400px !important; + } +} \ No newline at end of file