Spaces:
Running
on
L4
Running
on
L4
Initial commit
Browse files- .gitattributes +2 -0
- LICENSE.md +51 -0
- README.md +11 -5
- app.py +357 -0
- demo_files/comp.gif +3 -0
- demo_files/examples/animal_character.png +3 -0
- demo_files/examples/animal_character_2.png +3 -0
- demo_files/examples/axe.png +3 -0
- demo_files/examples/chair1.png +3 -0
- demo_files/examples/character1.png +3 -0
- demo_files/examples/otter_samurai.png +3 -0
- demo_files/examples/raccoon_wizard.png +3 -0
- demo_files/examples/stylized-rocks.png +3 -0
- demo_files/examples/tree.png +3 -0
- demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
- demo_files/hdri/metro_noord_1k.hdr +0 -0
- demo_files/hdri/neon_photostudio_1k.hdr +0 -0
- demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
- demo_files/hdri/rainforest_trail_1k.hdr +0 -0
- demo_files/hdri/studio_small_08_1k.hdr +0 -0
- demo_files/hdri/urban_alley_01_1k.hdr +0 -0
- demo_files/scatterplot.jpg +0 -0
- demo_files/teaser.gif +3 -0
- load/tets/160_tets.npz +3 -0
- requirements.txt +13 -0
- sf3d/box_uv_unwrap.py +610 -0
- sf3d/models/camera.py +32 -0
- sf3d/models/global_estimator/multi_head_estimator.py +118 -0
- sf3d/models/image_estimator/clip_based_estimator.py +168 -0
- sf3d/models/isosurface.py +229 -0
- sf3d/models/mesh.py +172 -0
- sf3d/models/network.py +195 -0
- sf3d/models/tokenizers/dinov2.py +1196 -0
- sf3d/models/tokenizers/image.py +99 -0
- sf3d/models/tokenizers/triplane.py +49 -0
- sf3d/models/transformers/attention.py +31 -0
- sf3d/models/transformers/backbone.py +515 -0
- sf3d/models/utils.py +292 -0
- sf3d/system.py +483 -0
- sf3d/texture_baker.py +87 -0
- sf3d/texture_baker.slang +93 -0
- sf3d/utils.py +91 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT
|
2 |
+
Last Updated: July 5, 2024
|
3 |
+
|
4 |
+
|
5 |
+
I. INTRODUCTION
|
6 |
+
|
7 |
+
This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
8 |
+
|
9 |
+
|
10 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
11 |
+
|
12 |
+
|
13 |
+
By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
|
14 |
+
|
15 |
+
II. RESEARCH & NON-COMMERCIAL USE LICENSE
|
16 |
+
|
17 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
18 |
+
|
19 |
+
III. COMMERCIAL USE LICENSE
|
20 |
+
|
21 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
|
22 |
+
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
23 |
+
|
24 |
+
IV. GENERAL TERMS
|
25 |
+
|
26 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
|
27 |
+
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
|
28 |
+
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
|
29 |
+
c. Intellectual Property.
|
30 |
+
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
|
31 |
+
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
32 |
+
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
|
33 |
+
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
|
34 |
+
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
|
35 |
+
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
36 |
+
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
37 |
+
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
|
38 |
+
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
39 |
+
|
40 |
+
V. DEFINITIONS
|
41 |
+
|
42 |
+
"Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
|
43 |
+
"Agreement" means this Stability AI Community License Agreement.
|
44 |
+
"AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
45 |
+
"Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
|
46 |
+
"Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
|
47 |
+
"Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
|
48 |
+
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
|
49 |
+
"Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
|
50 |
+
"Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
|
51 |
+
"Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
README.md
CHANGED
@@ -1,12 +1,18 @@
|
|
1 |
---
|
2 |
-
title: Stable Fast
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Stable Fast 3D
|
3 |
+
emoji: 🎮
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.31.4
|
8 |
+
python_version: 3.10.13
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
+
models:
|
12 |
+
- stabilityai/stable-fast-3d
|
13 |
+
license: other
|
14 |
+
license_name: stabilityai-ai-community
|
15 |
+
license_link: LICENSE.md
|
16 |
---
|
17 |
|
18 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import rembg
|
10 |
+
import torch
|
11 |
+
from gradio_litmodel3d import LitModel3D
|
12 |
+
import spaces
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
import sf3d.utils as sf3d_utils
|
16 |
+
from sf3d.system import SF3D
|
17 |
+
|
18 |
+
rembg_session = rembg.new_session()
|
19 |
+
|
20 |
+
COND_WIDTH = 512
|
21 |
+
COND_HEIGHT = 512
|
22 |
+
COND_DISTANCE = 1.6
|
23 |
+
COND_FOVY_DEG = 40
|
24 |
+
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
25 |
+
|
26 |
+
# Cached. Doesn't change
|
27 |
+
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
|
28 |
+
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
29 |
+
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
model = SF3D.from_pretrained(
|
34 |
+
"stabilityai/stable-fast-3d",
|
35 |
+
config_name="config.yaml",
|
36 |
+
weight_name="model.safetensors",
|
37 |
+
)
|
38 |
+
model.eval().cuda()
|
39 |
+
|
40 |
+
example_files = [
|
41 |
+
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
@spaces.GPU
|
46 |
+
def run_model(input_image):
|
47 |
+
start = time.time()
|
48 |
+
with torch.no_grad():
|
49 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
50 |
+
model_batch = create_batch(input_image)
|
51 |
+
model_batch = {k: v.cuda() for k, v in model_batch.items()}
|
52 |
+
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
|
53 |
+
trimesh_mesh = trimesh_mesh[0]
|
54 |
+
|
55 |
+
# Create new tmp file
|
56 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
57 |
+
|
58 |
+
trimesh_mesh.export(tmp_file.name, file_type="glb")
|
59 |
+
|
60 |
+
print("Generation took:", time.time() - start, "s")
|
61 |
+
|
62 |
+
return tmp_file.name
|
63 |
+
|
64 |
+
|
65 |
+
def create_batch(input_image: Image) -> dict[str, Any]:
|
66 |
+
img_cond = (
|
67 |
+
torch.from_numpy(
|
68 |
+
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
|
69 |
+
/ 255.0
|
70 |
+
)
|
71 |
+
.float()
|
72 |
+
.clip(0, 1)
|
73 |
+
)
|
74 |
+
mask_cond = img_cond[:, :, -1:]
|
75 |
+
rgb_cond = torch.lerp(
|
76 |
+
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
|
77 |
+
)
|
78 |
+
|
79 |
+
batch_elem = {
|
80 |
+
"rgb_cond": rgb_cond,
|
81 |
+
"mask_cond": mask_cond,
|
82 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
83 |
+
"intrinsic_cond": intrinsic.unsqueeze(0),
|
84 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
85 |
+
}
|
86 |
+
# Add batch dim
|
87 |
+
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
88 |
+
return batched
|
89 |
+
|
90 |
+
|
91 |
+
@lru_cache
|
92 |
+
def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
93 |
+
base = np.zeros((squares, squares)) + min_value
|
94 |
+
base[1::2, ::2] = 1
|
95 |
+
base[::2, 1::2] = 1
|
96 |
+
|
97 |
+
repeat_mult = size // squares
|
98 |
+
return (
|
99 |
+
base.repeat(repeat_mult, axis=0)
|
100 |
+
.repeat(repeat_mult, axis=1)[:, :, None]
|
101 |
+
.repeat(3, axis=-1)
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def remove_background(input_image: Image) -> Image:
|
106 |
+
return rembg.remove(input_image, session=rembg_session)
|
107 |
+
|
108 |
+
|
109 |
+
def resize_foreground(
|
110 |
+
image: Image,
|
111 |
+
ratio: float,
|
112 |
+
) -> Image:
|
113 |
+
image = np.array(image)
|
114 |
+
assert image.shape[-1] == 4
|
115 |
+
alpha = np.where(image[..., 3] > 0)
|
116 |
+
y1, y2, x1, x2 = (
|
117 |
+
alpha[0].min(),
|
118 |
+
alpha[0].max(),
|
119 |
+
alpha[1].min(),
|
120 |
+
alpha[1].max(),
|
121 |
+
)
|
122 |
+
# crop the foreground
|
123 |
+
fg = image[y1:y2, x1:x2]
|
124 |
+
# pad to square
|
125 |
+
size = max(fg.shape[0], fg.shape[1])
|
126 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
127 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
128 |
+
new_image = np.pad(
|
129 |
+
fg,
|
130 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
131 |
+
mode="constant",
|
132 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
133 |
+
)
|
134 |
+
|
135 |
+
# compute padding according to the ratio
|
136 |
+
new_size = int(new_image.shape[0] / ratio)
|
137 |
+
# pad to size, double side
|
138 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
139 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
140 |
+
new_image = np.pad(
|
141 |
+
new_image,
|
142 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
143 |
+
mode="constant",
|
144 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
145 |
+
)
|
146 |
+
new_image = Image.fromarray(new_image, mode="RGBA").resize(
|
147 |
+
(COND_WIDTH, COND_HEIGHT)
|
148 |
+
)
|
149 |
+
return new_image
|
150 |
+
|
151 |
+
|
152 |
+
def square_crop(input_image: Image) -> Image:
|
153 |
+
# Perform a center square crop
|
154 |
+
min_size = min(input_image.size)
|
155 |
+
left = (input_image.size[0] - min_size) // 2
|
156 |
+
top = (input_image.size[1] - min_size) // 2
|
157 |
+
right = (input_image.size[0] + min_size) // 2
|
158 |
+
bottom = (input_image.size[1] + min_size) // 2
|
159 |
+
return input_image.crop((left, top, right, bottom)).resize(
|
160 |
+
(COND_WIDTH, COND_HEIGHT)
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def show_mask_img(input_image: Image) -> Image:
|
165 |
+
img_numpy = np.array(input_image)
|
166 |
+
alpha = img_numpy[:, :, 3] / 255.0
|
167 |
+
chkb = checkerboard(32, 512) * 255
|
168 |
+
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
|
169 |
+
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
170 |
+
|
171 |
+
|
172 |
+
def run_button(run_btn, input_image, background_state, foreground_ratio):
|
173 |
+
if run_btn == "Run":
|
174 |
+
glb_file: str = run_model(background_state)
|
175 |
+
|
176 |
+
return (
|
177 |
+
gr.update(),
|
178 |
+
gr.update(),
|
179 |
+
gr.update(),
|
180 |
+
gr.update(),
|
181 |
+
gr.update(value=glb_file, visible=True),
|
182 |
+
gr.update(visible=True),
|
183 |
+
)
|
184 |
+
elif run_btn == "Remove Background":
|
185 |
+
rem_removed = remove_background(input_image)
|
186 |
+
|
187 |
+
sqr_crop = square_crop(rem_removed)
|
188 |
+
fr_res = resize_foreground(sqr_crop, foreground_ratio)
|
189 |
+
|
190 |
+
return (
|
191 |
+
gr.update(value="Run", visible=True),
|
192 |
+
sqr_crop,
|
193 |
+
fr_res,
|
194 |
+
gr.update(value=show_mask_img(fr_res), visible=True),
|
195 |
+
gr.update(value=None, visible=False),
|
196 |
+
gr.update(visible=False),
|
197 |
+
)
|
198 |
+
|
199 |
+
|
200 |
+
def requires_bg_remove(image, fr):
|
201 |
+
if image is None:
|
202 |
+
return (
|
203 |
+
gr.update(visible=False, value="Run"),
|
204 |
+
None,
|
205 |
+
None,
|
206 |
+
gr.update(value=None, visible=False),
|
207 |
+
gr.update(visible=False),
|
208 |
+
gr.update(visible=False),
|
209 |
+
)
|
210 |
+
alpha_channel = np.array(image.getchannel("A"))
|
211 |
+
min_alpha = alpha_channel.min()
|
212 |
+
|
213 |
+
if min_alpha == 0:
|
214 |
+
print("Already has alpha")
|
215 |
+
sqr_crop = square_crop(image)
|
216 |
+
fr_res = resize_foreground(sqr_crop, fr)
|
217 |
+
return (
|
218 |
+
gr.update(value="Run", visible=True),
|
219 |
+
sqr_crop,
|
220 |
+
fr_res,
|
221 |
+
gr.update(value=show_mask_img(fr_res), visible=True),
|
222 |
+
gr.update(visible=False),
|
223 |
+
gr.update(visible=False),
|
224 |
+
)
|
225 |
+
return (
|
226 |
+
gr.update(value="Remove Background", visible=True),
|
227 |
+
None,
|
228 |
+
None,
|
229 |
+
gr.update(value=None, visible=False),
|
230 |
+
gr.update(visible=False),
|
231 |
+
gr.update(visible=False),
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
def update_foreground_ratio(img_proc, fr):
|
236 |
+
foreground_res = resize_foreground(img_proc, fr)
|
237 |
+
return (
|
238 |
+
foreground_res,
|
239 |
+
gr.update(value=show_mask_img(foreground_res)),
|
240 |
+
)
|
241 |
+
|
242 |
+
|
243 |
+
with gr.Blocks() as demo:
|
244 |
+
img_proc_state = gr.State()
|
245 |
+
background_remove_state = gr.State()
|
246 |
+
gr.Markdown("""
|
247 |
+
# SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement
|
248 |
+
|
249 |
+
**SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
|
250 |
+
This demo allows you to upload an image and generate a 3D mesh model from it.
|
251 |
+
|
252 |
+
**Tips**
|
253 |
+
1. If the image already has an alpha channel, you can skip the background removal step.
|
254 |
+
2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
|
255 |
+
3. You can upload your own HDR environment map to light the 3D model.
|
256 |
+
""")
|
257 |
+
with gr.Row(variant="panel"):
|
258 |
+
with gr.Column():
|
259 |
+
with gr.Row():
|
260 |
+
input_img = gr.Image(
|
261 |
+
type="pil", label="Input Image", sources="upload", image_mode="RGBA"
|
262 |
+
)
|
263 |
+
preview_removal = gr.Image(
|
264 |
+
label="Preview Background Removal",
|
265 |
+
type="pil",
|
266 |
+
image_mode="RGB",
|
267 |
+
interactive=False,
|
268 |
+
visible=False,
|
269 |
+
)
|
270 |
+
|
271 |
+
foreground_ratio = gr.Slider(
|
272 |
+
label="Foreground Ratio",
|
273 |
+
minimum=0.5,
|
274 |
+
maximum=1.0,
|
275 |
+
value=0.85,
|
276 |
+
step=0.05,
|
277 |
+
)
|
278 |
+
|
279 |
+
foreground_ratio.change(
|
280 |
+
update_foreground_ratio,
|
281 |
+
inputs=[img_proc_state, foreground_ratio],
|
282 |
+
outputs=[background_remove_state, preview_removal],
|
283 |
+
)
|
284 |
+
|
285 |
+
run_btn = gr.Button("Run", variant="primary", visible=False)
|
286 |
+
|
287 |
+
with gr.Column():
|
288 |
+
output_3d = LitModel3D(
|
289 |
+
label="3D Model",
|
290 |
+
visible=False,
|
291 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
292 |
+
tonemapping="aces",
|
293 |
+
contrast=1.0,
|
294 |
+
scale=1.0,
|
295 |
+
)
|
296 |
+
with gr.Column(visible=False, scale=1.0) as hdr_row:
|
297 |
+
gr.Markdown("""## HDR Environment Map
|
298 |
+
|
299 |
+
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
|
300 |
+
""")
|
301 |
+
|
302 |
+
with gr.Row():
|
303 |
+
hdr_illumination_file = gr.File(
|
304 |
+
label="HDR Env Map", file_types=[".hdr"], file_count="single"
|
305 |
+
)
|
306 |
+
example_hdris = [
|
307 |
+
os.path.join("demo_files/hdri", f)
|
308 |
+
for f in os.listdir("demo_files/hdri")
|
309 |
+
]
|
310 |
+
hdr_illumination_example = gr.Examples(
|
311 |
+
examples=example_hdris,
|
312 |
+
inputs=hdr_illumination_file,
|
313 |
+
)
|
314 |
+
|
315 |
+
hdr_illumination_file.change(
|
316 |
+
lambda x: gr.update(env_map=x.name if x is not None else None),
|
317 |
+
inputs=hdr_illumination_file,
|
318 |
+
outputs=[output_3d],
|
319 |
+
)
|
320 |
+
|
321 |
+
examples = gr.Examples(
|
322 |
+
examples=example_files,
|
323 |
+
inputs=input_img,
|
324 |
+
)
|
325 |
+
|
326 |
+
input_img.change(
|
327 |
+
requires_bg_remove,
|
328 |
+
inputs=[input_img, foreground_ratio],
|
329 |
+
outputs=[
|
330 |
+
run_btn,
|
331 |
+
img_proc_state,
|
332 |
+
background_remove_state,
|
333 |
+
preview_removal,
|
334 |
+
output_3d,
|
335 |
+
hdr_row,
|
336 |
+
],
|
337 |
+
)
|
338 |
+
|
339 |
+
run_btn.click(
|
340 |
+
run_button,
|
341 |
+
inputs=[
|
342 |
+
run_btn,
|
343 |
+
input_img,
|
344 |
+
background_remove_state,
|
345 |
+
foreground_ratio,
|
346 |
+
],
|
347 |
+
outputs=[
|
348 |
+
run_btn,
|
349 |
+
img_proc_state,
|
350 |
+
background_remove_state,
|
351 |
+
preview_removal,
|
352 |
+
output_3d,
|
353 |
+
hdr_row,
|
354 |
+
],
|
355 |
+
)
|
356 |
+
|
357 |
+
demo.launch()
|
demo_files/comp.gif
ADDED
Git LFS Details
|
demo_files/examples/animal_character.png
ADDED
Git LFS Details
|
demo_files/examples/animal_character_2.png
ADDED
Git LFS Details
|
demo_files/examples/axe.png
ADDED
Git LFS Details
|
demo_files/examples/chair1.png
ADDED
Git LFS Details
|
demo_files/examples/character1.png
ADDED
Git LFS Details
|
demo_files/examples/otter_samurai.png
ADDED
Git LFS Details
|
demo_files/examples/raccoon_wizard.png
ADDED
Git LFS Details
|
demo_files/examples/stylized-rocks.png
ADDED
Git LFS Details
|
demo_files/examples/tree.png
ADDED
Git LFS Details
|
demo_files/hdri/abandoned_tiled_room_1k.hdr
ADDED
Binary file (478 kB). View file
|
|
demo_files/hdri/metro_noord_1k.hdr
ADDED
Binary file (467 kB). View file
|
|
demo_files/hdri/neon_photostudio_1k.hdr
ADDED
Binary file (438 kB). View file
|
|
demo_files/hdri/peppermint_powerplant_1k.hdr
ADDED
Binary file (473 kB). View file
|
|
demo_files/hdri/rainforest_trail_1k.hdr
ADDED
Binary file (512 kB). View file
|
|
demo_files/hdri/studio_small_08_1k.hdr
ADDED
Binary file (412 kB). View file
|
|
demo_files/hdri/urban_alley_01_1k.hdr
ADDED
Binary file (458 kB). View file
|
|
demo_files/scatterplot.jpg
ADDED
demo_files/teaser.gif
ADDED
Git LFS Details
|
load/tets/160_tets.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
|
3 |
+
size 15408790
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.2
|
2 |
+
torchvision==0.16.2
|
3 |
+
einops==0.7.0
|
4 |
+
jaxtyping==0.2.31
|
5 |
+
omegaconf==2.3.0
|
6 |
+
transformers==4.42.3
|
7 |
+
slangtorch==1.2.2
|
8 |
+
open_clip_torch==2.24.0
|
9 |
+
trimesh==4.4.1
|
10 |
+
numpy==1.26.4
|
11 |
+
huggingface-hub==0.23.4
|
12 |
+
rembg[gpu]==2.0.57
|
13 |
+
gradio-litmodel3d==0.0.1
|
sf3d/box_uv_unwrap.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from jaxtyping import Float, Integer
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from sf3d.models.utils import dot, triangle_intersection_2d
|
10 |
+
|
11 |
+
|
12 |
+
def _box_assign_vertex_to_cube_face(
|
13 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
14 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
15 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
16 |
+
bbox: Float[Tensor, "2 3"],
|
17 |
+
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
|
18 |
+
# Test to not have a scaled model to fit the space better
|
19 |
+
# bbox_min = bbox[:1].mean(-1, keepdim=True)
|
20 |
+
# bbox_max = bbox[1:].mean(-1, keepdim=True)
|
21 |
+
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
|
22 |
+
|
23 |
+
# Create a [0, 1] normalized vertex position
|
24 |
+
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
|
25 |
+
# And to [-1, 1]
|
26 |
+
v_pos_normalized = 2.0 * v_pos_normalized - 1.0
|
27 |
+
|
28 |
+
# Get all vertex positions for each triangle
|
29 |
+
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
|
30 |
+
v0 = v_pos_normalized[triangle_idxs[:, 0]]
|
31 |
+
v1 = v_pos_normalized[triangle_idxs[:, 1]]
|
32 |
+
v2 = v_pos_normalized[triangle_idxs[:, 2]]
|
33 |
+
tri_stack = torch.stack([v0, v1, v2], dim=1)
|
34 |
+
|
35 |
+
vn0 = vertex_normals[triangle_idxs[:, 0]]
|
36 |
+
vn1 = vertex_normals[triangle_idxs[:, 1]]
|
37 |
+
vn2 = vertex_normals[triangle_idxs[:, 2]]
|
38 |
+
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
|
39 |
+
|
40 |
+
# Just average the normals per face
|
41 |
+
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
|
42 |
+
|
43 |
+
# Now decide based on the face normal in which box map we project
|
44 |
+
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
|
45 |
+
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
|
46 |
+
|
47 |
+
axis = torch.tensor(
|
48 |
+
[
|
49 |
+
[1, 0, 0], # 0
|
50 |
+
[-1, 0, 0], # 1
|
51 |
+
[0, 1, 0], # 2
|
52 |
+
[0, -1, 0], # 3
|
53 |
+
[0, 0, 1], # 4
|
54 |
+
[0, 0, -1], # 5
|
55 |
+
],
|
56 |
+
device=face_normal.device,
|
57 |
+
dtype=face_normal.dtype,
|
58 |
+
)
|
59 |
+
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
|
60 |
+
index = face_normal_axis.argmax(-1)
|
61 |
+
|
62 |
+
max_axis, uc, vc = (
|
63 |
+
torch.ones_like(abs_x),
|
64 |
+
torch.zeros_like(tri_stack[..., :1]),
|
65 |
+
torch.zeros_like(tri_stack[..., :1]),
|
66 |
+
)
|
67 |
+
mask_pos_x = index == 0
|
68 |
+
max_axis[mask_pos_x] = abs_x[mask_pos_x]
|
69 |
+
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
|
70 |
+
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
|
71 |
+
|
72 |
+
mask_neg_x = index == 1
|
73 |
+
max_axis[mask_neg_x] = abs_x[mask_neg_x]
|
74 |
+
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
|
75 |
+
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
|
76 |
+
|
77 |
+
mask_pos_y = index == 2
|
78 |
+
max_axis[mask_pos_y] = abs_y[mask_pos_y]
|
79 |
+
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
|
80 |
+
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
|
81 |
+
|
82 |
+
mask_neg_y = index == 3
|
83 |
+
max_axis[mask_neg_y] = abs_y[mask_neg_y]
|
84 |
+
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
|
85 |
+
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
|
86 |
+
|
87 |
+
mask_pos_z = index == 4
|
88 |
+
max_axis[mask_pos_z] = abs_z[mask_pos_z]
|
89 |
+
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
|
90 |
+
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
|
91 |
+
|
92 |
+
mask_neg_z = index == 5
|
93 |
+
max_axis[mask_neg_z] = abs_z[mask_neg_z]
|
94 |
+
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
|
95 |
+
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
|
96 |
+
|
97 |
+
# UC from [-1, 1] to [0, 1]
|
98 |
+
max_dim_div = max_axis.max(dim=0, keepdims=True).values
|
99 |
+
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
100 |
+
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
101 |
+
|
102 |
+
uv = torch.stack([uc, vc], dim=-1)
|
103 |
+
|
104 |
+
return uv, index
|
105 |
+
|
106 |
+
|
107 |
+
def _assign_faces_uv_to_atlas_index(
|
108 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
109 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
110 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
111 |
+
face_index: Integer[Tensor, "Nf 3"],
|
112 |
+
) -> Integer[Tensor, "Nf"]: # noqa: F821
|
113 |
+
triangle_pos = vertex_positions[triangle_idxs]
|
114 |
+
# We need to do perform 3 overlap checks.
|
115 |
+
# The first set is placed in the upper two thirds of the UV atlas.
|
116 |
+
# Conceptually, this is the direct visible surfaces from the each cube side
|
117 |
+
# The second set is placed in the lower thirds and the left half of the UV atlas.
|
118 |
+
# This is the first set of occluded surfaces. They will also be saved in the projected fashion
|
119 |
+
# The third pass finds all non assigned faces. They will be placed in the bottom right half of
|
120 |
+
# the UV atlas in scattered fashion.
|
121 |
+
assign_idx = face_index.clone()
|
122 |
+
for overlap_step in range(3):
|
123 |
+
overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
|
124 |
+
for i in range(overlap_step * 6, (overlap_step + 1) * 6):
|
125 |
+
mask = assign_idx == i
|
126 |
+
if not mask.any():
|
127 |
+
continue
|
128 |
+
# Get all elements belonging to the projection face
|
129 |
+
uv_triangle = face_uv[mask]
|
130 |
+
cur_triangle_pos = triangle_pos[mask]
|
131 |
+
# Find the center of the uv coordinates
|
132 |
+
center_uv = uv_triangle.mean(dim=1, keepdim=True)
|
133 |
+
# And also the radius of the triangle
|
134 |
+
uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
|
135 |
+
|
136 |
+
potentially_overlapping_mask = (
|
137 |
+
# Find all close triangles
|
138 |
+
(center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
|
139 |
+
# Do not select the same element by offseting with an large valued identity matrix
|
140 |
+
+ torch.eye(
|
141 |
+
uv_triangle.shape[0],
|
142 |
+
device=uv_triangle.device,
|
143 |
+
dtype=uv_triangle.dtype,
|
144 |
+
).unsqueeze(-1)
|
145 |
+
* 1000
|
146 |
+
)
|
147 |
+
# Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
|
148 |
+
potentially_overlapping_mask = (
|
149 |
+
potentially_overlapping_mask
|
150 |
+
<= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
|
151 |
+
).squeeze(-1)
|
152 |
+
overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
|
153 |
+
|
154 |
+
# Only unique triangles (A|B and B|A should be the same)
|
155 |
+
f = torch.min(overlap_coords, dim=-1).values
|
156 |
+
s = torch.max(overlap_coords, dim=-1).values
|
157 |
+
overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
|
158 |
+
first, second = overlap_coords.unbind(-1)
|
159 |
+
|
160 |
+
# Get the triangles
|
161 |
+
tri_1 = uv_triangle[first]
|
162 |
+
tri_2 = uv_triangle[second]
|
163 |
+
|
164 |
+
# Perform the actual set with the reduced number of potentially overlapping triangles
|
165 |
+
its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
|
166 |
+
|
167 |
+
# So we now need to detect which triangles are the occluded ones.
|
168 |
+
# We always assume the first to be the visible one (the others should move)
|
169 |
+
# In the previous step we use a lexigraphical sort to get the unique pairs
|
170 |
+
# In this we use a sort based on the orthographic projection
|
171 |
+
ax = 0 if i < 2 else 1 if i < 4 else 2
|
172 |
+
use_max = i % 2 == 1
|
173 |
+
|
174 |
+
tri1_c = cur_triangle_pos[first].mean(dim=1)
|
175 |
+
tri2_c = cur_triangle_pos[second].mean(dim=1)
|
176 |
+
|
177 |
+
mark_first = (
|
178 |
+
(tri1_c[..., ax] > tri2_c[..., ax])
|
179 |
+
if use_max
|
180 |
+
else (tri1_c[..., ax] < tri2_c[..., ax])
|
181 |
+
)
|
182 |
+
first[mark_first] = second[mark_first]
|
183 |
+
|
184 |
+
# Lastly the same index can be tested multiple times.
|
185 |
+
# If one marks it as overlapping we keep it marked as such.
|
186 |
+
# We do this by testing if it has been marked at least once.
|
187 |
+
unique_idx, rev_idx = torch.unique(first, return_inverse=True)
|
188 |
+
|
189 |
+
add = torch.zeros_like(unique_idx, dtype=torch.float32)
|
190 |
+
add.index_add_(0, rev_idx, its.float())
|
191 |
+
its_mask = add > 0
|
192 |
+
|
193 |
+
# And fill it in the overlapping indicator
|
194 |
+
idx = torch.where(mask)[0][unique_idx]
|
195 |
+
overlapping_indicator[idx] = its_mask
|
196 |
+
|
197 |
+
# Move the index to the overlap regions (shift by 6)
|
198 |
+
assign_idx[overlapping_indicator] += 6
|
199 |
+
|
200 |
+
# We do not care about the correct face placement after the first 2 slices
|
201 |
+
max_idx = 6 * 2
|
202 |
+
return assign_idx.clamp(0, max_idx)
|
203 |
+
|
204 |
+
|
205 |
+
def _find_slice_offset_and_scale(
|
206 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
207 |
+
) -> Tuple[
|
208 |
+
Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
|
209 |
+
]: # noqa: F821
|
210 |
+
# 6 due to the 6 cube faces
|
211 |
+
off = 1 / 3
|
212 |
+
dupl_off = 1 / 6
|
213 |
+
|
214 |
+
# Here, we need to decide how to pack the textures in the case of overlap
|
215 |
+
def x_offset_calc(x, i):
|
216 |
+
offset_calc = i // 6
|
217 |
+
# Initial coordinates - just 3x2 grid
|
218 |
+
if offset_calc == 0:
|
219 |
+
return off * x
|
220 |
+
else:
|
221 |
+
# Smaller 3x2 grid plus eventual shift to right for
|
222 |
+
# second overlap
|
223 |
+
return dupl_off * x + min(offset_calc - 1, 1) * 0.5
|
224 |
+
|
225 |
+
def y_offset_calc(x, i):
|
226 |
+
offset_calc = i // 6
|
227 |
+
# Initial coordinates - just a 3x2 grid
|
228 |
+
if offset_calc == 0:
|
229 |
+
return off * x
|
230 |
+
else:
|
231 |
+
# Smaller coordinates in the lowest row
|
232 |
+
return dupl_off * x + off * 2
|
233 |
+
|
234 |
+
offset_x = torch.zeros_like(index, dtype=torch.float32)
|
235 |
+
offset_y = torch.zeros_like(index, dtype=torch.float32)
|
236 |
+
offset_x_vals = [0, 1, 2, 0, 1, 2]
|
237 |
+
offset_y_vals = [0, 0, 0, 1, 1, 1]
|
238 |
+
for i in range(index.max().item() + 1):
|
239 |
+
mask = index == i
|
240 |
+
if not mask.any():
|
241 |
+
continue
|
242 |
+
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
|
243 |
+
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
|
244 |
+
|
245 |
+
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
|
246 |
+
# All overlap elements are saved in half scale
|
247 |
+
div_x[index >= 6] = 6
|
248 |
+
div_y = div_x.clone() # Same for y
|
249 |
+
# Except for the random overlaps
|
250 |
+
div_x[index >= 12] = 2
|
251 |
+
# But the random overlaps are saved in a large block in the lower thirds
|
252 |
+
div_y[index >= 12] = 3
|
253 |
+
|
254 |
+
return offset_x, offset_y, div_x, div_y
|
255 |
+
|
256 |
+
|
257 |
+
def rotation_flip_matrix_2d(
|
258 |
+
rad: float, flip_x: bool = False, flip_y: bool = False
|
259 |
+
) -> Float[Tensor, "2 2"]:
|
260 |
+
cos = math.cos(rad)
|
261 |
+
sin = math.sin(rad)
|
262 |
+
rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
|
263 |
+
flip_mat = torch.tensor(
|
264 |
+
[
|
265 |
+
[-1 if flip_x else 1, 0],
|
266 |
+
[0, -1 if flip_y else 1],
|
267 |
+
],
|
268 |
+
dtype=torch.float32,
|
269 |
+
)
|
270 |
+
|
271 |
+
return flip_mat @ rot_mat
|
272 |
+
|
273 |
+
|
274 |
+
def calculate_tangents(
|
275 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
276 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
277 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
278 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
279 |
+
) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
|
280 |
+
vn_idx = [None] * 3
|
281 |
+
pos = [None] * 3
|
282 |
+
tex = face_uv.unbind(1)
|
283 |
+
for i in range(0, 3):
|
284 |
+
pos[i] = vertex_positions[triangle_idxs[:, i]]
|
285 |
+
# t_nrm_idx is always the same as t_pos_idx
|
286 |
+
vn_idx[i] = triangle_idxs[:, i]
|
287 |
+
|
288 |
+
tangents = torch.zeros_like(vertex_normals)
|
289 |
+
tansum = torch.zeros_like(vertex_normals)
|
290 |
+
|
291 |
+
# Compute tangent space for each triangle
|
292 |
+
duv1 = tex[1] - tex[0]
|
293 |
+
duv2 = tex[2] - tex[0]
|
294 |
+
dpos1 = pos[1] - pos[0]
|
295 |
+
dpos2 = pos[2] - pos[0]
|
296 |
+
|
297 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
298 |
+
|
299 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
300 |
+
|
301 |
+
# Avoid division by zero for degenerated texture coordinates
|
302 |
+
denom_safe = denom.clip(1e-6)
|
303 |
+
tang = tng_nom / denom_safe
|
304 |
+
|
305 |
+
# Update all 3 vertices
|
306 |
+
for i in range(0, 3):
|
307 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
308 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
309 |
+
tansum.scatter_add_(
|
310 |
+
0, idx, torch.ones_like(tang)
|
311 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
312 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
313 |
+
# triangles influence the tangent space more
|
314 |
+
tangents = tangents / tansum
|
315 |
+
|
316 |
+
# Normalize and make sure tangent is perpendicular to normal
|
317 |
+
tangents = F.normalize(tangents, dim=1)
|
318 |
+
tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
|
319 |
+
|
320 |
+
return tangents
|
321 |
+
|
322 |
+
|
323 |
+
def _rotate_uv_slices_consistent_space(
|
324 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
325 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
326 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
327 |
+
uv: Float[Tensor, "Nf 3 2"],
|
328 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
329 |
+
):
|
330 |
+
tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
|
331 |
+
pos_stack = torch.stack(
|
332 |
+
[
|
333 |
+
-vertex_positions[..., 1],
|
334 |
+
vertex_positions[..., 0],
|
335 |
+
torch.zeros_like(vertex_positions[..., 0]),
|
336 |
+
],
|
337 |
+
dim=-1,
|
338 |
+
)
|
339 |
+
expected_tangents = F.normalize(
|
340 |
+
torch.linalg.cross(
|
341 |
+
vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
|
342 |
+
),
|
343 |
+
-1,
|
344 |
+
)
|
345 |
+
|
346 |
+
actual_tangents = tangents[triangle_idxs]
|
347 |
+
expected_tangents = expected_tangents[triangle_idxs]
|
348 |
+
|
349 |
+
def rotation_matrix_2d(theta):
|
350 |
+
c, s = torch.cos(theta), torch.sin(theta)
|
351 |
+
return torch.tensor([[c, -s], [s, c]])
|
352 |
+
|
353 |
+
# Now find the rotation
|
354 |
+
index_mod = index % 6 # Shouldn't happen. Just for safety
|
355 |
+
for i in range(6):
|
356 |
+
mask = index_mod == i
|
357 |
+
if not mask.any():
|
358 |
+
continue
|
359 |
+
|
360 |
+
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
|
361 |
+
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
|
362 |
+
|
363 |
+
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
|
364 |
+
cross_product = (
|
365 |
+
actual_mean_tangent[0] * expected_mean_tangent[1]
|
366 |
+
- actual_mean_tangent[1] * expected_mean_tangent[0]
|
367 |
+
)
|
368 |
+
angle = torch.atan2(cross_product, dot_product)
|
369 |
+
|
370 |
+
rot_matrix = rotation_matrix_2d(angle).to(mask.device)
|
371 |
+
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered
|
372 |
+
uv_cur = uv[mask] * 2 - 1 # Center it first
|
373 |
+
# Rotate it
|
374 |
+
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
|
375 |
+
|
376 |
+
# Rescale uv[mask] to be within the 0-1 range
|
377 |
+
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
|
378 |
+
|
379 |
+
return uv
|
380 |
+
|
381 |
+
|
382 |
+
def _handle_slice_uvs(
|
383 |
+
uv: Float[Tensor, "Nf 3 2"],
|
384 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
385 |
+
island_padding: float,
|
386 |
+
max_index: int = 6 * 2,
|
387 |
+
) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
|
388 |
+
uc, vc = uv.unbind(-1)
|
389 |
+
|
390 |
+
# Get the second slice (The first overlap)
|
391 |
+
index_filter = [index == i for i in range(6, max_index)]
|
392 |
+
|
393 |
+
# Normalize them to always fully fill the atlas patch
|
394 |
+
for i, fi in enumerate(index_filter):
|
395 |
+
if fi.sum() > 0:
|
396 |
+
# Scale the slice but only up to a factor of 2
|
397 |
+
# This keeps the texture resolution with the first slice in line (Half space in UV)
|
398 |
+
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
|
399 |
+
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
|
400 |
+
|
401 |
+
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
402 |
+
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
403 |
+
|
404 |
+
return torch.stack([uc_padded, vc_padded], dim=-1)
|
405 |
+
|
406 |
+
|
407 |
+
def _handle_remaining_uvs(
|
408 |
+
uv: Float[Tensor, "Nf 3 2"],
|
409 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
410 |
+
island_padding: float,
|
411 |
+
) -> Float[Tensor, "Nf 3 2"]:
|
412 |
+
uc, vc = uv.unbind(-1)
|
413 |
+
# Get all remaining elements
|
414 |
+
remaining_filter = index >= 6 * 2
|
415 |
+
squares_left = remaining_filter.sum()
|
416 |
+
|
417 |
+
if squares_left == 0:
|
418 |
+
return uv
|
419 |
+
|
420 |
+
uc = uc[remaining_filter]
|
421 |
+
vc = vc[remaining_filter]
|
422 |
+
|
423 |
+
# Or remaining triangles are distributed in a rectangle
|
424 |
+
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
|
425 |
+
ratio = 0.5 * (1 / 3) # 1.5
|
426 |
+
# sqrt(744/(0.5*(1/3)))
|
427 |
+
|
428 |
+
mult = math.sqrt(squares_left / ratio)
|
429 |
+
num_square_width = int(math.ceil(0.5 * mult))
|
430 |
+
num_square_height = int(math.ceil(squares_left / num_square_width))
|
431 |
+
|
432 |
+
width = 1 / num_square_width
|
433 |
+
height = 1 / num_square_height
|
434 |
+
|
435 |
+
# The idea is again to keep the texture resolution consistent with the first slice
|
436 |
+
# This only occupys half the region in the texture chart but the scaling on the squares
|
437 |
+
# assumes full coverage.
|
438 |
+
clip_val = min(width, height) * 1.5
|
439 |
+
# Now normalize the UVs with taking into account the maximum scaling
|
440 |
+
uc = (uc - uc.min(dim=1, keepdim=True).values) / (
|
441 |
+
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
|
442 |
+
).clip(clip_val)
|
443 |
+
vc = (vc - vc.min(dim=1, keepdim=True).values) / (
|
444 |
+
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
|
445 |
+
).clip(clip_val)
|
446 |
+
# Add a small padding
|
447 |
+
uc = (
|
448 |
+
uc * (1 - island_padding * num_square_width * 0.5)
|
449 |
+
+ island_padding * num_square_width * 0.25
|
450 |
+
).clip(0, 1)
|
451 |
+
vc = (
|
452 |
+
vc * (1 - island_padding * num_square_height * 0.5)
|
453 |
+
+ island_padding * num_square_height * 0.25
|
454 |
+
).clip(0, 1)
|
455 |
+
|
456 |
+
uc = uc * width
|
457 |
+
vc = vc * height
|
458 |
+
|
459 |
+
# And calculate offsets for each element
|
460 |
+
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
|
461 |
+
x_idx = idx % num_square_width
|
462 |
+
y_idx = idx // num_square_width
|
463 |
+
# And move each triangle to its own spot
|
464 |
+
uc = uc + x_idx[:, None] * width
|
465 |
+
vc = vc + y_idx[:, None] * height
|
466 |
+
|
467 |
+
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
468 |
+
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
469 |
+
|
470 |
+
uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
|
471 |
+
|
472 |
+
return uv
|
473 |
+
|
474 |
+
|
475 |
+
def _distribute_individual_uvs_in_atlas(
|
476 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
477 |
+
assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
|
478 |
+
offset_x: Float[Tensor, "Nf"], # noqa: F821
|
479 |
+
offset_y: Float[Tensor, "Nf"], # noqa: F821
|
480 |
+
div_x: Float[Tensor, "Nf"], # noqa: F821
|
481 |
+
div_y: Float[Tensor, "Nf"], # noqa: F821
|
482 |
+
island_padding: float,
|
483 |
+
):
|
484 |
+
# Place the slice first
|
485 |
+
placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
|
486 |
+
# Then handle the remaining overlap elements
|
487 |
+
placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
|
488 |
+
|
489 |
+
uc, vc = placed_uv.unbind(-1)
|
490 |
+
uc = uc / div_x[:, None] + offset_x[:, None]
|
491 |
+
vc = vc / div_y[:, None] + offset_y[:, None]
|
492 |
+
|
493 |
+
uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
|
494 |
+
|
495 |
+
return uv
|
496 |
+
|
497 |
+
|
498 |
+
def _get_unique_face_uv(
|
499 |
+
uv: Float[Tensor, "Nf 3 2"],
|
500 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
501 |
+
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
|
502 |
+
# And add the face to uv index mapping
|
503 |
+
vtex_idx = unique_idx.view(-1, 3)
|
504 |
+
|
505 |
+
return unique_uv, vtex_idx
|
506 |
+
|
507 |
+
|
508 |
+
def _align_mesh_with_main_axis(
|
509 |
+
vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
|
510 |
+
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
|
511 |
+
# Use pca to find the 2 main axis (third is derived by cross product)
|
512 |
+
# Set the random seed so it's repeatable
|
513 |
+
torch.manual_seed(0)
|
514 |
+
_, _, v = torch.pca_lowrank(vertex_positions, q=2)
|
515 |
+
main_axis, seconday_axis = v[:, 0], v[:, 1]
|
516 |
+
|
517 |
+
main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
|
518 |
+
# Orthogonalize the second axis
|
519 |
+
seconday_axis: Float[Tensor, "3"] = F.normalize(
|
520 |
+
seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
|
521 |
+
)
|
522 |
+
# Create perpendicular third axis
|
523 |
+
third_axis: Float[Tensor, "3"] = F.normalize(
|
524 |
+
torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
|
525 |
+
)
|
526 |
+
|
527 |
+
# Check to which canonical axis each aligns
|
528 |
+
main_axis_max_idx = main_axis.abs().argmax().item()
|
529 |
+
seconday_axis_max_idx = seconday_axis.abs().argmax().item()
|
530 |
+
third_axis_max_idx = third_axis.abs().argmax().item()
|
531 |
+
|
532 |
+
# Now sort the axes based on the argmax so they align with thecanonoical axes
|
533 |
+
# If two axes have the same argmax move one of them
|
534 |
+
all_possible_axis = {0, 1, 2}
|
535 |
+
cur_index = 1
|
536 |
+
while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
|
537 |
+
# Find missing axis
|
538 |
+
missing_axis = all_possible_axis - set(
|
539 |
+
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
|
540 |
+
)
|
541 |
+
missing_axis = missing_axis.pop()
|
542 |
+
# Just assign it to third axis as it had the smallest contribution to the
|
543 |
+
# overall shape
|
544 |
+
if cur_index == 1:
|
545 |
+
third_axis_max_idx = missing_axis
|
546 |
+
elif cur_index == 2:
|
547 |
+
seconday_axis_max_idx = missing_axis
|
548 |
+
else:
|
549 |
+
raise ValueError("Could not find 3 unique axis")
|
550 |
+
cur_index += 1
|
551 |
+
|
552 |
+
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
|
553 |
+
raise ValueError("Could not find 3 unique axis")
|
554 |
+
|
555 |
+
axes = [None] * 3
|
556 |
+
axes[main_axis_max_idx] = main_axis
|
557 |
+
axes[seconday_axis_max_idx] = seconday_axis
|
558 |
+
axes[third_axis_max_idx] = third_axis
|
559 |
+
# Create rotation matrix from the individual axes
|
560 |
+
rot_mat = torch.stack(axes, dim=1).T
|
561 |
+
|
562 |
+
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
|
563 |
+
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
|
564 |
+
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
|
565 |
+
|
566 |
+
return vertex_positions, vertex_normals
|
567 |
+
|
568 |
+
|
569 |
+
def box_projection_uv_unwrap(
|
570 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
571 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
572 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
573 |
+
island_padding: float,
|
574 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
575 |
+
# Align the mesh with main axis directions first
|
576 |
+
vertex_positions, vertex_normals = _align_mesh_with_main_axis(
|
577 |
+
vertex_positions, vertex_normals
|
578 |
+
)
|
579 |
+
|
580 |
+
bbox: Float[Tensor, "2 3"] = torch.stack(
|
581 |
+
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
|
582 |
+
)
|
583 |
+
# First decide in which cube face the triangle is placed
|
584 |
+
face_uv, face_index = _box_assign_vertex_to_cube_face(
|
585 |
+
vertex_positions, vertex_normals, triangle_idxs, bbox
|
586 |
+
)
|
587 |
+
|
588 |
+
# Rotate the UV islands in a way that they align with the radial z tangent space
|
589 |
+
face_uv = _rotate_uv_slices_consistent_space(
|
590 |
+
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
|
591 |
+
)
|
592 |
+
|
593 |
+
# Then find where where the face is placed in the atlas.
|
594 |
+
# This has to detect potential overlaps
|
595 |
+
assigned_atlas_index = _assign_faces_uv_to_atlas_index(
|
596 |
+
vertex_positions, triangle_idxs, face_uv, face_index
|
597 |
+
)
|
598 |
+
|
599 |
+
# Then figure out the final place in the atlas based on the assignment
|
600 |
+
offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
|
601 |
+
assigned_atlas_index
|
602 |
+
)
|
603 |
+
|
604 |
+
# Next distribute the faces in the uv atlas
|
605 |
+
placed_uv = _distribute_individual_uvs_in_atlas(
|
606 |
+
face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
|
607 |
+
)
|
608 |
+
|
609 |
+
# And get the unique per-triangle UV coordinates
|
610 |
+
return _get_unique_face_uv(placed_uv)
|
sf3d/models/camera.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from sf3d.models.utils import BaseModule
|
8 |
+
|
9 |
+
|
10 |
+
class LinearCameraEmbedder(BaseModule):
|
11 |
+
@dataclass
|
12 |
+
class Config(BaseModule.Config):
|
13 |
+
in_channels: int = 25
|
14 |
+
out_channels: int = 768
|
15 |
+
conditions: List[str] = field(default_factory=list)
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
21 |
+
|
22 |
+
def forward(self, **kwargs):
|
23 |
+
cond_tensors = []
|
24 |
+
for cond_name in self.cfg.conditions:
|
25 |
+
assert cond_name in kwargs
|
26 |
+
cond = kwargs[cond_name]
|
27 |
+
# cond in shape (B, Nv, ...)
|
28 |
+
cond_tensors.append(cond.view(*cond.shape[:2], -1))
|
29 |
+
cond_tensor = torch.cat(cond_tensors, dim=-1)
|
30 |
+
assert cond_tensor.shape[-1] == self.cfg.in_channels
|
31 |
+
embedding = self.linear(cond_tensor)
|
32 |
+
return embedding
|
sf3d/models/global_estimator/multi_head_estimator.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from jaxtyping import Float
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from sf3d.models.network import get_activation
|
9 |
+
from sf3d.models.utils import BaseModule
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class HeadSpec:
|
14 |
+
name: str
|
15 |
+
out_channels: int
|
16 |
+
n_hidden_layers: int
|
17 |
+
output_activation: Optional[str] = None
|
18 |
+
output_bias: float = 0.0
|
19 |
+
add_to_decoder_features: bool = False
|
20 |
+
shape: Optional[list[int]] = None
|
21 |
+
|
22 |
+
|
23 |
+
class MultiHeadEstimator(BaseModule):
|
24 |
+
@dataclass
|
25 |
+
class Config(BaseModule.Config):
|
26 |
+
triplane_features: int = 1024
|
27 |
+
|
28 |
+
n_layers: int = 2
|
29 |
+
hidden_features: int = 512
|
30 |
+
activation: str = "relu"
|
31 |
+
|
32 |
+
pool: str = "max"
|
33 |
+
# Literal["mean", "max"] = "mean" # noqa: F821
|
34 |
+
|
35 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
36 |
+
|
37 |
+
cfg: Config
|
38 |
+
|
39 |
+
def configure(self):
|
40 |
+
layers = []
|
41 |
+
cur_features = self.cfg.triplane_features * 3
|
42 |
+
for _ in range(self.cfg.n_layers):
|
43 |
+
layers.append(
|
44 |
+
nn.Conv2d(
|
45 |
+
cur_features,
|
46 |
+
self.cfg.hidden_features,
|
47 |
+
kernel_size=3,
|
48 |
+
padding=0,
|
49 |
+
stride=2,
|
50 |
+
)
|
51 |
+
)
|
52 |
+
layers.append(self.make_activation(self.cfg.activation))
|
53 |
+
|
54 |
+
cur_features = self.cfg.hidden_features
|
55 |
+
|
56 |
+
self.layers = nn.Sequential(*layers)
|
57 |
+
|
58 |
+
assert len(self.cfg.heads) > 0
|
59 |
+
heads = {}
|
60 |
+
for head in self.cfg.heads:
|
61 |
+
head_layers = []
|
62 |
+
for i in range(head.n_hidden_layers):
|
63 |
+
head_layers += [
|
64 |
+
nn.Linear(
|
65 |
+
self.cfg.hidden_features,
|
66 |
+
self.cfg.hidden_features,
|
67 |
+
),
|
68 |
+
self.make_activation(self.cfg.activation),
|
69 |
+
]
|
70 |
+
head_layers += [
|
71 |
+
nn.Linear(
|
72 |
+
self.cfg.hidden_features,
|
73 |
+
head.out_channels,
|
74 |
+
),
|
75 |
+
]
|
76 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
77 |
+
self.heads = nn.ModuleDict(heads)
|
78 |
+
|
79 |
+
def make_activation(self, activation):
|
80 |
+
if activation == "relu":
|
81 |
+
return nn.ReLU(inplace=True)
|
82 |
+
elif activation == "silu":
|
83 |
+
return nn.SiLU(inplace=True)
|
84 |
+
else:
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
triplane: Float[Tensor, "B 3 F Ht Wt"],
|
90 |
+
) -> dict[str, Any]:
|
91 |
+
x = self.layers(
|
92 |
+
triplane.reshape(
|
93 |
+
triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
if self.cfg.pool == "max":
|
98 |
+
x = x.amax(dim=[-2, -1])
|
99 |
+
elif self.cfg.pool == "mean":
|
100 |
+
x = x.mean(dim=[-2, -1])
|
101 |
+
else:
|
102 |
+
raise NotImplementedError
|
103 |
+
|
104 |
+
out = {
|
105 |
+
("decoder_" if head.add_to_decoder_features else "")
|
106 |
+
+ head.name: get_activation(head.output_activation)(
|
107 |
+
self.heads[head.name](x) + head.output_bias
|
108 |
+
)
|
109 |
+
for head in self.cfg.heads
|
110 |
+
}
|
111 |
+
for head in self.cfg.heads:
|
112 |
+
if head.shape:
|
113 |
+
head_name = (
|
114 |
+
"decoder_" if head.add_to_decoder_features else ""
|
115 |
+
) + head.name
|
116 |
+
out[head_name] = out[head_name].reshape(*head.shape)
|
117 |
+
|
118 |
+
return out
|
sf3d/models/image_estimator/clip_based_estimator.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
import open_clip
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
|
11 |
+
from sf3d.models.network import get_activation
|
12 |
+
from sf3d.models.utils import BaseModule
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class HeadSpec:
|
17 |
+
name: str
|
18 |
+
out_channels: int
|
19 |
+
n_hidden_layers: int
|
20 |
+
output_activation: Optional[str] = None
|
21 |
+
output_bias: float = 0.0
|
22 |
+
add_to_decoder_features: bool = False
|
23 |
+
shape: Optional[list[int]] = None
|
24 |
+
|
25 |
+
|
26 |
+
class ClipBasedHeadEstimator(BaseModule):
|
27 |
+
@dataclass
|
28 |
+
class Config(BaseModule.Config):
|
29 |
+
model: str = "ViT-B-32"
|
30 |
+
pretrain: str = "laion2b_s34b_b79k"
|
31 |
+
|
32 |
+
distribution: str = "beta"
|
33 |
+
|
34 |
+
# ["mean", "mode", "sample", "sample_mean"]
|
35 |
+
distribution_eval: str = "mode"
|
36 |
+
|
37 |
+
activation: str = "relu"
|
38 |
+
hidden_features: int = 512
|
39 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
40 |
+
|
41 |
+
cfg: Config
|
42 |
+
|
43 |
+
def configure(self):
|
44 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
45 |
+
self.cfg.model, pretrained=self.cfg.pretrain
|
46 |
+
)
|
47 |
+
self.model.eval()
|
48 |
+
|
49 |
+
# Do not add the weights in self.model to the optimizer
|
50 |
+
for param in self.model.parameters():
|
51 |
+
param.requires_grad = False
|
52 |
+
|
53 |
+
assert len(self.cfg.heads) > 0
|
54 |
+
heads = {}
|
55 |
+
for head in self.cfg.heads:
|
56 |
+
head_layers = []
|
57 |
+
|
58 |
+
for i in range(head.n_hidden_layers):
|
59 |
+
head_layers += [
|
60 |
+
nn.Linear(
|
61 |
+
self.cfg.hidden_features,
|
62 |
+
self.cfg.hidden_features,
|
63 |
+
),
|
64 |
+
self.make_activation(self.cfg.activation),
|
65 |
+
]
|
66 |
+
|
67 |
+
head_layers = [nn.Sequential(*head_layers)]
|
68 |
+
head_layers += [
|
69 |
+
nn.Sequential(
|
70 |
+
nn.Linear(
|
71 |
+
self.cfg.hidden_features,
|
72 |
+
self.cfg.hidden_features,
|
73 |
+
),
|
74 |
+
self.make_activation(self.cfg.activation),
|
75 |
+
nn.Linear(self.cfg.hidden_features, 1),
|
76 |
+
)
|
77 |
+
for _ in range(2)
|
78 |
+
]
|
79 |
+
heads[head.name] = nn.ModuleList(head_layers)
|
80 |
+
self.heads = nn.ModuleDict(heads)
|
81 |
+
|
82 |
+
def make_activation(self, activation):
|
83 |
+
if activation == "relu":
|
84 |
+
return nn.ReLU(inplace=True)
|
85 |
+
elif activation == "silu":
|
86 |
+
return nn.SiLU(inplace=True)
|
87 |
+
else:
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
cond_image: Float[Tensor, "B 1 H W 3"],
|
93 |
+
sample: bool = True,
|
94 |
+
) -> dict[str, Any]:
|
95 |
+
# Run the model
|
96 |
+
# Resize cond_image to 224
|
97 |
+
cond_image = nn.functional.interpolate(
|
98 |
+
cond_image.flatten(0, 1).permute(0, 3, 1, 2),
|
99 |
+
size=(224, 224),
|
100 |
+
mode="bilinear",
|
101 |
+
align_corners=False,
|
102 |
+
)
|
103 |
+
cond_image = Normalize(
|
104 |
+
mean=open_clip.constants.OPENAI_DATASET_MEAN,
|
105 |
+
std=open_clip.constants.OPENAI_DATASET_STD,
|
106 |
+
)(cond_image)
|
107 |
+
image_features = self.model.encode_image(cond_image)
|
108 |
+
|
109 |
+
# Run the heads
|
110 |
+
outputs = {}
|
111 |
+
|
112 |
+
for head_dict in self.cfg.heads:
|
113 |
+
head_name = head_dict.name
|
114 |
+
shared_head, d1_h, d2_h = self.heads[head_name]
|
115 |
+
shared_features = shared_head(image_features)
|
116 |
+
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
|
117 |
+
if self.cfg.distribution == "normal":
|
118 |
+
mean = d1
|
119 |
+
var = d2
|
120 |
+
if mean.shape[-1] == 1:
|
121 |
+
outputs[head_name] = torch.distributions.Normal(
|
122 |
+
mean + head_dict.output_bias,
|
123 |
+
torch.nn.functional.softplus(var),
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
outputs[head_name] = torch.distributions.MultivariateNormal(
|
127 |
+
mean + head_dict.output_bias,
|
128 |
+
torch.nn.functional.softplus(var).diag_embed(),
|
129 |
+
)
|
130 |
+
elif self.cfg.distribution == "beta":
|
131 |
+
outputs[head_name] = torch.distributions.Beta(
|
132 |
+
torch.nn.functional.softplus(d1 + head_dict.output_bias),
|
133 |
+
torch.nn.functional.softplus(d2 + head_dict.output_bias),
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
raise NotImplementedError
|
137 |
+
|
138 |
+
if sample:
|
139 |
+
for head_dict in self.cfg.heads:
|
140 |
+
head_name = head_dict.name
|
141 |
+
dist = outputs[head_name]
|
142 |
+
|
143 |
+
if self.cfg.distribution_eval == "mean":
|
144 |
+
out = dist.mean
|
145 |
+
elif self.cfg.distribution_eval == "mode":
|
146 |
+
out = dist.mode
|
147 |
+
elif self.cfg.distribution_eval == "sample_mean":
|
148 |
+
out = dist.sample([10]).mean(-1)
|
149 |
+
else:
|
150 |
+
# use rsample if gradient is needed
|
151 |
+
out = dist.rsample() if self.training else dist.sample()
|
152 |
+
|
153 |
+
outputs[head_name] = get_activation(head_dict.output_activation)(out)
|
154 |
+
outputs[f"{head_name}_dist"] = dist
|
155 |
+
|
156 |
+
for head in self.cfg.heads:
|
157 |
+
if head.shape:
|
158 |
+
if not sample:
|
159 |
+
raise ValueError(
|
160 |
+
"Cannot reshape non-sampled probabilisitic outputs"
|
161 |
+
)
|
162 |
+
outputs[head.name] = outputs[head.name].reshape(*head.shape)
|
163 |
+
|
164 |
+
if head.add_to_decoder_features:
|
165 |
+
outputs[f"decoder_{head.name}"] = outputs[head.name]
|
166 |
+
del outputs[head.name]
|
167 |
+
|
168 |
+
return outputs
|
sf3d/models/isosurface.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from jaxtyping import Float, Integer
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from .mesh import Mesh
|
10 |
+
|
11 |
+
|
12 |
+
class IsosurfaceHelper(nn.Module):
|
13 |
+
points_range: Tuple[float, float] = (0, 1)
|
14 |
+
|
15 |
+
@property
|
16 |
+
def grid_vertices(self) -> Float[Tensor, "N 3"]:
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
@property
|
20 |
+
def requires_instance_per_batch(self) -> bool:
|
21 |
+
return False
|
22 |
+
|
23 |
+
|
24 |
+
class MarchingTetrahedraHelper(IsosurfaceHelper):
|
25 |
+
def __init__(self, resolution: int, tets_path: str):
|
26 |
+
super().__init__()
|
27 |
+
self.resolution = resolution
|
28 |
+
self.tets_path = tets_path
|
29 |
+
|
30 |
+
self.triangle_table: Float[Tensor, "..."]
|
31 |
+
self.register_buffer(
|
32 |
+
"triangle_table",
|
33 |
+
torch.as_tensor(
|
34 |
+
[
|
35 |
+
[-1, -1, -1, -1, -1, -1],
|
36 |
+
[1, 0, 2, -1, -1, -1],
|
37 |
+
[4, 0, 3, -1, -1, -1],
|
38 |
+
[1, 4, 2, 1, 3, 4],
|
39 |
+
[3, 1, 5, -1, -1, -1],
|
40 |
+
[2, 3, 0, 2, 5, 3],
|
41 |
+
[1, 4, 0, 1, 5, 4],
|
42 |
+
[4, 2, 5, -1, -1, -1],
|
43 |
+
[4, 5, 2, -1, -1, -1],
|
44 |
+
[4, 1, 0, 4, 5, 1],
|
45 |
+
[3, 2, 0, 3, 5, 2],
|
46 |
+
[1, 3, 5, -1, -1, -1],
|
47 |
+
[4, 1, 2, 4, 3, 1],
|
48 |
+
[3, 0, 4, -1, -1, -1],
|
49 |
+
[2, 0, 1, -1, -1, -1],
|
50 |
+
[-1, -1, -1, -1, -1, -1],
|
51 |
+
],
|
52 |
+
dtype=torch.long,
|
53 |
+
),
|
54 |
+
persistent=False,
|
55 |
+
)
|
56 |
+
self.num_triangles_table: Integer[Tensor, "..."]
|
57 |
+
self.register_buffer(
|
58 |
+
"num_triangles_table",
|
59 |
+
torch.as_tensor(
|
60 |
+
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
|
61 |
+
),
|
62 |
+
persistent=False,
|
63 |
+
)
|
64 |
+
self.base_tet_edges: Integer[Tensor, "..."]
|
65 |
+
self.register_buffer(
|
66 |
+
"base_tet_edges",
|
67 |
+
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
|
68 |
+
persistent=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
tets = np.load(self.tets_path)
|
72 |
+
self._grid_vertices: Float[Tensor, "..."]
|
73 |
+
self.register_buffer(
|
74 |
+
"_grid_vertices",
|
75 |
+
torch.from_numpy(tets["vertices"]).float(),
|
76 |
+
persistent=False,
|
77 |
+
)
|
78 |
+
self.indices: Integer[Tensor, "..."]
|
79 |
+
self.register_buffer(
|
80 |
+
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False
|
81 |
+
)
|
82 |
+
|
83 |
+
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
84 |
+
|
85 |
+
center_indices, boundary_indices = self.get_center_boundary_index(
|
86 |
+
self._grid_vertices
|
87 |
+
)
|
88 |
+
self.center_indices: Integer[Tensor, "..."]
|
89 |
+
self.register_buffer("center_indices", center_indices, persistent=False)
|
90 |
+
self.boundary_indices: Integer[Tensor, "..."]
|
91 |
+
self.register_buffer("boundary_indices", boundary_indices, persistent=False)
|
92 |
+
|
93 |
+
def get_center_boundary_index(self, verts):
|
94 |
+
magn = torch.sum(verts**2, dim=-1)
|
95 |
+
|
96 |
+
center_idx = torch.argmin(magn)
|
97 |
+
boundary_neg = verts == verts.max()
|
98 |
+
boundary_pos = verts == verts.min()
|
99 |
+
|
100 |
+
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
101 |
+
boundary = torch.sum(boundary.float(), dim=-1)
|
102 |
+
|
103 |
+
boundary_idx = torch.nonzero(boundary)
|
104 |
+
return center_idx, boundary_idx.squeeze(dim=-1)
|
105 |
+
|
106 |
+
def normalize_grid_deformation(
|
107 |
+
self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
|
108 |
+
) -> Float[Tensor, "Nv 3"]:
|
109 |
+
return (
|
110 |
+
(self.points_range[1] - self.points_range[0])
|
111 |
+
/ self.resolution # half tet size is approximately 1 / self.resolution
|
112 |
+
* torch.tanh(grid_vertex_offsets)
|
113 |
+
) # FIXME: hard-coded activation
|
114 |
+
|
115 |
+
@property
|
116 |
+
def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
|
117 |
+
return self._grid_vertices
|
118 |
+
|
119 |
+
@property
|
120 |
+
def all_edges(self) -> Integer[Tensor, "Ne 2"]:
|
121 |
+
if self._all_edges is None:
|
122 |
+
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
|
123 |
+
edges = torch.tensor(
|
124 |
+
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
|
125 |
+
dtype=torch.long,
|
126 |
+
device=self.indices.device,
|
127 |
+
)
|
128 |
+
_all_edges = self.indices[:, edges].reshape(-1, 2)
|
129 |
+
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
|
130 |
+
_all_edges = torch.unique(_all_edges_sorted, dim=0)
|
131 |
+
self._all_edges = _all_edges
|
132 |
+
return self._all_edges
|
133 |
+
|
134 |
+
def sort_edges(self, edges_ex2):
|
135 |
+
with torch.no_grad():
|
136 |
+
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
137 |
+
order = order.unsqueeze(dim=1)
|
138 |
+
|
139 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
140 |
+
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
141 |
+
|
142 |
+
return torch.stack([a, b], -1)
|
143 |
+
|
144 |
+
def _forward(self, pos_nx3, sdf_n, tet_fx4):
|
145 |
+
with torch.no_grad():
|
146 |
+
occ_n = sdf_n > 0
|
147 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
148 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
149 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
150 |
+
occ_sum = occ_sum[valid_tets]
|
151 |
+
|
152 |
+
# find all vertices
|
153 |
+
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
|
154 |
+
all_edges = self.sort_edges(all_edges)
|
155 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
156 |
+
|
157 |
+
unique_edges = unique_edges.long()
|
158 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
159 |
+
mapping = (
|
160 |
+
torch.ones(
|
161 |
+
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
|
162 |
+
)
|
163 |
+
* -1
|
164 |
+
)
|
165 |
+
mapping[mask_edges] = torch.arange(
|
166 |
+
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
|
167 |
+
)
|
168 |
+
idx_map = mapping[idx_map] # map edges to verts
|
169 |
+
|
170 |
+
interp_v = unique_edges[mask_edges]
|
171 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
172 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
173 |
+
edges_to_interp_sdf[:, -1] *= -1
|
174 |
+
|
175 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
176 |
+
|
177 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
178 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
179 |
+
|
180 |
+
idx_map = idx_map.reshape(-1, 6)
|
181 |
+
|
182 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
|
183 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
184 |
+
num_triangles = self.num_triangles_table[tetindex]
|
185 |
+
|
186 |
+
# Generate triangle indices
|
187 |
+
faces = torch.cat(
|
188 |
+
(
|
189 |
+
torch.gather(
|
190 |
+
input=idx_map[num_triangles == 1],
|
191 |
+
dim=1,
|
192 |
+
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
|
193 |
+
).reshape(-1, 3),
|
194 |
+
torch.gather(
|
195 |
+
input=idx_map[num_triangles == 2],
|
196 |
+
dim=1,
|
197 |
+
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
|
198 |
+
).reshape(-1, 3),
|
199 |
+
),
|
200 |
+
dim=0,
|
201 |
+
)
|
202 |
+
|
203 |
+
return verts, faces
|
204 |
+
|
205 |
+
def forward(
|
206 |
+
self,
|
207 |
+
level: Float[Tensor, "N3 1"],
|
208 |
+
deformation: Optional[Float[Tensor, "N3 3"]] = None,
|
209 |
+
) -> Mesh:
|
210 |
+
if deformation is not None:
|
211 |
+
grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
|
212 |
+
deformation
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
grid_vertices = self.grid_vertices
|
216 |
+
|
217 |
+
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
|
218 |
+
|
219 |
+
mesh = Mesh(
|
220 |
+
v_pos=v_pos,
|
221 |
+
t_pos_idx=t_pos_idx,
|
222 |
+
# extras
|
223 |
+
grid_vertices=grid_vertices,
|
224 |
+
tet_edges=self.all_edges,
|
225 |
+
grid_level=level,
|
226 |
+
grid_deformation=deformation,
|
227 |
+
)
|
228 |
+
|
229 |
+
return mesh
|
sf3d/models/mesh.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from jaxtyping import Float, Integer
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.box_uv_unwrap import box_projection_uv_unwrap
|
11 |
+
from sf3d.models.utils import dot
|
12 |
+
|
13 |
+
|
14 |
+
class Mesh:
|
15 |
+
def __init__(
|
16 |
+
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
|
17 |
+
) -> None:
|
18 |
+
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
|
19 |
+
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
|
20 |
+
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
|
21 |
+
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
|
22 |
+
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
|
23 |
+
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
24 |
+
self.extras: Dict[str, Any] = {}
|
25 |
+
for k, v in kwargs.items():
|
26 |
+
self.add_extra(k, v)
|
27 |
+
|
28 |
+
def add_extra(self, k, v) -> None:
|
29 |
+
self.extras[k] = v
|
30 |
+
|
31 |
+
@property
|
32 |
+
def requires_grad(self):
|
33 |
+
return self.v_pos.requires_grad
|
34 |
+
|
35 |
+
@property
|
36 |
+
def v_nrm(self):
|
37 |
+
if self._v_nrm is None:
|
38 |
+
self._v_nrm = self._compute_vertex_normal()
|
39 |
+
return self._v_nrm
|
40 |
+
|
41 |
+
@property
|
42 |
+
def v_tng(self):
|
43 |
+
if self._v_tng is None:
|
44 |
+
self._v_tng = self._compute_vertex_tangent()
|
45 |
+
return self._v_tng
|
46 |
+
|
47 |
+
@property
|
48 |
+
def v_tex(self):
|
49 |
+
if self._v_tex is None:
|
50 |
+
self.unwrap_uv()
|
51 |
+
return self._v_tex
|
52 |
+
|
53 |
+
@property
|
54 |
+
def edges(self):
|
55 |
+
if self._edges is None:
|
56 |
+
self._edges = self._compute_edges()
|
57 |
+
return self._edges
|
58 |
+
|
59 |
+
def _compute_vertex_normal(self):
|
60 |
+
i0 = self.t_pos_idx[:, 0]
|
61 |
+
i1 = self.t_pos_idx[:, 1]
|
62 |
+
i2 = self.t_pos_idx[:, 2]
|
63 |
+
|
64 |
+
v0 = self.v_pos[i0, :]
|
65 |
+
v1 = self.v_pos[i1, :]
|
66 |
+
v2 = self.v_pos[i2, :]
|
67 |
+
|
68 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
69 |
+
|
70 |
+
# Splat face normals to vertices
|
71 |
+
v_nrm = torch.zeros_like(self.v_pos)
|
72 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
73 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
74 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
75 |
+
|
76 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
77 |
+
v_nrm = torch.where(
|
78 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
79 |
+
)
|
80 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
81 |
+
|
82 |
+
if torch.is_anomaly_enabled():
|
83 |
+
assert torch.all(torch.isfinite(v_nrm))
|
84 |
+
|
85 |
+
return v_nrm
|
86 |
+
|
87 |
+
def _compute_vertex_tangent(self):
|
88 |
+
vn_idx = [None] * 3
|
89 |
+
pos = [None] * 3
|
90 |
+
tex = [None] * 3
|
91 |
+
for i in range(0, 3):
|
92 |
+
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
|
93 |
+
tex[i] = self.v_tex[self.t_pos_idx[:, i]]
|
94 |
+
# t_nrm_idx is always the same as t_pos_idx
|
95 |
+
vn_idx[i] = self.t_pos_idx[:, i]
|
96 |
+
|
97 |
+
tangents = torch.zeros_like(self.v_nrm)
|
98 |
+
tansum = torch.zeros_like(self.v_nrm)
|
99 |
+
|
100 |
+
# Compute tangent space for each triangle
|
101 |
+
duv1 = tex[1] - tex[0]
|
102 |
+
duv2 = tex[2] - tex[0]
|
103 |
+
dpos1 = pos[1] - pos[0]
|
104 |
+
dpos2 = pos[2] - pos[0]
|
105 |
+
|
106 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
107 |
+
|
108 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
109 |
+
|
110 |
+
# Avoid division by zero for degenerated texture coordinates
|
111 |
+
denom_safe = denom.clip(1e-6)
|
112 |
+
tang = tng_nom / denom_safe
|
113 |
+
|
114 |
+
# Update all 3 vertices
|
115 |
+
for i in range(0, 3):
|
116 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
117 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
118 |
+
tansum.scatter_add_(
|
119 |
+
0, idx, torch.ones_like(tang)
|
120 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
121 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
122 |
+
# triangles influence the tangent space more
|
123 |
+
tangents = tangents / tansum
|
124 |
+
|
125 |
+
# Normalize and make sure tangent is perpendicular to normal
|
126 |
+
tangents = F.normalize(tangents, dim=1)
|
127 |
+
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
|
128 |
+
|
129 |
+
if torch.is_anomaly_enabled():
|
130 |
+
assert torch.all(torch.isfinite(tangents))
|
131 |
+
|
132 |
+
return tangents
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def unwrap_uv(
|
136 |
+
self,
|
137 |
+
island_padding: float = 0.02,
|
138 |
+
) -> Mesh:
|
139 |
+
uv, indices = box_projection_uv_unwrap(
|
140 |
+
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
141 |
+
)
|
142 |
+
|
143 |
+
# Do store per vertex UVs.
|
144 |
+
# This means we need to duplicate some vertices at the seams
|
145 |
+
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
|
146 |
+
individual_faces = torch.arange(
|
147 |
+
individual_vertices.shape[0],
|
148 |
+
device=individual_vertices.device,
|
149 |
+
dtype=self.t_pos_idx.dtype,
|
150 |
+
).reshape(-1, 3)
|
151 |
+
uv_flat = uv[indices].reshape((-1, 2))
|
152 |
+
# uv_flat[:, 1] = 1 - uv_flat[:, 1]
|
153 |
+
|
154 |
+
self.v_pos = individual_vertices
|
155 |
+
self.t_pos_idx = individual_faces
|
156 |
+
self._v_tex = uv_flat
|
157 |
+
self._v_nrm = self._compute_vertex_normal()
|
158 |
+
self._v_tng = self._compute_vertex_tangent()
|
159 |
+
|
160 |
+
def _compute_edges(self):
|
161 |
+
# Compute edges
|
162 |
+
edges = torch.cat(
|
163 |
+
[
|
164 |
+
self.t_pos_idx[:, [0, 1]],
|
165 |
+
self.t_pos_idx[:, [1, 2]],
|
166 |
+
self.t_pos_idx[:, [2, 0]],
|
167 |
+
],
|
168 |
+
dim=0,
|
169 |
+
)
|
170 |
+
edges = edges.sort()[0]
|
171 |
+
edges = torch.unique(edges, dim=0)
|
172 |
+
return edges
|
sf3d/models/network.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Callable, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from jaxtyping import Float
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.autograd import Function
|
11 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
12 |
+
|
13 |
+
from sf3d.models.utils import BaseModule, normalize
|
14 |
+
|
15 |
+
|
16 |
+
class PixelShuffleUpsampleNetwork(BaseModule):
|
17 |
+
@dataclass
|
18 |
+
class Config(BaseModule.Config):
|
19 |
+
in_channels: int = 1024
|
20 |
+
out_channels: int = 40
|
21 |
+
scale_factor: int = 4
|
22 |
+
|
23 |
+
conv_layers: int = 4
|
24 |
+
conv_kernel_size: int = 3
|
25 |
+
|
26 |
+
cfg: Config
|
27 |
+
|
28 |
+
def configure(self) -> None:
|
29 |
+
layers = []
|
30 |
+
output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
|
31 |
+
|
32 |
+
in_channels = self.cfg.in_channels
|
33 |
+
for i in range(self.cfg.conv_layers):
|
34 |
+
cur_out_channels = (
|
35 |
+
in_channels if i != self.cfg.conv_layers - 1 else output_channels
|
36 |
+
)
|
37 |
+
layers.append(
|
38 |
+
nn.Conv2d(
|
39 |
+
in_channels,
|
40 |
+
cur_out_channels,
|
41 |
+
self.cfg.conv_kernel_size,
|
42 |
+
padding=(self.cfg.conv_kernel_size - 1) // 2,
|
43 |
+
)
|
44 |
+
)
|
45 |
+
if i != self.cfg.conv_layers - 1:
|
46 |
+
layers.append(nn.ReLU(inplace=True))
|
47 |
+
|
48 |
+
layers.append(nn.PixelShuffle(self.cfg.scale_factor))
|
49 |
+
|
50 |
+
self.upsample = nn.Sequential(*layers)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
|
54 |
+
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
|
55 |
+
return rearrange(
|
56 |
+
self.upsample(
|
57 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
58 |
+
),
|
59 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
60 |
+
Np=3,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class _TruncExp(Function): # pylint: disable=abstract-method
|
65 |
+
# Implementation from torch-ngp:
|
66 |
+
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
67 |
+
@staticmethod
|
68 |
+
@custom_fwd(cast_inputs=torch.float32)
|
69 |
+
def forward(ctx, x): # pylint: disable=arguments-differ
|
70 |
+
ctx.save_for_backward(x)
|
71 |
+
return torch.exp(x)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
@custom_bwd
|
75 |
+
def backward(ctx, g): # pylint: disable=arguments-differ
|
76 |
+
x = ctx.saved_tensors[0]
|
77 |
+
return g * torch.exp(torch.clamp(x, max=15))
|
78 |
+
|
79 |
+
|
80 |
+
trunc_exp = _TruncExp.apply
|
81 |
+
|
82 |
+
|
83 |
+
def get_activation(name) -> Callable:
|
84 |
+
if name is None:
|
85 |
+
return lambda x: x
|
86 |
+
name = name.lower()
|
87 |
+
if name == "none" or name == "linear" or name == "identity":
|
88 |
+
return lambda x: x
|
89 |
+
elif name == "lin2srgb":
|
90 |
+
return lambda x: torch.where(
|
91 |
+
x > 0.0031308,
|
92 |
+
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
|
93 |
+
12.92 * x,
|
94 |
+
).clamp(0.0, 1.0)
|
95 |
+
elif name == "exp":
|
96 |
+
return lambda x: torch.exp(x)
|
97 |
+
elif name == "shifted_exp":
|
98 |
+
return lambda x: torch.exp(x - 1.0)
|
99 |
+
elif name == "trunc_exp":
|
100 |
+
return trunc_exp
|
101 |
+
elif name == "shifted_trunc_exp":
|
102 |
+
return lambda x: trunc_exp(x - 1.0)
|
103 |
+
elif name == "sigmoid":
|
104 |
+
return lambda x: torch.sigmoid(x)
|
105 |
+
elif name == "tanh":
|
106 |
+
return lambda x: torch.tanh(x)
|
107 |
+
elif name == "shifted_softplus":
|
108 |
+
return lambda x: F.softplus(x - 1.0)
|
109 |
+
elif name == "scale_-11_01":
|
110 |
+
return lambda x: x * 0.5 + 0.5
|
111 |
+
elif name == "negative":
|
112 |
+
return lambda x: -x
|
113 |
+
elif name == "normalize_channel_last":
|
114 |
+
return lambda x: normalize(x)
|
115 |
+
elif name == "normalize_channel_first":
|
116 |
+
return lambda x: normalize(x, dim=1)
|
117 |
+
else:
|
118 |
+
try:
|
119 |
+
return getattr(F, name)
|
120 |
+
except AttributeError:
|
121 |
+
raise ValueError(f"Unknown activation function: {name}")
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class HeadSpec:
|
126 |
+
name: str
|
127 |
+
out_channels: int
|
128 |
+
n_hidden_layers: int
|
129 |
+
output_activation: Optional[str] = None
|
130 |
+
out_bias: float = 0.0
|
131 |
+
|
132 |
+
|
133 |
+
class MaterialMLP(BaseModule):
|
134 |
+
@dataclass
|
135 |
+
class Config(BaseModule.Config):
|
136 |
+
in_channels: int = 120
|
137 |
+
n_neurons: int = 64
|
138 |
+
activation: str = "silu"
|
139 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
140 |
+
|
141 |
+
cfg: Config
|
142 |
+
|
143 |
+
def configure(self) -> None:
|
144 |
+
assert len(self.cfg.heads) > 0
|
145 |
+
heads = {}
|
146 |
+
for head in self.cfg.heads:
|
147 |
+
head_layers = []
|
148 |
+
for i in range(head.n_hidden_layers):
|
149 |
+
head_layers += [
|
150 |
+
nn.Linear(
|
151 |
+
self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
|
152 |
+
self.cfg.n_neurons,
|
153 |
+
),
|
154 |
+
self.make_activation(self.cfg.activation),
|
155 |
+
]
|
156 |
+
head_layers += [
|
157 |
+
nn.Linear(
|
158 |
+
self.cfg.n_neurons,
|
159 |
+
head.out_channels,
|
160 |
+
),
|
161 |
+
]
|
162 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
163 |
+
self.heads = nn.ModuleDict(heads)
|
164 |
+
|
165 |
+
def make_activation(self, activation):
|
166 |
+
if activation == "relu":
|
167 |
+
return nn.ReLU(inplace=True)
|
168 |
+
elif activation == "silu":
|
169 |
+
return nn.SiLU(inplace=True)
|
170 |
+
else:
|
171 |
+
raise NotImplementedError
|
172 |
+
|
173 |
+
def keys(self):
|
174 |
+
return self.heads.keys()
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self, x, include: Optional[List] = None, exclude: Optional[List] = None
|
178 |
+
):
|
179 |
+
if include is not None and exclude is not None:
|
180 |
+
raise ValueError("Cannot specify both include and exclude.")
|
181 |
+
if include is not None:
|
182 |
+
heads = [h for h in self.cfg.heads if h.name in include]
|
183 |
+
elif exclude is not None:
|
184 |
+
heads = [h for h in self.cfg.heads if h.name not in exclude]
|
185 |
+
else:
|
186 |
+
heads = self.cfg.heads
|
187 |
+
|
188 |
+
out = {
|
189 |
+
head.name: get_activation(head.output_activation)(
|
190 |
+
self.heads[head.name](x) + head.out_bias
|
191 |
+
)
|
192 |
+
for head in heads
|
193 |
+
}
|
194 |
+
|
195 |
+
return out
|
sf3d/models/tokenizers/dinov2.py
ADDED
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch DINOv2 model."""
|
16 |
+
|
17 |
+
import collections.abc
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
from transformers.activations import ACT2FN
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BackboneOutput,
|
30 |
+
BaseModelOutput,
|
31 |
+
BaseModelOutputWithPooling,
|
32 |
+
ImageClassifierOutput,
|
33 |
+
)
|
34 |
+
from transformers.modeling_utils import PreTrainedModel
|
35 |
+
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
|
36 |
+
from transformers.pytorch_utils import (
|
37 |
+
find_pruneable_heads_and_indices,
|
38 |
+
prune_linear_layer,
|
39 |
+
)
|
40 |
+
from transformers.utils import (
|
41 |
+
add_code_sample_docstrings,
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from transformers.utils.backbone_utils import BackboneMixin
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
# General docstring
|
52 |
+
_CONFIG_FOR_DOC = "Dinov2Config"
|
53 |
+
|
54 |
+
# Base docstring
|
55 |
+
_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
56 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
57 |
+
|
58 |
+
# Image classification docstring
|
59 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
60 |
+
|
61 |
+
|
62 |
+
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
63 |
+
"facebook/dinov2-base",
|
64 |
+
# See all DINOv2 models at https://huggingface.co/models?filter=dinov2
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
class Dinov2Embeddings(nn.Module):
|
69 |
+
"""
|
70 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, config: Dinov2Config) -> None:
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
77 |
+
# register as mask token as it's not used in optimization
|
78 |
+
# to avoid the use of find_unused_parameters_true
|
79 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
|
80 |
+
self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
|
81 |
+
self.patch_embeddings = Dinov2PatchEmbeddings(config)
|
82 |
+
num_patches = self.patch_embeddings.num_patches
|
83 |
+
self.position_embeddings = nn.Parameter(
|
84 |
+
torch.randn(1, num_patches + 1, config.hidden_size)
|
85 |
+
)
|
86 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
87 |
+
self.config = config
|
88 |
+
|
89 |
+
def interpolate_pos_encoding(
|
90 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
94 |
+
resolution images.
|
95 |
+
|
96 |
+
Source:
|
97 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
98 |
+
"""
|
99 |
+
|
100 |
+
num_patches = embeddings.shape[1] - 1
|
101 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
102 |
+
if num_patches == num_positions and height == width:
|
103 |
+
return self.position_embeddings
|
104 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
105 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
106 |
+
dim = embeddings.shape[-1]
|
107 |
+
height = height // self.config.patch_size
|
108 |
+
width = width // self.config.patch_size
|
109 |
+
# we add a small number to avoid floating point error in the interpolation
|
110 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
111 |
+
height, width = height + 0.1, width + 0.1
|
112 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
113 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
114 |
+
)
|
115 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
116 |
+
patch_pos_embed = nn.functional.interpolate(
|
117 |
+
patch_pos_embed,
|
118 |
+
scale_factor=(
|
119 |
+
height / math.sqrt(num_positions),
|
120 |
+
width / math.sqrt(num_positions),
|
121 |
+
),
|
122 |
+
mode="bicubic",
|
123 |
+
align_corners=False,
|
124 |
+
)
|
125 |
+
if (
|
126 |
+
int(height) != patch_pos_embed.shape[-2]
|
127 |
+
or int(width) != patch_pos_embed.shape[-1]
|
128 |
+
):
|
129 |
+
raise ValueError(
|
130 |
+
"Width or height does not match with the interpolated position embeddings"
|
131 |
+
)
|
132 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
133 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
pixel_values: torch.Tensor,
|
138 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
batch_size, _, height, width = pixel_values.shape
|
141 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
142 |
+
embeddings = patch_embeddings
|
143 |
+
|
144 |
+
if bool_masked_pos is not None:
|
145 |
+
embeddings = torch.where(
|
146 |
+
bool_masked_pos.unsqueeze(-1),
|
147 |
+
self.mask_token.to(embeddings.dtype).unsqueeze(0),
|
148 |
+
embeddings,
|
149 |
+
)
|
150 |
+
|
151 |
+
# add the [CLS] token to the embedded patch tokens
|
152 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
153 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
154 |
+
|
155 |
+
# add positional encoding to each token
|
156 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
157 |
+
embeddings, height, width
|
158 |
+
)
|
159 |
+
|
160 |
+
embeddings = self.dropout(embeddings)
|
161 |
+
|
162 |
+
return embeddings
|
163 |
+
|
164 |
+
|
165 |
+
class Dinov2PatchEmbeddings(nn.Module):
|
166 |
+
"""
|
167 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
168 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
169 |
+
Transformer.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, config):
|
173 |
+
super().__init__()
|
174 |
+
image_size, patch_size = config.image_size, config.patch_size
|
175 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
176 |
+
|
177 |
+
image_size = (
|
178 |
+
image_size
|
179 |
+
if isinstance(image_size, collections.abc.Iterable)
|
180 |
+
else (image_size, image_size)
|
181 |
+
)
|
182 |
+
patch_size = (
|
183 |
+
patch_size
|
184 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
185 |
+
else (patch_size, patch_size)
|
186 |
+
)
|
187 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
188 |
+
image_size[0] // patch_size[0]
|
189 |
+
)
|
190 |
+
self.image_size = image_size
|
191 |
+
self.patch_size = patch_size
|
192 |
+
self.num_channels = num_channels
|
193 |
+
self.num_patches = num_patches
|
194 |
+
|
195 |
+
self.projection = nn.Conv2d(
|
196 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
200 |
+
"""
|
201 |
+
num_channels = pixel_values.shape[1]
|
202 |
+
if num_channels != self.num_channels:
|
203 |
+
raise ValueError(
|
204 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
205 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
206 |
+
)
|
207 |
+
"""
|
208 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
209 |
+
return embeddings
|
210 |
+
|
211 |
+
|
212 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
|
213 |
+
class Dinov2SelfAttention(nn.Module):
|
214 |
+
def __init__(self, config: Dinov2Config) -> None:
|
215 |
+
super().__init__()
|
216 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
217 |
+
config, "embedding_size"
|
218 |
+
):
|
219 |
+
raise ValueError(
|
220 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
221 |
+
f"heads {config.num_attention_heads}."
|
222 |
+
)
|
223 |
+
|
224 |
+
self.num_attention_heads = config.num_attention_heads
|
225 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
226 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
227 |
+
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
228 |
+
|
229 |
+
self.query = nn.Linear(
|
230 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
231 |
+
)
|
232 |
+
self.key = nn.Linear(
|
233 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
234 |
+
)
|
235 |
+
self.value = nn.Linear(
|
236 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
237 |
+
)
|
238 |
+
|
239 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
240 |
+
|
241 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
242 |
+
new_x_shape = x.size()[:-1] + (
|
243 |
+
self.num_attention_heads,
|
244 |
+
self.attention_head_size,
|
245 |
+
)
|
246 |
+
x = x.view(new_x_shape)
|
247 |
+
return x.permute(0, 2, 1, 3)
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
hidden_states,
|
252 |
+
head_mask: Optional[torch.Tensor] = None,
|
253 |
+
output_attentions: bool = False,
|
254 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
255 |
+
mixed_query_layer = self.query(hidden_states)
|
256 |
+
|
257 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
258 |
+
assert head_mask is None and not output_attentions
|
259 |
+
new_size = hidden_states.size()[:-1] + (
|
260 |
+
self.num_attention_heads,
|
261 |
+
self.attention_head_size,
|
262 |
+
)
|
263 |
+
key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
|
264 |
+
value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
|
265 |
+
query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
|
266 |
+
context_layer = F.scaled_dot_product_attention(
|
267 |
+
query_layer,
|
268 |
+
key_layer,
|
269 |
+
value_layer,
|
270 |
+
dropout_p=self.attention_probs_dropout_prob,
|
271 |
+
is_causal=False,
|
272 |
+
)
|
273 |
+
context_layer = context_layer.transpose(1, 2).reshape(
|
274 |
+
*hidden_states.size()[:-1], -1
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
278 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
279 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
280 |
+
|
281 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
282 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
283 |
+
|
284 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
285 |
+
|
286 |
+
# Normalize the attention scores to probabilities.
|
287 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
288 |
+
|
289 |
+
# This is actually dropping out entire tokens to attend to, which might
|
290 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
291 |
+
attention_probs = self.dropout(attention_probs)
|
292 |
+
|
293 |
+
# Mask heads if we want to
|
294 |
+
if head_mask is not None:
|
295 |
+
attention_probs = attention_probs * head_mask
|
296 |
+
|
297 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
298 |
+
|
299 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
300 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
301 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
302 |
+
|
303 |
+
outputs = (
|
304 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
305 |
+
)
|
306 |
+
|
307 |
+
return outputs
|
308 |
+
|
309 |
+
|
310 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
311 |
+
class Dinov2SelfOutput(nn.Module):
|
312 |
+
"""
|
313 |
+
The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
|
314 |
+
layernorm applied before each block.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, config: Dinov2Config) -> None:
|
318 |
+
super().__init__()
|
319 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
320 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
324 |
+
) -> torch.Tensor:
|
325 |
+
hidden_states = self.dense(hidden_states)
|
326 |
+
hidden_states = self.dropout(hidden_states)
|
327 |
+
|
328 |
+
return hidden_states
|
329 |
+
|
330 |
+
|
331 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
|
332 |
+
class Dinov2Attention(nn.Module):
|
333 |
+
def __init__(self, config: Dinov2Config) -> None:
|
334 |
+
super().__init__()
|
335 |
+
self.attention = Dinov2SelfAttention(config)
|
336 |
+
self.output = Dinov2SelfOutput(config)
|
337 |
+
self.pruned_heads = set()
|
338 |
+
|
339 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
340 |
+
if len(heads) == 0:
|
341 |
+
return
|
342 |
+
heads, index = find_pruneable_heads_and_indices(
|
343 |
+
heads,
|
344 |
+
self.attention.num_attention_heads,
|
345 |
+
self.attention.attention_head_size,
|
346 |
+
self.pruned_heads,
|
347 |
+
)
|
348 |
+
|
349 |
+
# Prune linear layers
|
350 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
351 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
352 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
353 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
354 |
+
|
355 |
+
# Update hyper params and store pruned heads
|
356 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
|
357 |
+
heads
|
358 |
+
)
|
359 |
+
self.attention.all_head_size = (
|
360 |
+
self.attention.attention_head_size * self.attention.num_attention_heads
|
361 |
+
)
|
362 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
head_mask: Optional[torch.Tensor] = None,
|
368 |
+
output_attentions: bool = False,
|
369 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
370 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
371 |
+
|
372 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
373 |
+
|
374 |
+
outputs = (attention_output,) + self_outputs[
|
375 |
+
1:
|
376 |
+
] # add attentions if we output them
|
377 |
+
return outputs
|
378 |
+
|
379 |
+
|
380 |
+
class Dinov2LayerScale(nn.Module):
|
381 |
+
def __init__(self, config) -> None:
|
382 |
+
super().__init__()
|
383 |
+
self.lambda1 = nn.Parameter(
|
384 |
+
config.layerscale_value * torch.ones(config.hidden_size)
|
385 |
+
)
|
386 |
+
|
387 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
388 |
+
return hidden_state * self.lambda1
|
389 |
+
|
390 |
+
|
391 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
392 |
+
def drop_path(
|
393 |
+
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
394 |
+
) -> torch.Tensor:
|
395 |
+
"""
|
396 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
397 |
+
|
398 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
399 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
400 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
401 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
402 |
+
argument.
|
403 |
+
"""
|
404 |
+
if drop_prob == 0.0 or not training:
|
405 |
+
return input
|
406 |
+
keep_prob = 1 - drop_prob
|
407 |
+
shape = (input.shape[0],) + (1,) * (
|
408 |
+
input.ndim - 1
|
409 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
410 |
+
random_tensor = keep_prob + torch.rand(
|
411 |
+
shape, dtype=input.dtype, device=input.device
|
412 |
+
)
|
413 |
+
random_tensor.floor_() # binarize
|
414 |
+
output = input.div(keep_prob) * random_tensor
|
415 |
+
return output
|
416 |
+
|
417 |
+
|
418 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
419 |
+
class Dinov2DropPath(nn.Module):
|
420 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
421 |
+
|
422 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
423 |
+
super().__init__()
|
424 |
+
self.drop_prob = drop_prob
|
425 |
+
|
426 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
427 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
428 |
+
|
429 |
+
def extra_repr(self) -> str:
|
430 |
+
return "p={}".format(self.drop_prob)
|
431 |
+
|
432 |
+
|
433 |
+
class Dinov2MLP(nn.Module):
|
434 |
+
def __init__(self, config) -> None:
|
435 |
+
super().__init__()
|
436 |
+
in_features = out_features = config.hidden_size
|
437 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
438 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
439 |
+
if isinstance(config.hidden_act, str):
|
440 |
+
self.activation = ACT2FN[config.hidden_act]
|
441 |
+
else:
|
442 |
+
self.activation = config.hidden_act
|
443 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
444 |
+
|
445 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
446 |
+
hidden_state = self.fc1(hidden_state)
|
447 |
+
hidden_state = self.activation(hidden_state)
|
448 |
+
hidden_state = self.fc2(hidden_state)
|
449 |
+
return hidden_state
|
450 |
+
|
451 |
+
|
452 |
+
class Dinov2SwiGLUFFN(nn.Module):
|
453 |
+
def __init__(self, config) -> None:
|
454 |
+
super().__init__()
|
455 |
+
in_features = out_features = config.hidden_size
|
456 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
457 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
458 |
+
|
459 |
+
self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
|
460 |
+
self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
|
461 |
+
|
462 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
463 |
+
hidden_state = self.weights_in(hidden_state)
|
464 |
+
x1, x2 = hidden_state.chunk(2, dim=-1)
|
465 |
+
hidden = nn.functional.silu(x1) * x2
|
466 |
+
return self.weights_out(hidden)
|
467 |
+
|
468 |
+
|
469 |
+
class Dinov2Layer(nn.Module):
|
470 |
+
"""This corresponds to the Block class in the original implementation."""
|
471 |
+
|
472 |
+
def __init__(self, config: Dinov2Config) -> None:
|
473 |
+
super().__init__()
|
474 |
+
|
475 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
476 |
+
self.norm1_modulation = None
|
477 |
+
self.attention = Dinov2Attention(config)
|
478 |
+
self.layer_scale1 = Dinov2LayerScale(config)
|
479 |
+
self.drop_path1 = (
|
480 |
+
Dinov2DropPath(config.drop_path_rate)
|
481 |
+
if config.drop_path_rate > 0.0
|
482 |
+
else nn.Identity()
|
483 |
+
)
|
484 |
+
|
485 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
486 |
+
self.norm2_modulation = None
|
487 |
+
|
488 |
+
if config.use_swiglu_ffn:
|
489 |
+
self.mlp = Dinov2SwiGLUFFN(config)
|
490 |
+
else:
|
491 |
+
self.mlp = Dinov2MLP(config)
|
492 |
+
self.layer_scale2 = Dinov2LayerScale(config)
|
493 |
+
self.drop_path2 = (
|
494 |
+
Dinov2DropPath(config.drop_path_rate)
|
495 |
+
if config.drop_path_rate > 0.0
|
496 |
+
else nn.Identity()
|
497 |
+
)
|
498 |
+
|
499 |
+
def forward(
|
500 |
+
self,
|
501 |
+
hidden_states: torch.Tensor,
|
502 |
+
head_mask: Optional[torch.Tensor] = None,
|
503 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
504 |
+
output_attentions: bool = False,
|
505 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
506 |
+
hidden_states_norm = self.norm1(hidden_states)
|
507 |
+
if self.norm1_modulation is not None:
|
508 |
+
assert modulation_cond is not None
|
509 |
+
hidden_states_norm = self.norm1_modulation(
|
510 |
+
hidden_states_norm, modulation_cond
|
511 |
+
)
|
512 |
+
self_attention_outputs = self.attention(
|
513 |
+
hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
|
514 |
+
head_mask,
|
515 |
+
output_attentions=output_attentions,
|
516 |
+
)
|
517 |
+
attention_output = self_attention_outputs[0]
|
518 |
+
|
519 |
+
attention_output = self.layer_scale1(attention_output)
|
520 |
+
outputs = self_attention_outputs[
|
521 |
+
1:
|
522 |
+
] # add self attentions if we output attention weights
|
523 |
+
|
524 |
+
# first residual connection
|
525 |
+
hidden_states = attention_output + hidden_states
|
526 |
+
|
527 |
+
# in Dinov2, layernorm is also applied after self-attention
|
528 |
+
layer_output = self.norm2(hidden_states)
|
529 |
+
if self.norm2_modulation is not None:
|
530 |
+
assert modulation_cond is not None
|
531 |
+
layer_output = self.norm2_modulation(layer_output, modulation_cond)
|
532 |
+
layer_output = self.mlp(layer_output)
|
533 |
+
layer_output = self.layer_scale2(layer_output)
|
534 |
+
|
535 |
+
# second residual connection
|
536 |
+
layer_output = layer_output + hidden_states
|
537 |
+
|
538 |
+
outputs = (layer_output,) + outputs
|
539 |
+
|
540 |
+
return outputs
|
541 |
+
|
542 |
+
def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
|
543 |
+
self.norm1_modulation = norm1_mod
|
544 |
+
self.norm2_modulation = norm2_mod
|
545 |
+
|
546 |
+
|
547 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
|
548 |
+
class Dinov2Encoder(nn.Module):
|
549 |
+
def __init__(self, config: Dinov2Config) -> None:
|
550 |
+
super().__init__()
|
551 |
+
self.config = config
|
552 |
+
self.layer = nn.ModuleList(
|
553 |
+
[Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
|
554 |
+
)
|
555 |
+
self.gradient_checkpointing = False
|
556 |
+
|
557 |
+
def forward(
|
558 |
+
self,
|
559 |
+
hidden_states: torch.Tensor,
|
560 |
+
head_mask: Optional[torch.Tensor] = None,
|
561 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
562 |
+
output_attentions: bool = False,
|
563 |
+
output_hidden_states: bool = False,
|
564 |
+
return_dict: bool = True,
|
565 |
+
) -> Union[tuple, BaseModelOutput]:
|
566 |
+
all_hidden_states = () if output_hidden_states else None
|
567 |
+
all_self_attentions = () if output_attentions else None
|
568 |
+
|
569 |
+
for i, layer_module in enumerate(self.layer):
|
570 |
+
if output_hidden_states:
|
571 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
572 |
+
|
573 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
574 |
+
|
575 |
+
if self.gradient_checkpointing and self.training:
|
576 |
+
|
577 |
+
def create_custom_forward(module):
|
578 |
+
def custom_forward(*inputs):
|
579 |
+
return module(*inputs, output_attentions)
|
580 |
+
|
581 |
+
return custom_forward
|
582 |
+
|
583 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
584 |
+
create_custom_forward(layer_module),
|
585 |
+
hidden_states,
|
586 |
+
layer_head_mask,
|
587 |
+
modulation_cond,
|
588 |
+
use_reentrant=False,
|
589 |
+
)
|
590 |
+
else:
|
591 |
+
layer_outputs = layer_module(
|
592 |
+
hidden_states, layer_head_mask, modulation_cond, output_attentions
|
593 |
+
)
|
594 |
+
|
595 |
+
hidden_states = layer_outputs[0]
|
596 |
+
|
597 |
+
if output_attentions:
|
598 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
599 |
+
|
600 |
+
if output_hidden_states:
|
601 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
602 |
+
|
603 |
+
if not return_dict:
|
604 |
+
return tuple(
|
605 |
+
v
|
606 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
607 |
+
if v is not None
|
608 |
+
)
|
609 |
+
return BaseModelOutput(
|
610 |
+
last_hidden_state=hidden_states,
|
611 |
+
hidden_states=all_hidden_states,
|
612 |
+
attentions=all_self_attentions,
|
613 |
+
)
|
614 |
+
|
615 |
+
|
616 |
+
class Dinov2PreTrainedModel(PreTrainedModel):
|
617 |
+
"""
|
618 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
619 |
+
models.
|
620 |
+
"""
|
621 |
+
|
622 |
+
config_class = Dinov2Config
|
623 |
+
base_model_prefix = "dinov2"
|
624 |
+
main_input_name = "pixel_values"
|
625 |
+
supports_gradient_checkpointing = True
|
626 |
+
|
627 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
628 |
+
"""Initialize the weights"""
|
629 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
630 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
631 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
632 |
+
module.weight.data = nn.init.trunc_normal_(
|
633 |
+
module.weight.data.to(torch.float32),
|
634 |
+
mean=0.0,
|
635 |
+
std=self.config.initializer_range,
|
636 |
+
).to(module.weight.dtype)
|
637 |
+
if module.bias is not None:
|
638 |
+
module.bias.data.zero_()
|
639 |
+
elif isinstance(module, nn.LayerNorm):
|
640 |
+
module.bias.data.zero_()
|
641 |
+
module.weight.data.fill_(1.0)
|
642 |
+
elif isinstance(module, Dinov2Embeddings):
|
643 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
644 |
+
module.position_embeddings.data.to(torch.float32),
|
645 |
+
mean=0.0,
|
646 |
+
std=self.config.initializer_range,
|
647 |
+
).to(module.position_embeddings.dtype)
|
648 |
+
|
649 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
650 |
+
module.cls_token.data.to(torch.float32),
|
651 |
+
mean=0.0,
|
652 |
+
std=self.config.initializer_range,
|
653 |
+
).to(module.cls_token.dtype)
|
654 |
+
|
655 |
+
def _set_gradient_checkpointing(
|
656 |
+
self, module: Dinov2Encoder, value: bool = False
|
657 |
+
) -> None:
|
658 |
+
if isinstance(module, Dinov2Encoder):
|
659 |
+
module.gradient_checkpointing = value
|
660 |
+
|
661 |
+
|
662 |
+
DINOV2_START_DOCSTRING = r"""
|
663 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
664 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
665 |
+
behavior.
|
666 |
+
|
667 |
+
Parameters:
|
668 |
+
config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
|
669 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
670 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
671 |
+
"""
|
672 |
+
|
673 |
+
DINOV2_BASE_INPUTS_DOCSTRING = r"""
|
674 |
+
Args:
|
675 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
676 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
677 |
+
[`BitImageProcessor.preprocess`] for details.
|
678 |
+
|
679 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
680 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
|
681 |
+
pre-training.
|
682 |
+
|
683 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
684 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
685 |
+
|
686 |
+
- 1 indicates the head is **not masked**,
|
687 |
+
- 0 indicates the head is **masked**.
|
688 |
+
|
689 |
+
output_attentions (`bool`, *optional*):
|
690 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
691 |
+
tensors for more detail.
|
692 |
+
output_hidden_states (`bool`, *optional*):
|
693 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
694 |
+
more detail.
|
695 |
+
return_dict (`bool`, *optional*):
|
696 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
697 |
+
"""
|
698 |
+
|
699 |
+
DINOV2_INPUTS_DOCSTRING = r"""
|
700 |
+
Args:
|
701 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
702 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
703 |
+
[`BitImageProcessor.preprocess`] for details.
|
704 |
+
|
705 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
706 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
707 |
+
|
708 |
+
- 1 indicates the head is **not masked**,
|
709 |
+
- 0 indicates the head is **masked**.
|
710 |
+
|
711 |
+
output_attentions (`bool`, *optional*):
|
712 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
713 |
+
tensors for more detail.
|
714 |
+
output_hidden_states (`bool`, *optional*):
|
715 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
716 |
+
more detail.
|
717 |
+
return_dict (`bool`, *optional*):
|
718 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
719 |
+
"""
|
720 |
+
|
721 |
+
|
722 |
+
@dataclass
|
723 |
+
class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
|
724 |
+
patch_embeddings: Optional[torch.FloatTensor] = None
|
725 |
+
|
726 |
+
|
727 |
+
@add_start_docstrings(
|
728 |
+
"The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
|
729 |
+
DINOV2_START_DOCSTRING,
|
730 |
+
)
|
731 |
+
class Dinov2Model(Dinov2PreTrainedModel):
|
732 |
+
def __init__(self, config: Dinov2Config):
|
733 |
+
super().__init__(config)
|
734 |
+
self.config = config
|
735 |
+
|
736 |
+
self.embeddings = Dinov2Embeddings(config)
|
737 |
+
self.encoder = Dinov2Encoder(config)
|
738 |
+
|
739 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
740 |
+
|
741 |
+
# Initialize weights and apply final processing
|
742 |
+
self.post_init()
|
743 |
+
|
744 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
745 |
+
return self.embeddings.patch_embeddings
|
746 |
+
|
747 |
+
def expand_input_channels(self, extra_input_channels: int) -> None:
|
748 |
+
if extra_input_channels == 0:
|
749 |
+
return
|
750 |
+
conv_old = self.embeddings.patch_embeddings.projection
|
751 |
+
conv_new = nn.Conv2d(
|
752 |
+
self.config.num_channels + extra_input_channels,
|
753 |
+
self.config.hidden_size,
|
754 |
+
kernel_size=self.config.patch_size,
|
755 |
+
stride=self.config.patch_size,
|
756 |
+
).to(self.device)
|
757 |
+
with torch.no_grad():
|
758 |
+
conv_new.weight[:, :3] = conv_old.weight
|
759 |
+
conv_new.bias = conv_old.bias
|
760 |
+
self.embeddings.patch_embeddings.projection = conv_new
|
761 |
+
del conv_old
|
762 |
+
|
763 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
764 |
+
"""
|
765 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
766 |
+
class PreTrainedModel
|
767 |
+
"""
|
768 |
+
for layer, heads in heads_to_prune.items():
|
769 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
770 |
+
|
771 |
+
@add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
|
772 |
+
@add_code_sample_docstrings(
|
773 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
774 |
+
output_type=BaseModelOutputWithPooling,
|
775 |
+
config_class=_CONFIG_FOR_DOC,
|
776 |
+
modality="vision",
|
777 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
778 |
+
)
|
779 |
+
def forward(
|
780 |
+
self,
|
781 |
+
pixel_values: Optional[torch.Tensor] = None,
|
782 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
783 |
+
head_mask: Optional[torch.Tensor] = None,
|
784 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
785 |
+
output_attentions: Optional[bool] = None,
|
786 |
+
output_hidden_states: Optional[bool] = None,
|
787 |
+
return_dict: Optional[bool] = None,
|
788 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
789 |
+
output_attentions = (
|
790 |
+
output_attentions
|
791 |
+
if output_attentions is not None
|
792 |
+
else self.config.output_attentions
|
793 |
+
)
|
794 |
+
output_hidden_states = (
|
795 |
+
output_hidden_states
|
796 |
+
if output_hidden_states is not None
|
797 |
+
else self.config.output_hidden_states
|
798 |
+
)
|
799 |
+
return_dict = (
|
800 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
801 |
+
)
|
802 |
+
|
803 |
+
if pixel_values is None:
|
804 |
+
raise ValueError("You have to specify pixel_values")
|
805 |
+
|
806 |
+
# Prepare head mask if needed
|
807 |
+
# 1.0 in head_mask indicate we keep the head
|
808 |
+
# attention_probs has shape bsz x n_heads x N x N
|
809 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
810 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
811 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
812 |
+
|
813 |
+
embedding_output = self.embeddings(
|
814 |
+
pixel_values, bool_masked_pos=bool_masked_pos
|
815 |
+
)
|
816 |
+
|
817 |
+
encoder_outputs = self.encoder(
|
818 |
+
embedding_output,
|
819 |
+
head_mask=head_mask,
|
820 |
+
modulation_cond=modulation_cond,
|
821 |
+
output_attentions=output_attentions,
|
822 |
+
output_hidden_states=output_hidden_states,
|
823 |
+
return_dict=return_dict,
|
824 |
+
)
|
825 |
+
sequence_output = encoder_outputs[0]
|
826 |
+
sequence_output = self.layernorm(sequence_output)
|
827 |
+
pooled_output = sequence_output[:, 0, :]
|
828 |
+
|
829 |
+
if not return_dict:
|
830 |
+
head_outputs = (sequence_output, pooled_output)
|
831 |
+
return head_outputs + encoder_outputs[1:]
|
832 |
+
|
833 |
+
return CustomBaseModelOutputWithPooling(
|
834 |
+
last_hidden_state=sequence_output,
|
835 |
+
pooler_output=pooled_output,
|
836 |
+
hidden_states=encoder_outputs.hidden_states,
|
837 |
+
attentions=encoder_outputs.attentions,
|
838 |
+
patch_embeddings=embedding_output,
|
839 |
+
)
|
840 |
+
|
841 |
+
def set_gradient_checkpointing(self, value: bool = False) -> None:
|
842 |
+
self._set_gradient_checkpointing(self.encoder, value)
|
843 |
+
|
844 |
+
|
845 |
+
@add_start_docstrings(
|
846 |
+
"""
|
847 |
+
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
848 |
+
of the [CLS] token) e.g. for ImageNet.
|
849 |
+
""",
|
850 |
+
DINOV2_START_DOCSTRING,
|
851 |
+
)
|
852 |
+
class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
853 |
+
def __init__(self, config: Dinov2Config) -> None:
|
854 |
+
super().__init__(config)
|
855 |
+
|
856 |
+
self.num_labels = config.num_labels
|
857 |
+
self.dinov2 = Dinov2Model(config)
|
858 |
+
|
859 |
+
# Classifier head
|
860 |
+
self.classifier = (
|
861 |
+
nn.Linear(config.hidden_size * 2, config.num_labels)
|
862 |
+
if config.num_labels > 0
|
863 |
+
else nn.Identity()
|
864 |
+
)
|
865 |
+
|
866 |
+
# Initialize weights and apply final processing
|
867 |
+
self.post_init()
|
868 |
+
|
869 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
870 |
+
@add_code_sample_docstrings(
|
871 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
872 |
+
output_type=ImageClassifierOutput,
|
873 |
+
config_class=_CONFIG_FOR_DOC,
|
874 |
+
)
|
875 |
+
def forward(
|
876 |
+
self,
|
877 |
+
pixel_values: Optional[torch.Tensor] = None,
|
878 |
+
head_mask: Optional[torch.Tensor] = None,
|
879 |
+
labels: Optional[torch.Tensor] = None,
|
880 |
+
output_attentions: Optional[bool] = None,
|
881 |
+
output_hidden_states: Optional[bool] = None,
|
882 |
+
return_dict: Optional[bool] = None,
|
883 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
884 |
+
r"""
|
885 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
886 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
887 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
888 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
889 |
+
"""
|
890 |
+
return_dict = (
|
891 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
892 |
+
)
|
893 |
+
|
894 |
+
outputs = self.dinov2(
|
895 |
+
pixel_values,
|
896 |
+
head_mask=head_mask,
|
897 |
+
output_attentions=output_attentions,
|
898 |
+
output_hidden_states=output_hidden_states,
|
899 |
+
return_dict=return_dict,
|
900 |
+
)
|
901 |
+
|
902 |
+
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
|
903 |
+
|
904 |
+
cls_token = sequence_output[:, 0]
|
905 |
+
patch_tokens = sequence_output[:, 1:]
|
906 |
+
|
907 |
+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
|
908 |
+
|
909 |
+
logits = self.classifier(linear_input)
|
910 |
+
|
911 |
+
loss = None
|
912 |
+
if labels is not None:
|
913 |
+
# move labels to correct device to enable model parallelism
|
914 |
+
labels = labels.to(logits.device)
|
915 |
+
if self.config.problem_type is None:
|
916 |
+
if self.num_labels == 1:
|
917 |
+
self.config.problem_type = "regression"
|
918 |
+
elif self.num_labels > 1 and (
|
919 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
920 |
+
):
|
921 |
+
self.config.problem_type = "single_label_classification"
|
922 |
+
else:
|
923 |
+
self.config.problem_type = "multi_label_classification"
|
924 |
+
|
925 |
+
if self.config.problem_type == "regression":
|
926 |
+
loss_fct = MSELoss()
|
927 |
+
if self.num_labels == 1:
|
928 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
929 |
+
else:
|
930 |
+
loss = loss_fct(logits, labels)
|
931 |
+
elif self.config.problem_type == "single_label_classification":
|
932 |
+
loss_fct = CrossEntropyLoss()
|
933 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
934 |
+
elif self.config.problem_type == "multi_label_classification":
|
935 |
+
loss_fct = BCEWithLogitsLoss()
|
936 |
+
loss = loss_fct(logits, labels)
|
937 |
+
|
938 |
+
if not return_dict:
|
939 |
+
output = (logits,) + outputs[2:]
|
940 |
+
return ((loss,) + output) if loss is not None else output
|
941 |
+
|
942 |
+
return ImageClassifierOutput(
|
943 |
+
loss=loss,
|
944 |
+
logits=logits,
|
945 |
+
hidden_states=outputs.hidden_states,
|
946 |
+
attentions=outputs.attentions,
|
947 |
+
)
|
948 |
+
|
949 |
+
|
950 |
+
@add_start_docstrings(
|
951 |
+
"""
|
952 |
+
Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
|
953 |
+
""",
|
954 |
+
DINOV2_START_DOCSTRING,
|
955 |
+
)
|
956 |
+
class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
|
957 |
+
def __init__(self, config):
|
958 |
+
super().__init__(config)
|
959 |
+
super()._init_backbone(config)
|
960 |
+
|
961 |
+
self.num_features = [
|
962 |
+
config.hidden_size for _ in range(config.num_hidden_layers + 1)
|
963 |
+
]
|
964 |
+
self.embeddings = Dinov2Embeddings(config)
|
965 |
+
self.encoder = Dinov2Encoder(config)
|
966 |
+
|
967 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
968 |
+
|
969 |
+
# Initialize weights and apply final processing
|
970 |
+
self.post_init()
|
971 |
+
|
972 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
973 |
+
return self.embeddings.patch_embeddings
|
974 |
+
|
975 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
976 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
977 |
+
def forward(
|
978 |
+
self,
|
979 |
+
pixel_values: torch.Tensor,
|
980 |
+
output_hidden_states: Optional[bool] = None,
|
981 |
+
output_attentions: Optional[bool] = None,
|
982 |
+
return_dict: Optional[bool] = None,
|
983 |
+
) -> BackboneOutput:
|
984 |
+
"""
|
985 |
+
Returns:
|
986 |
+
|
987 |
+
Examples:
|
988 |
+
|
989 |
+
```python
|
990 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
991 |
+
>>> import torch
|
992 |
+
>>> from PIL import Image
|
993 |
+
>>> import requests
|
994 |
+
|
995 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
996 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
997 |
+
|
998 |
+
>>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
999 |
+
>>> model = AutoBackbone.from_pretrained(
|
1000 |
+
... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
|
1001 |
+
... )
|
1002 |
+
|
1003 |
+
>>> inputs = processor(image, return_tensors="pt")
|
1004 |
+
|
1005 |
+
>>> outputs = model(**inputs)
|
1006 |
+
>>> feature_maps = outputs.feature_maps
|
1007 |
+
>>> list(feature_maps[-1].shape)
|
1008 |
+
[1, 768, 16, 16]
|
1009 |
+
```"""
|
1010 |
+
return_dict = (
|
1011 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1012 |
+
)
|
1013 |
+
output_hidden_states = (
|
1014 |
+
output_hidden_states
|
1015 |
+
if output_hidden_states is not None
|
1016 |
+
else self.config.output_hidden_states
|
1017 |
+
)
|
1018 |
+
output_attentions = (
|
1019 |
+
output_attentions
|
1020 |
+
if output_attentions is not None
|
1021 |
+
else self.config.output_attentions
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
embedding_output = self.embeddings(pixel_values)
|
1025 |
+
|
1026 |
+
outputs = self.encoder(
|
1027 |
+
embedding_output,
|
1028 |
+
output_hidden_states=True,
|
1029 |
+
output_attentions=output_attentions,
|
1030 |
+
return_dict=return_dict,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
1034 |
+
|
1035 |
+
feature_maps = ()
|
1036 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
1037 |
+
if stage in self.out_features:
|
1038 |
+
if self.config.apply_layernorm:
|
1039 |
+
hidden_state = self.layernorm(hidden_state)
|
1040 |
+
if self.config.reshape_hidden_states:
|
1041 |
+
batch_size, _, height, width = pixel_values.shape
|
1042 |
+
patch_size = self.config.patch_size
|
1043 |
+
hidden_state = hidden_state[:, 1:, :].reshape(
|
1044 |
+
batch_size, width // patch_size, height // patch_size, -1
|
1045 |
+
)
|
1046 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
1047 |
+
feature_maps += (hidden_state,)
|
1048 |
+
|
1049 |
+
if not return_dict:
|
1050 |
+
if output_hidden_states:
|
1051 |
+
output = (feature_maps,) + outputs[1:]
|
1052 |
+
else:
|
1053 |
+
output = (feature_maps,) + outputs[2:]
|
1054 |
+
return output
|
1055 |
+
|
1056 |
+
return BackboneOutput(
|
1057 |
+
feature_maps=feature_maps,
|
1058 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1059 |
+
attentions=outputs.attentions if output_attentions else None,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
|
1063 |
+
class CustomPatchEmbeddings(nn.Module):
|
1064 |
+
"""
|
1065 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
1066 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
1067 |
+
Transformer.
|
1068 |
+
"""
|
1069 |
+
|
1070 |
+
def __init__(
|
1071 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
1072 |
+
):
|
1073 |
+
super().__init__()
|
1074 |
+
|
1075 |
+
image_size = (
|
1076 |
+
image_size
|
1077 |
+
if isinstance(image_size, collections.abc.Iterable)
|
1078 |
+
else (image_size, image_size)
|
1079 |
+
)
|
1080 |
+
patch_size = (
|
1081 |
+
patch_size
|
1082 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
1083 |
+
else (patch_size, patch_size)
|
1084 |
+
)
|
1085 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
1086 |
+
image_size[0] // patch_size[0]
|
1087 |
+
)
|
1088 |
+
self.image_size = image_size
|
1089 |
+
self.patch_size = patch_size
|
1090 |
+
self.num_channels = num_channels
|
1091 |
+
self.num_patches = num_patches
|
1092 |
+
|
1093 |
+
self.projection = nn.Conv2d(
|
1094 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
1098 |
+
num_channels = pixel_values.shape[1]
|
1099 |
+
if num_channels != self.num_channels:
|
1100 |
+
raise ValueError(
|
1101 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
1102 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
1103 |
+
)
|
1104 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
1105 |
+
return embeddings
|
1106 |
+
|
1107 |
+
|
1108 |
+
class CustomEmbeddings(nn.Module):
|
1109 |
+
"""
|
1110 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
1111 |
+
"""
|
1112 |
+
|
1113 |
+
def __init__(
|
1114 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
1115 |
+
) -> None:
|
1116 |
+
super().__init__()
|
1117 |
+
|
1118 |
+
self.image_size = image_size
|
1119 |
+
self.patch_size = patch_size
|
1120 |
+
self.num_channels = num_channels
|
1121 |
+
self.hidden_size = hidden_size
|
1122 |
+
|
1123 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
1124 |
+
|
1125 |
+
self.patch_embeddings = CustomPatchEmbeddings(
|
1126 |
+
image_size, patch_size, num_channels, hidden_size
|
1127 |
+
)
|
1128 |
+
num_patches = self.patch_embeddings.num_patches
|
1129 |
+
self.position_embeddings = nn.Parameter(
|
1130 |
+
torch.randn(1, num_patches + 1, self.hidden_size)
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
def interpolate_pos_encoding(
|
1134 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
1135 |
+
) -> torch.Tensor:
|
1136 |
+
"""
|
1137 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
1138 |
+
resolution images.
|
1139 |
+
|
1140 |
+
Source:
|
1141 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
1142 |
+
"""
|
1143 |
+
|
1144 |
+
num_patches = embeddings.shape[1] - 1
|
1145 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
1146 |
+
if num_patches == num_positions and height == width:
|
1147 |
+
return self.position_embeddings
|
1148 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
1149 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
1150 |
+
dim = embeddings.shape[-1]
|
1151 |
+
height = height // self.patch_size
|
1152 |
+
width = width // self.patch_size
|
1153 |
+
# we add a small number to avoid floating point error in the interpolation
|
1154 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
1155 |
+
height, width = height + 0.1, width + 0.1
|
1156 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
1157 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
1158 |
+
)
|
1159 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
1160 |
+
patch_pos_embed = nn.functional.interpolate(
|
1161 |
+
patch_pos_embed,
|
1162 |
+
scale_factor=(
|
1163 |
+
height / math.sqrt(num_positions),
|
1164 |
+
width / math.sqrt(num_positions),
|
1165 |
+
),
|
1166 |
+
mode="bicubic",
|
1167 |
+
align_corners=False,
|
1168 |
+
)
|
1169 |
+
if (
|
1170 |
+
int(height) != patch_pos_embed.shape[-2]
|
1171 |
+
or int(width) != patch_pos_embed.shape[-1]
|
1172 |
+
):
|
1173 |
+
raise ValueError(
|
1174 |
+
"Width or height does not match with the interpolated position embeddings"
|
1175 |
+
)
|
1176 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
1177 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
1178 |
+
|
1179 |
+
def forward(
|
1180 |
+
self,
|
1181 |
+
pixel_values: torch.Tensor,
|
1182 |
+
) -> torch.Tensor:
|
1183 |
+
batch_size, _, height, width = pixel_values.shape
|
1184 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
1185 |
+
embeddings = patch_embeddings
|
1186 |
+
|
1187 |
+
# add the [CLS] token to the embedded patch tokens
|
1188 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
1189 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
1190 |
+
|
1191 |
+
# add positional encoding to each token
|
1192 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
1193 |
+
embeddings, height, width
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
return embeddings
|
sf3d/models/tokenizers/image.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.models.tokenizers.dinov2 import Dinov2Model
|
11 |
+
from sf3d.models.transformers.attention import Modulation
|
12 |
+
from sf3d.models.utils import BaseModule
|
13 |
+
|
14 |
+
|
15 |
+
class DINOV2SingleImageTokenizer(BaseModule):
|
16 |
+
@dataclass
|
17 |
+
class Config(BaseModule.Config):
|
18 |
+
pretrained_model_name_or_path: str = "facebook/dinov2-large"
|
19 |
+
width: int = 512
|
20 |
+
height: int = 512
|
21 |
+
modulation_cond_dim: int = 768
|
22 |
+
|
23 |
+
cfg: Config
|
24 |
+
|
25 |
+
def configure(self) -> None:
|
26 |
+
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
|
27 |
+
|
28 |
+
for p in self.model.parameters():
|
29 |
+
p.requires_grad_(False)
|
30 |
+
self.model.eval()
|
31 |
+
|
32 |
+
self.model.set_gradient_checkpointing(False)
|
33 |
+
|
34 |
+
# add modulation
|
35 |
+
modulations = []
|
36 |
+
for layer in self.model.encoder.layer:
|
37 |
+
norm1_modulation = Modulation(
|
38 |
+
self.model.config.hidden_size,
|
39 |
+
self.cfg.modulation_cond_dim,
|
40 |
+
zero_init=True,
|
41 |
+
single_layer=True,
|
42 |
+
)
|
43 |
+
norm2_modulation = Modulation(
|
44 |
+
self.model.config.hidden_size,
|
45 |
+
self.cfg.modulation_cond_dim,
|
46 |
+
zero_init=True,
|
47 |
+
single_layer=True,
|
48 |
+
)
|
49 |
+
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
|
50 |
+
modulations += [norm1_modulation, norm2_modulation]
|
51 |
+
self.modulations = nn.ModuleList(modulations)
|
52 |
+
|
53 |
+
self.register_buffer(
|
54 |
+
"image_mean",
|
55 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
56 |
+
persistent=False,
|
57 |
+
)
|
58 |
+
self.register_buffer(
|
59 |
+
"image_std",
|
60 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
61 |
+
persistent=False,
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
images: Float[Tensor, "B *N C H W"],
|
67 |
+
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
|
68 |
+
**kwargs,
|
69 |
+
) -> Float[Tensor, "B *N Ct Nt"]:
|
70 |
+
model = self.model
|
71 |
+
|
72 |
+
packed = False
|
73 |
+
if images.ndim == 4:
|
74 |
+
packed = True
|
75 |
+
images = images.unsqueeze(1)
|
76 |
+
if modulation_cond is not None:
|
77 |
+
assert modulation_cond.ndim == 2
|
78 |
+
modulation_cond = modulation_cond.unsqueeze(1)
|
79 |
+
|
80 |
+
batch_size, n_input_views = images.shape[:2]
|
81 |
+
images = (images - self.image_mean) / self.image_std
|
82 |
+
out = model(
|
83 |
+
rearrange(images, "B N C H W -> (B N) C H W"),
|
84 |
+
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
|
85 |
+
if modulation_cond is not None
|
86 |
+
else None,
|
87 |
+
)
|
88 |
+
local_features = out.last_hidden_state
|
89 |
+
local_features = local_features.permute(0, 2, 1)
|
90 |
+
local_features = rearrange(
|
91 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
92 |
+
)
|
93 |
+
if packed:
|
94 |
+
local_features = local_features.squeeze(1)
|
95 |
+
|
96 |
+
return local_features
|
97 |
+
|
98 |
+
def detokenize(self, *args, **kwargs):
|
99 |
+
raise NotImplementedError
|
sf3d/models/tokenizers/triplane.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.models.utils import BaseModule
|
11 |
+
|
12 |
+
|
13 |
+
class TriplaneLearnablePositionalEmbedding(BaseModule):
|
14 |
+
@dataclass
|
15 |
+
class Config(BaseModule.Config):
|
16 |
+
plane_size: int = 96
|
17 |
+
num_channels: int = 1024
|
18 |
+
|
19 |
+
cfg: Config
|
20 |
+
|
21 |
+
def configure(self) -> None:
|
22 |
+
self.embeddings = nn.Parameter(
|
23 |
+
torch.randn(
|
24 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
25 |
+
dtype=torch.float32,
|
26 |
+
)
|
27 |
+
* 1
|
28 |
+
/ math.sqrt(self.cfg.num_channels)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
|
32 |
+
return rearrange(
|
33 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
34 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
35 |
+
)
|
36 |
+
|
37 |
+
def detokenize(
|
38 |
+
self, tokens: Float[Tensor, "B Ct Nt"]
|
39 |
+
) -> Float[Tensor, "B 3 Ct Hp Wp"]:
|
40 |
+
batch_size, Ct, Nt = tokens.shape
|
41 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
42 |
+
assert Ct == self.cfg.num_channels
|
43 |
+
return rearrange(
|
44 |
+
tokens,
|
45 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
46 |
+
Np=3,
|
47 |
+
Hp=self.cfg.plane_size,
|
48 |
+
Wp=self.cfg.plane_size,
|
49 |
+
)
|
sf3d/models/transformers/attention.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class Modulation(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
embedding_dim: int,
|
9 |
+
condition_dim: int,
|
10 |
+
zero_init: bool = False,
|
11 |
+
single_layer: bool = False,
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.silu = nn.SiLU()
|
15 |
+
if single_layer:
|
16 |
+
self.linear1 = nn.Identity()
|
17 |
+
else:
|
18 |
+
self.linear1 = nn.Linear(condition_dim, condition_dim)
|
19 |
+
|
20 |
+
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
|
21 |
+
|
22 |
+
# Only zero init the last linear layer
|
23 |
+
if zero_init:
|
24 |
+
nn.init.zeros_(self.linear2.weight)
|
25 |
+
nn.init.zeros_(self.linear2.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
28 |
+
emb = self.linear2(self.silu(self.linear1(condition)))
|
29 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
30 |
+
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
31 |
+
return x
|
sf3d/models/transformers/backbone.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from sf3d.models.utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class GEGLU(nn.Module):
|
12 |
+
r"""
|
13 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
dim_in (`int`): The number of channels in the input.
|
17 |
+
dim_out (`int`): The number of channels in the output.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, dim_in: int, dim_out: int):
|
21 |
+
super().__init__()
|
22 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
23 |
+
|
24 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
25 |
+
if gate.device.type != "mps":
|
26 |
+
return F.gelu(gate)
|
27 |
+
# mps: gelu is not implemented for float16
|
28 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
29 |
+
|
30 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
31 |
+
args = ()
|
32 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
33 |
+
return hidden_states * self.gelu(gate)
|
34 |
+
|
35 |
+
|
36 |
+
class CrossAttention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim,
|
40 |
+
kv_dim=None,
|
41 |
+
num_heads=16,
|
42 |
+
qkv_bias=False,
|
43 |
+
attn_drop=0.0,
|
44 |
+
proj_drop=0.0,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
kv_dim = dim if not kv_dim else kv_dim
|
51 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
52 |
+
self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
53 |
+
self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
54 |
+
self.attn_drop = attn_drop
|
55 |
+
self.proj = nn.Linear(dim, dim)
|
56 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
57 |
+
|
58 |
+
def forward(self, x_q, x_kv):
|
59 |
+
B, N_q, C = x_q.shape
|
60 |
+
B, N_kv, _ = x_kv.shape
|
61 |
+
# [B, N_q, C] -> [B, N_q, H, C/H]
|
62 |
+
q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
|
63 |
+
# [B, N_kv, C] -> [B, N_kv, H, C/H]
|
64 |
+
k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
|
65 |
+
v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
|
66 |
+
|
67 |
+
# attention
|
68 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
69 |
+
q.permute(0, 2, 1, 3),
|
70 |
+
k.permute(0, 2, 1, 3),
|
71 |
+
v.permute(0, 2, 1, 3),
|
72 |
+
attn_mask=None,
|
73 |
+
dropout_p=self.attn_drop,
|
74 |
+
scale=self.scale,
|
75 |
+
).permute(0, 2, 1, 3)
|
76 |
+
|
77 |
+
# [B, N_q, H, C/H] -> [B, N_q, C]
|
78 |
+
x = x.reshape(B, N_q, C)
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class FeedForward(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
dim: int,
|
88 |
+
dim_out: Optional[int] = None,
|
89 |
+
mult: int = 4,
|
90 |
+
dropout: float = 0.0,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
inner_dim = int(dim * mult)
|
94 |
+
dim_out = dim_out if dim_out is not None else dim
|
95 |
+
act_fn = GEGLU(dim, inner_dim)
|
96 |
+
self.net = nn.ModuleList([])
|
97 |
+
self.net.append(act_fn)
|
98 |
+
self.net.append(nn.Dropout(dropout))
|
99 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102 |
+
for module in self.net:
|
103 |
+
x = module(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class BasicBlock(nn.Module):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
dim: int,
|
111 |
+
kv_dim: Optional[int] = None,
|
112 |
+
num_heads: int = 16,
|
113 |
+
qkv_bias: bool = False,
|
114 |
+
attn_drop: float = 0.0,
|
115 |
+
proj_drop: float = 0.0,
|
116 |
+
ff_drop: float = 0.0,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.norm1 = nn.LayerNorm(dim)
|
120 |
+
self.attn1 = CrossAttention(
|
121 |
+
dim,
|
122 |
+
kv_dim=dim,
|
123 |
+
num_heads=num_heads,
|
124 |
+
qkv_bias=qkv_bias,
|
125 |
+
attn_drop=attn_drop,
|
126 |
+
proj_drop=proj_drop,
|
127 |
+
)
|
128 |
+
self.norm2 = nn.LayerNorm(dim)
|
129 |
+
self.attn2 = CrossAttention(
|
130 |
+
dim,
|
131 |
+
kv_dim=kv_dim,
|
132 |
+
num_heads=num_heads,
|
133 |
+
qkv_bias=qkv_bias,
|
134 |
+
attn_drop=attn_drop,
|
135 |
+
proj_drop=proj_drop,
|
136 |
+
)
|
137 |
+
self.norm3 = nn.LayerNorm(dim)
|
138 |
+
self.ff = FeedForward(dim, dropout=ff_drop)
|
139 |
+
|
140 |
+
def forward(self, z, x):
|
141 |
+
z_norm = self.norm1(z)
|
142 |
+
z = z + self.attn1(z_norm, z_norm)
|
143 |
+
# TODO: do we need to have the second attention when x is None?
|
144 |
+
z_norm = self.norm2(z)
|
145 |
+
z = z + self.attn2(z_norm, x if x is not None else z_norm)
|
146 |
+
z_norm = self.norm3(z)
|
147 |
+
z = z + self.ff(z_norm)
|
148 |
+
return z
|
149 |
+
|
150 |
+
|
151 |
+
class SingleStreamTransformer(BaseModule):
|
152 |
+
@dataclass
|
153 |
+
class Config(BaseModule.Config):
|
154 |
+
num_attention_heads: int = 16
|
155 |
+
attention_head_dim: int = 88
|
156 |
+
in_channels: Optional[int] = None
|
157 |
+
out_channels: Optional[int] = None
|
158 |
+
num_layers: int = 16
|
159 |
+
dropout: float = 0.0
|
160 |
+
norm_num_groups: int = 32
|
161 |
+
cross_attention_dim: Optional[int] = None
|
162 |
+
attention_bias: bool = False
|
163 |
+
|
164 |
+
cfg: Config
|
165 |
+
|
166 |
+
def configure(self) -> None:
|
167 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
168 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
169 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
170 |
+
|
171 |
+
# Define input layers
|
172 |
+
self.norm = torch.nn.GroupNorm(
|
173 |
+
num_groups=self.cfg.norm_num_groups,
|
174 |
+
num_channels=self.cfg.in_channels,
|
175 |
+
eps=1e-6,
|
176 |
+
affine=True,
|
177 |
+
)
|
178 |
+
self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
|
179 |
+
|
180 |
+
# Define transformers blocks
|
181 |
+
self.transformer_blocks = nn.ModuleList(
|
182 |
+
[
|
183 |
+
BasicBlock(
|
184 |
+
inner_dim,
|
185 |
+
kv_dim=self.cfg.cross_attention_dim,
|
186 |
+
num_heads=self.num_attention_heads,
|
187 |
+
qkv_bias=self.cfg.attention_bias,
|
188 |
+
proj_drop=self.cfg.dropout,
|
189 |
+
ff_drop=self.cfg.dropout,
|
190 |
+
)
|
191 |
+
for d in range(self.cfg.num_layers)
|
192 |
+
]
|
193 |
+
)
|
194 |
+
|
195 |
+
# 4. Define output layers
|
196 |
+
self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
|
197 |
+
|
198 |
+
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
199 |
+
residual = hidden_states
|
200 |
+
hidden_states = self.norm(hidden_states)
|
201 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
202 |
+
hidden_states = self.proj_in(hidden_states)
|
203 |
+
for block in self.transformer_blocks:
|
204 |
+
hidden_states = block(hidden_states, encoder_hidden_states)
|
205 |
+
hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
|
206 |
+
# TODO: do we really need to add the residual?
|
207 |
+
hidden_states = hidden_states + residual
|
208 |
+
return hidden_states
|
209 |
+
|
210 |
+
|
211 |
+
class FuseBlock(nn.Module):
|
212 |
+
"""
|
213 |
+
Fuse X in to Z with cross attention
|
214 |
+
"""
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
dim_z: int,
|
219 |
+
dim_x: int,
|
220 |
+
num_heads: int = 16,
|
221 |
+
qkv_bias: bool = False,
|
222 |
+
attn_drop: float = 0.0,
|
223 |
+
proj_drop: float = 0.0,
|
224 |
+
ff_drop: float = 0.0,
|
225 |
+
norm_x_input: bool = True,
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
self.norm_x_input = norm_x_input
|
229 |
+
if self.norm_x_input:
|
230 |
+
self.norm_x = nn.LayerNorm(dim_x)
|
231 |
+
self.attn = CrossAttention(
|
232 |
+
dim_z,
|
233 |
+
kv_dim=dim_x,
|
234 |
+
num_heads=num_heads,
|
235 |
+
qkv_bias=qkv_bias,
|
236 |
+
attn_drop=attn_drop,
|
237 |
+
proj_drop=proj_drop,
|
238 |
+
)
|
239 |
+
self.norm_z1 = nn.LayerNorm(dim_z)
|
240 |
+
self.norm_z2 = nn.LayerNorm(dim_z)
|
241 |
+
self.ff = FeedForward(dim_z, dropout=ff_drop)
|
242 |
+
|
243 |
+
def forward(self, z, x):
|
244 |
+
# TODO: do we need to normalize x?
|
245 |
+
z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
|
246 |
+
z = z + self.ff(self.norm_z2(z))
|
247 |
+
return z
|
248 |
+
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def get_triplane_attention_mask(res):
|
252 |
+
N = 3 * res * res
|
253 |
+
attn_mask = torch.zeros(3, res, res, 3, res, res)
|
254 |
+
|
255 |
+
i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
|
256 |
+
|
257 |
+
attn_mask[0, i, j, 1, i, :] = 1.0
|
258 |
+
attn_mask[0, i, j, 2, j, :] = 1.0
|
259 |
+
attn_mask[1, i, j, 0, i, :] = 1.0
|
260 |
+
attn_mask[1, i, j, 2, :, j] = 1.0
|
261 |
+
attn_mask[2, i, j, 0, :, i] = 1.0
|
262 |
+
attn_mask[2, i, j, 1, :, j] = 1.0
|
263 |
+
attn_mask = attn_mask.bool()
|
264 |
+
|
265 |
+
attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
|
266 |
+
attn_bias.masked_fill_(attn_mask, 0.0)
|
267 |
+
attn_bias.masked_fill_(~attn_mask, float("-inf"))
|
268 |
+
|
269 |
+
return attn_bias.reshape(N, N)
|
270 |
+
|
271 |
+
|
272 |
+
class TriplaneAttention(nn.Module):
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
dim: int,
|
276 |
+
resolution: int,
|
277 |
+
num_heads: int = 16,
|
278 |
+
qkv_bias: bool = False,
|
279 |
+
attn_drop: float = 0.0,
|
280 |
+
proj_drop: float = 0.0,
|
281 |
+
full_attention: bool = False,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
self.num_heads = num_heads
|
285 |
+
head_dim = dim // num_heads
|
286 |
+
self.scale = head_dim**-0.5
|
287 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
288 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
289 |
+
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
|
290 |
+
self.attn_drop = attn_drop
|
291 |
+
self.proj = nn.Linear(dim, dim)
|
292 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
293 |
+
|
294 |
+
self.resolution = resolution
|
295 |
+
self.full_attention = full_attention
|
296 |
+
self.attn_mask = (
|
297 |
+
get_triplane_attention_mask(resolution) if not full_attention else None
|
298 |
+
)
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
B, N, C = x.shape
|
302 |
+
# [B, N, C] -> [B, N, H, C/H]
|
303 |
+
q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
304 |
+
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
305 |
+
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
|
306 |
+
|
307 |
+
# detokenize the planes
|
308 |
+
assert N == self.resolution**2 * 3
|
309 |
+
attn_bias = (
|
310 |
+
self.attn_mask.to(q)
|
311 |
+
.unsqueeze(0)
|
312 |
+
.unsqueeze(0)
|
313 |
+
.expand(B, self.num_heads, -1, -1)
|
314 |
+
if not self.full_attention
|
315 |
+
else None
|
316 |
+
)
|
317 |
+
|
318 |
+
# full attention
|
319 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
320 |
+
q.permute(0, 2, 1, 3),
|
321 |
+
k.permute(0, 2, 1, 3),
|
322 |
+
v.permute(0, 2, 1, 3),
|
323 |
+
attn_mask=attn_bias,
|
324 |
+
dropout_p=self.attn_drop,
|
325 |
+
scale=self.scale,
|
326 |
+
).permute(0, 2, 1, 3)
|
327 |
+
|
328 |
+
# [B, N_q, H, C/H] -> [B, N_q, C]
|
329 |
+
x = x.reshape(B, N, C)
|
330 |
+
x = self.proj(x)
|
331 |
+
x = self.proj_drop(x)
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
class TwoStreamBlock(nn.Module):
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
dim_latent: int,
|
339 |
+
dim_input: int,
|
340 |
+
num_basic_blocks: int = 4,
|
341 |
+
num_heads: int = 16,
|
342 |
+
qkv_bias: bool = False,
|
343 |
+
attn_drop: float = 0.0,
|
344 |
+
proj_drop: float = 0.0,
|
345 |
+
ff_drop: float = 0.0,
|
346 |
+
norm_x_input: bool = True,
|
347 |
+
dim_cross: Optional[int] = None,
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
|
351 |
+
# Define the fuse block that fuse the input into the latent
|
352 |
+
self.fuse_block_in = FuseBlock(
|
353 |
+
dim_latent,
|
354 |
+
dim_input,
|
355 |
+
num_heads=num_heads,
|
356 |
+
qkv_bias=qkv_bias,
|
357 |
+
attn_drop=attn_drop,
|
358 |
+
proj_drop=proj_drop,
|
359 |
+
ff_drop=ff_drop,
|
360 |
+
norm_x_input=norm_x_input,
|
361 |
+
)
|
362 |
+
|
363 |
+
# Define the transformer block that process the latent
|
364 |
+
self.transformer_block = nn.ModuleList(
|
365 |
+
[
|
366 |
+
BasicBlock(
|
367 |
+
dim_latent,
|
368 |
+
kv_dim=dim_cross,
|
369 |
+
num_heads=num_heads,
|
370 |
+
qkv_bias=qkv_bias,
|
371 |
+
proj_drop=proj_drop,
|
372 |
+
ff_drop=ff_drop,
|
373 |
+
)
|
374 |
+
for _ in range(num_basic_blocks)
|
375 |
+
]
|
376 |
+
)
|
377 |
+
|
378 |
+
# Define the fuse block that fuse the latent into the input
|
379 |
+
self.fuse_block_out = FuseBlock(
|
380 |
+
dim_input,
|
381 |
+
dim_latent,
|
382 |
+
num_heads=num_heads,
|
383 |
+
qkv_bias=qkv_bias,
|
384 |
+
attn_drop=attn_drop,
|
385 |
+
proj_drop=proj_drop,
|
386 |
+
ff_drop=ff_drop,
|
387 |
+
norm_x_input=norm_x_input,
|
388 |
+
)
|
389 |
+
|
390 |
+
def forward(self, latent, input, cross_input):
|
391 |
+
latent = self.fuse_block_in(latent, input)
|
392 |
+
for block in self.transformer_block:
|
393 |
+
latent = block(latent, cross_input)
|
394 |
+
input = self.fuse_block_out(input, latent)
|
395 |
+
return latent, input
|
396 |
+
|
397 |
+
|
398 |
+
class TwoStreamInterleaveTransformer(BaseModule):
|
399 |
+
@dataclass
|
400 |
+
class Config(BaseModule.Config):
|
401 |
+
num_attention_heads: int = 16
|
402 |
+
attention_head_dim: int = 64
|
403 |
+
raw_triplane_channels: int = 1024
|
404 |
+
triplane_channels: int = 1024
|
405 |
+
raw_image_channels: int = 1024
|
406 |
+
num_latents: int = 1792
|
407 |
+
num_blocks: int = 4
|
408 |
+
num_basic_blocks: int = 3
|
409 |
+
dropout: float = 0.0
|
410 |
+
latent_init_std: float = 0.02
|
411 |
+
norm_num_groups: int = 32
|
412 |
+
attention_bias: bool = False
|
413 |
+
norm_x_input: bool = False
|
414 |
+
cross_attention_dim: int = 1024
|
415 |
+
mix_latent: bool = True
|
416 |
+
|
417 |
+
cfg: Config
|
418 |
+
|
419 |
+
def configure(self) -> None:
|
420 |
+
self.mix_latent = self.cfg.mix_latent
|
421 |
+
|
422 |
+
# Define the dimensions
|
423 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
424 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
425 |
+
self.num_latents = self.cfg.num_latents
|
426 |
+
self.latent_dim = self.num_attention_heads * self.attention_head_dim
|
427 |
+
|
428 |
+
# Define input layers
|
429 |
+
if self.cfg.norm_num_groups > 0:
|
430 |
+
self.norm_triplane = torch.nn.GroupNorm(
|
431 |
+
num_groups=self.cfg.norm_num_groups,
|
432 |
+
num_channels=self.cfg.raw_triplane_channels,
|
433 |
+
eps=1e-6,
|
434 |
+
affine=True,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
|
438 |
+
self.proj_triplane = nn.Linear(
|
439 |
+
self.cfg.raw_triplane_channels, self.cfg.triplane_channels
|
440 |
+
)
|
441 |
+
if self.mix_latent:
|
442 |
+
self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
|
443 |
+
self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
|
444 |
+
self.norm_latent = nn.LayerNorm(self.latent_dim)
|
445 |
+
self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
|
446 |
+
|
447 |
+
# Define the latents
|
448 |
+
self.latent_init = nn.Parameter(
|
449 |
+
torch.zeros(1, self.num_latents, self.latent_dim)
|
450 |
+
)
|
451 |
+
nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
|
452 |
+
|
453 |
+
# Define the transformer blocks
|
454 |
+
self.main_blocks = nn.ModuleList(
|
455 |
+
[
|
456 |
+
TwoStreamBlock(
|
457 |
+
self.latent_dim,
|
458 |
+
self.cfg.triplane_channels,
|
459 |
+
num_basic_blocks=self.cfg.num_basic_blocks,
|
460 |
+
num_heads=self.num_attention_heads,
|
461 |
+
qkv_bias=self.cfg.attention_bias,
|
462 |
+
proj_drop=self.cfg.dropout,
|
463 |
+
ff_drop=self.cfg.dropout,
|
464 |
+
norm_x_input=self.cfg.norm_x_input,
|
465 |
+
dim_cross=self.cfg.cross_attention_dim,
|
466 |
+
)
|
467 |
+
for _ in range(self.cfg.num_blocks)
|
468 |
+
]
|
469 |
+
)
|
470 |
+
|
471 |
+
# 4. Define output layers
|
472 |
+
self.proj_out = nn.Linear(
|
473 |
+
self.cfg.triplane_channels, self.cfg.raw_triplane_channels
|
474 |
+
)
|
475 |
+
|
476 |
+
def forward(self, hidden_states, encoder_hidden_states, **kwargs):
|
477 |
+
# hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
|
478 |
+
# encoder_hidden_states: [B, N_image, image_dim] is the image tokens
|
479 |
+
if isinstance(self.norm_triplane, nn.GroupNorm):
|
480 |
+
triplane_tokens = self.norm_triplane(hidden_states)
|
481 |
+
triplane_tokens = triplane_tokens.permute(
|
482 |
+
0, 2, 1
|
483 |
+
) # [B, N_triplane, triplane_dim]
|
484 |
+
elif isinstance(self.norm_triplane, nn.LayerNorm):
|
485 |
+
triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
|
486 |
+
else:
|
487 |
+
raise ValueError("Unknown normalization layer")
|
488 |
+
triplane_tokens = self.proj_triplane(triplane_tokens)
|
489 |
+
if self.mix_latent:
|
490 |
+
image_tokens = self.norm_image(
|
491 |
+
encoder_hidden_states
|
492 |
+
) # [B, N_image, image_dim]
|
493 |
+
image_tokens = self.proj_image(image_tokens)
|
494 |
+
init_latents = self.latent_init.expand(
|
495 |
+
hidden_states.shape[0], -1, -1
|
496 |
+
) # [B, N_latent_init, latent_dim]
|
497 |
+
init_latents = self.norm_latent(init_latents)
|
498 |
+
init_latents = self.proj_latent(init_latents)
|
499 |
+
if self.mix_latent:
|
500 |
+
latent_tokens = torch.cat(
|
501 |
+
[image_tokens, init_latents], dim=1
|
502 |
+
) # [B, N_latent, latent_dim]
|
503 |
+
else:
|
504 |
+
latent_tokens = init_latents
|
505 |
+
|
506 |
+
# forward the main blocks
|
507 |
+
for block in self.main_blocks:
|
508 |
+
latent_tokens, triplane_tokens = block(
|
509 |
+
latent_tokens, triplane_tokens, encoder_hidden_states
|
510 |
+
)
|
511 |
+
|
512 |
+
# project the triplane tokens back to the original dimension
|
513 |
+
triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
|
514 |
+
triplane_tokens = triplane_tokens + hidden_states
|
515 |
+
return triplane_tokens
|
sf3d/models/utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import importlib
|
3 |
+
import math
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from jaxtyping import Bool, Float, Int, Num
|
13 |
+
from omegaconf import DictConfig, OmegaConf
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
|
17 |
+
class BaseModule(nn.Module):
|
18 |
+
@dataclass
|
19 |
+
class Config:
|
20 |
+
pass
|
21 |
+
|
22 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
self.cfg = parse_structured(self.Config, cfg)
|
29 |
+
self.configure(*args, **kwargs)
|
30 |
+
|
31 |
+
def configure(self, *args, **kwargs) -> None:
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
|
35 |
+
def find_class(cls_string):
|
36 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
37 |
+
cls_name = cls_string.split(".")[-1]
|
38 |
+
module = importlib.import_module(module_string, package=None)
|
39 |
+
cls = getattr(module, cls_name)
|
40 |
+
return cls
|
41 |
+
|
42 |
+
|
43 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
44 |
+
# Check if cfg.keys are in fields
|
45 |
+
cfg_ = cfg.copy()
|
46 |
+
keys = list(cfg_.keys())
|
47 |
+
|
48 |
+
field_names = {f.name for f in dataclasses.fields(fields)}
|
49 |
+
for key in keys:
|
50 |
+
# This is helpful when swapping out modules from CLI
|
51 |
+
if key not in field_names:
|
52 |
+
print(f"Ignoring {key} as it's not supported by {fields}")
|
53 |
+
cfg_.pop(key)
|
54 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
|
55 |
+
return scfg
|
56 |
+
|
57 |
+
|
58 |
+
EPS_DTYPE = {
|
59 |
+
torch.float16: 1e-4,
|
60 |
+
torch.bfloat16: 1e-4,
|
61 |
+
torch.float32: 1e-7,
|
62 |
+
torch.float64: 1e-8,
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def dot(x, y, dim=-1):
|
67 |
+
return torch.sum(x * y, dim, keepdim=True)
|
68 |
+
|
69 |
+
|
70 |
+
def reflect(x, n):
|
71 |
+
return x - 2 * dot(x, n) * n
|
72 |
+
|
73 |
+
|
74 |
+
def normalize(x, dim=-1, eps=None):
|
75 |
+
if eps is None:
|
76 |
+
eps = EPS_DTYPE[x.dtype]
|
77 |
+
return F.normalize(x, dim=dim, p=2, eps=eps)
|
78 |
+
|
79 |
+
|
80 |
+
def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
|
81 |
+
# One pad for determinant
|
82 |
+
tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
|
83 |
+
det_tri = torch.det(tri_sq)
|
84 |
+
tri_rev = torch.cat(
|
85 |
+
(tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
|
86 |
+
)
|
87 |
+
tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
|
88 |
+
return tri_sq
|
89 |
+
|
90 |
+
|
91 |
+
def triangle_intersection_2d(
|
92 |
+
t1: Float[Tensor, "*B 3 2"],
|
93 |
+
t2: Float[Tensor, "*B 3 2"],
|
94 |
+
eps=1e-12,
|
95 |
+
) -> Float[Tensor, "*B"]: # noqa: F821
|
96 |
+
"""Returns True if triangles collide, False otherwise"""
|
97 |
+
|
98 |
+
def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
|
99 |
+
logdetx = torch.logdet(x.double())
|
100 |
+
if eps is None:
|
101 |
+
return ~torch.isfinite(logdetx)
|
102 |
+
return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
|
103 |
+
|
104 |
+
t1s = tri_winding(t1)
|
105 |
+
t2s = tri_winding(t2)
|
106 |
+
|
107 |
+
# Assume the triangles do not collide in the begging
|
108 |
+
ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
|
109 |
+
for i in range(3):
|
110 |
+
edge = torch.roll(t1s, i, dims=1)[:, :2, :]
|
111 |
+
# Check if all points of triangle 2 lay on the external side of edge E.
|
112 |
+
# If this is the case the triangle do not collide
|
113 |
+
upd = (
|
114 |
+
chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
|
115 |
+
& chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
|
116 |
+
& chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
|
117 |
+
)
|
118 |
+
# Here no collision is still True due to inversion
|
119 |
+
ret = ret | upd
|
120 |
+
|
121 |
+
for i in range(3):
|
122 |
+
edge = torch.roll(t2s, i, dims=1)[:, :2, :]
|
123 |
+
|
124 |
+
upd = (
|
125 |
+
chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
|
126 |
+
& chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
|
127 |
+
& chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
|
128 |
+
)
|
129 |
+
# Here no collision is still True due to inversion
|
130 |
+
ret = ret | upd
|
131 |
+
|
132 |
+
return ~ret # Do the inversion
|
133 |
+
|
134 |
+
|
135 |
+
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
136 |
+
|
137 |
+
|
138 |
+
def scale_tensor(
|
139 |
+
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
|
140 |
+
):
|
141 |
+
if inp_scale is None:
|
142 |
+
inp_scale = (0, 1)
|
143 |
+
if tgt_scale is None:
|
144 |
+
tgt_scale = (0, 1)
|
145 |
+
if isinstance(tgt_scale, Tensor):
|
146 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
147 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
148 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
149 |
+
return dat
|
150 |
+
|
151 |
+
|
152 |
+
def dilate_fill(img, mask, iterations=10):
|
153 |
+
oldMask = mask.float()
|
154 |
+
oldImg = img
|
155 |
+
|
156 |
+
mask_kernel = torch.ones(
|
157 |
+
(1, 1, 3, 3),
|
158 |
+
dtype=oldMask.dtype,
|
159 |
+
device=oldMask.device,
|
160 |
+
)
|
161 |
+
|
162 |
+
for i in range(iterations):
|
163 |
+
newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
|
164 |
+
|
165 |
+
# Fill the extension with mean color of old valid regions
|
166 |
+
img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
|
167 |
+
mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
168 |
+
new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
169 |
+
|
170 |
+
# Average color of the valid region
|
171 |
+
mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
|
172 |
+
2
|
173 |
+
)
|
174 |
+
# Extend it to the new region
|
175 |
+
fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
|
176 |
+
|
177 |
+
mask_conv = F.conv2d(
|
178 |
+
newMask, mask_kernel, padding=1
|
179 |
+
) # Get the sum for each kernel patch
|
180 |
+
newImg = F.fold(
|
181 |
+
fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
|
182 |
+
) / mask_conv.clamp(1)
|
183 |
+
|
184 |
+
diffMask = newMask - oldMask
|
185 |
+
|
186 |
+
oldMask = newMask
|
187 |
+
oldImg = torch.lerp(oldImg, newImg, diffMask)
|
188 |
+
|
189 |
+
return oldImg
|
190 |
+
|
191 |
+
|
192 |
+
def float32_to_uint8_np(
|
193 |
+
x: Float[np.ndarray, "*B H W C"],
|
194 |
+
dither: bool = True,
|
195 |
+
dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
|
196 |
+
dither_strength: float = 1.0,
|
197 |
+
) -> Int[np.ndarray, "*B H W C"]:
|
198 |
+
if dither:
|
199 |
+
dither = (
|
200 |
+
dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
|
201 |
+
)
|
202 |
+
if dither_mask is not None:
|
203 |
+
dither = dither * dither_mask
|
204 |
+
return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
|
205 |
+
return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
|
206 |
+
|
207 |
+
|
208 |
+
def convert_data(data):
|
209 |
+
if data is None:
|
210 |
+
return None
|
211 |
+
elif isinstance(data, np.ndarray):
|
212 |
+
return data
|
213 |
+
elif isinstance(data, torch.Tensor):
|
214 |
+
if data.dtype in [torch.float16, torch.bfloat16]:
|
215 |
+
data = data.float()
|
216 |
+
return data.detach().cpu().numpy()
|
217 |
+
elif isinstance(data, list):
|
218 |
+
return [convert_data(d) for d in data]
|
219 |
+
elif isinstance(data, dict):
|
220 |
+
return {k: convert_data(v) for k, v in data.items()}
|
221 |
+
else:
|
222 |
+
raise TypeError(
|
223 |
+
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
|
224 |
+
type(data),
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
class ImageProcessor:
|
229 |
+
def convert_and_resize(
|
230 |
+
self,
|
231 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
232 |
+
size: int,
|
233 |
+
):
|
234 |
+
if isinstance(image, PIL.Image.Image):
|
235 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
236 |
+
elif isinstance(image, np.ndarray):
|
237 |
+
if image.dtype == np.uint8:
|
238 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
239 |
+
else:
|
240 |
+
image = torch.from_numpy(image)
|
241 |
+
elif isinstance(image, torch.Tensor):
|
242 |
+
pass
|
243 |
+
|
244 |
+
batched = image.ndim == 4
|
245 |
+
|
246 |
+
if not batched:
|
247 |
+
image = image[None, ...]
|
248 |
+
image = F.interpolate(
|
249 |
+
image.permute(0, 3, 1, 2),
|
250 |
+
(size, size),
|
251 |
+
mode="bilinear",
|
252 |
+
align_corners=False,
|
253 |
+
antialias=True,
|
254 |
+
).permute(0, 2, 3, 1)
|
255 |
+
if not batched:
|
256 |
+
image = image[0]
|
257 |
+
return image
|
258 |
+
|
259 |
+
def __call__(
|
260 |
+
self,
|
261 |
+
image: Union[
|
262 |
+
PIL.Image.Image,
|
263 |
+
np.ndarray,
|
264 |
+
torch.FloatTensor,
|
265 |
+
List[PIL.Image.Image],
|
266 |
+
List[np.ndarray],
|
267 |
+
List[torch.FloatTensor],
|
268 |
+
],
|
269 |
+
size: int,
|
270 |
+
) -> Any:
|
271 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
272 |
+
image = self.convert_and_resize(image, size)
|
273 |
+
else:
|
274 |
+
if not isinstance(image, list):
|
275 |
+
image = [image]
|
276 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
277 |
+
image = torch.stack(image, dim=0)
|
278 |
+
return image
|
279 |
+
|
280 |
+
|
281 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
282 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
283 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
284 |
+
intrinsic[0, 0] = focal_length
|
285 |
+
intrinsic[1, 1] = focal_length
|
286 |
+
intrinsic[0, 2] = W / 2.0
|
287 |
+
intrinsic[1, 2] = H / 2.0
|
288 |
+
|
289 |
+
if bs > 0:
|
290 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
291 |
+
|
292 |
+
return torch.from_numpy(intrinsic)
|
sf3d/system.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Any, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import trimesh
|
9 |
+
from einops import rearrange
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from jaxtyping import Float
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from PIL import Image
|
14 |
+
from safetensors.torch import load_model
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
from sf3d.models.isosurface import MarchingTetrahedraHelper
|
18 |
+
from sf3d.models.mesh import Mesh
|
19 |
+
from sf3d.models.utils import (
|
20 |
+
BaseModule,
|
21 |
+
ImageProcessor,
|
22 |
+
convert_data,
|
23 |
+
dilate_fill,
|
24 |
+
dot,
|
25 |
+
find_class,
|
26 |
+
float32_to_uint8_np,
|
27 |
+
normalize,
|
28 |
+
scale_tensor,
|
29 |
+
)
|
30 |
+
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
|
31 |
+
|
32 |
+
from .texture_baker import TextureBaker
|
33 |
+
|
34 |
+
|
35 |
+
class SF3D(BaseModule):
|
36 |
+
@dataclass
|
37 |
+
class Config(BaseModule.Config):
|
38 |
+
cond_image_size: int
|
39 |
+
isosurface_resolution: int
|
40 |
+
isosurface_threshold: float = 10.0
|
41 |
+
radius: float = 1.0
|
42 |
+
background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
|
43 |
+
default_fovy_deg: float = 40.0
|
44 |
+
default_distance: float = 1.6
|
45 |
+
|
46 |
+
camera_embedder_cls: str = ""
|
47 |
+
camera_embedder: dict = field(default_factory=dict)
|
48 |
+
|
49 |
+
image_tokenizer_cls: str = ""
|
50 |
+
image_tokenizer: dict = field(default_factory=dict)
|
51 |
+
|
52 |
+
tokenizer_cls: str = ""
|
53 |
+
tokenizer: dict = field(default_factory=dict)
|
54 |
+
|
55 |
+
backbone_cls: str = ""
|
56 |
+
backbone: dict = field(default_factory=dict)
|
57 |
+
|
58 |
+
post_processor_cls: str = ""
|
59 |
+
post_processor: dict = field(default_factory=dict)
|
60 |
+
|
61 |
+
decoder_cls: str = ""
|
62 |
+
decoder: dict = field(default_factory=dict)
|
63 |
+
|
64 |
+
image_estimator_cls: str = ""
|
65 |
+
image_estimator: dict = field(default_factory=dict)
|
66 |
+
|
67 |
+
global_estimator_cls: str = ""
|
68 |
+
global_estimator: dict = field(default_factory=dict)
|
69 |
+
|
70 |
+
cfg: Config
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def from_pretrained(
|
74 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
75 |
+
):
|
76 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
77 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
78 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
79 |
+
else:
|
80 |
+
config_path = hf_hub_download(
|
81 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
82 |
+
)
|
83 |
+
weight_path = hf_hub_download(
|
84 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
85 |
+
)
|
86 |
+
|
87 |
+
cfg = OmegaConf.load(config_path)
|
88 |
+
OmegaConf.resolve(cfg)
|
89 |
+
model = cls(cfg)
|
90 |
+
load_model(model, weight_path)
|
91 |
+
return model
|
92 |
+
|
93 |
+
@property
|
94 |
+
def device(self):
|
95 |
+
return next(self.parameters()).device
|
96 |
+
|
97 |
+
def configure(self):
|
98 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
99 |
+
self.cfg.image_tokenizer
|
100 |
+
)
|
101 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
102 |
+
self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
|
103 |
+
self.cfg.camera_embedder
|
104 |
+
)
|
105 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
106 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
107 |
+
self.cfg.post_processor
|
108 |
+
)
|
109 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
110 |
+
self.image_estimator = find_class(self.cfg.image_estimator_cls)(
|
111 |
+
self.cfg.image_estimator
|
112 |
+
)
|
113 |
+
self.global_estimator = find_class(self.cfg.global_estimator_cls)(
|
114 |
+
self.cfg.global_estimator
|
115 |
+
)
|
116 |
+
|
117 |
+
self.bbox: Float[Tensor, "2 3"]
|
118 |
+
self.register_buffer(
|
119 |
+
"bbox",
|
120 |
+
torch.as_tensor(
|
121 |
+
[
|
122 |
+
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
|
123 |
+
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
|
124 |
+
],
|
125 |
+
dtype=torch.float32,
|
126 |
+
),
|
127 |
+
)
|
128 |
+
self.isosurface_helper = MarchingTetrahedraHelper(
|
129 |
+
self.cfg.isosurface_resolution,
|
130 |
+
os.path.join(
|
131 |
+
os.path.dirname(__file__),
|
132 |
+
"..",
|
133 |
+
"load",
|
134 |
+
"tets",
|
135 |
+
f"{self.cfg.isosurface_resolution}_tets.npz",
|
136 |
+
),
|
137 |
+
)
|
138 |
+
|
139 |
+
self.baker = TextureBaker()
|
140 |
+
self.image_processor = ImageProcessor()
|
141 |
+
|
142 |
+
def triplane_to_meshes(
|
143 |
+
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
|
144 |
+
) -> list[Mesh]:
|
145 |
+
meshes = []
|
146 |
+
for i in range(triplanes.shape[0]):
|
147 |
+
triplane = triplanes[i]
|
148 |
+
grid_vertices = scale_tensor(
|
149 |
+
self.isosurface_helper.grid_vertices.to(triplanes.device),
|
150 |
+
self.isosurface_helper.points_range,
|
151 |
+
self.bbox,
|
152 |
+
)
|
153 |
+
|
154 |
+
values = self.query_triplane(grid_vertices, triplane)
|
155 |
+
decoded = self.decoder(values, include=["vertex_offset", "density"])
|
156 |
+
sdf = decoded["density"] - self.cfg.isosurface_threshold
|
157 |
+
|
158 |
+
deform = decoded["vertex_offset"].squeeze(0)
|
159 |
+
|
160 |
+
mesh: Mesh = self.isosurface_helper(
|
161 |
+
sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
|
162 |
+
)
|
163 |
+
mesh.v_pos = scale_tensor(
|
164 |
+
mesh.v_pos, self.isosurface_helper.points_range, self.bbox
|
165 |
+
)
|
166 |
+
|
167 |
+
meshes.append(mesh)
|
168 |
+
|
169 |
+
return meshes
|
170 |
+
|
171 |
+
def query_triplane(
|
172 |
+
self,
|
173 |
+
positions: Float[Tensor, "*B N 3"],
|
174 |
+
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
|
175 |
+
) -> Float[Tensor, "*B N F"]:
|
176 |
+
batched = positions.ndim == 3
|
177 |
+
if not batched:
|
178 |
+
# no batch dimension
|
179 |
+
triplanes = triplanes[None, ...]
|
180 |
+
positions = positions[None, ...]
|
181 |
+
assert triplanes.ndim == 5 and positions.ndim == 3
|
182 |
+
|
183 |
+
positions = scale_tensor(
|
184 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
185 |
+
)
|
186 |
+
|
187 |
+
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
|
188 |
+
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
|
189 |
+
dim=-3,
|
190 |
+
).to(triplanes.dtype)
|
191 |
+
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
|
192 |
+
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
|
193 |
+
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
|
194 |
+
align_corners=True,
|
195 |
+
mode="bilinear",
|
196 |
+
)
|
197 |
+
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
|
198 |
+
|
199 |
+
return out
|
200 |
+
|
201 |
+
def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
|
202 |
+
# if batch[rgb_cond] is only one view, add a view dimension
|
203 |
+
if len(batch["rgb_cond"].shape) == 4:
|
204 |
+
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
|
205 |
+
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
|
206 |
+
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
|
207 |
+
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
|
208 |
+
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
|
209 |
+
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
210 |
+
|
211 |
+
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
|
212 |
+
camera_embeds = self.camera_embedder(**batch)
|
213 |
+
|
214 |
+
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
|
215 |
+
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
|
216 |
+
modulation_cond=camera_embeds,
|
217 |
+
)
|
218 |
+
|
219 |
+
input_image_tokens = rearrange(
|
220 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
|
221 |
+
)
|
222 |
+
|
223 |
+
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
|
224 |
+
|
225 |
+
tokens = self.backbone(
|
226 |
+
tokens,
|
227 |
+
encoder_hidden_states=input_image_tokens,
|
228 |
+
modulation_cond=None,
|
229 |
+
)
|
230 |
+
|
231 |
+
direct_codes = self.tokenizer.detokenize(tokens)
|
232 |
+
scene_codes = self.post_processor(direct_codes)
|
233 |
+
return scene_codes, direct_codes
|
234 |
+
|
235 |
+
def run_image(
|
236 |
+
self,
|
237 |
+
image: Image,
|
238 |
+
bake_resolution: int,
|
239 |
+
estimate_illumination: bool = False,
|
240 |
+
) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
|
241 |
+
if image.mode != "RGBA":
|
242 |
+
raise ValueError("Image must be in RGBA mode")
|
243 |
+
img_cond = (
|
244 |
+
torch.from_numpy(
|
245 |
+
np.asarray(
|
246 |
+
image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
|
247 |
+
).astype(np.float32)
|
248 |
+
/ 255.0
|
249 |
+
)
|
250 |
+
.float()
|
251 |
+
.clip(0, 1)
|
252 |
+
.to(self.device)
|
253 |
+
)
|
254 |
+
mask_cond = img_cond[:, :, -1:]
|
255 |
+
rgb_cond = torch.lerp(
|
256 |
+
torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
|
257 |
+
img_cond[:, :, :3],
|
258 |
+
mask_cond,
|
259 |
+
)
|
260 |
+
|
261 |
+
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
|
262 |
+
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
|
263 |
+
self.cfg.default_fovy_deg,
|
264 |
+
self.cfg.cond_image_size,
|
265 |
+
self.cfg.cond_image_size,
|
266 |
+
)
|
267 |
+
|
268 |
+
batch = {
|
269 |
+
"rgb_cond": rgb_cond,
|
270 |
+
"mask_cond": mask_cond,
|
271 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
272 |
+
"intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
|
273 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
|
274 |
+
}
|
275 |
+
|
276 |
+
meshes, global_dict = self.generate_mesh(
|
277 |
+
batch, bake_resolution, estimate_illumination
|
278 |
+
)
|
279 |
+
return meshes[0], global_dict
|
280 |
+
|
281 |
+
def generate_mesh(
|
282 |
+
self,
|
283 |
+
batch,
|
284 |
+
bake_resolution: int,
|
285 |
+
estimate_illumination: bool = False,
|
286 |
+
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
|
287 |
+
batch["rgb_cond"] = self.image_processor(
|
288 |
+
batch["rgb_cond"], self.cfg.cond_image_size
|
289 |
+
)
|
290 |
+
batch["mask_cond"] = self.image_processor(
|
291 |
+
batch["mask_cond"], self.cfg.cond_image_size
|
292 |
+
)
|
293 |
+
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
|
294 |
+
|
295 |
+
global_dict = {}
|
296 |
+
if self.image_estimator is not None:
|
297 |
+
global_dict.update(
|
298 |
+
self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
|
299 |
+
)
|
300 |
+
if self.global_estimator is not None and estimate_illumination:
|
301 |
+
global_dict.update(self.global_estimator(non_postprocessed_codes))
|
302 |
+
|
303 |
+
with torch.no_grad():
|
304 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
305 |
+
meshes = self.triplane_to_meshes(scene_codes)
|
306 |
+
|
307 |
+
rets = []
|
308 |
+
for i, mesh in enumerate(meshes):
|
309 |
+
# Check for empty mesh
|
310 |
+
if mesh.v_pos.shape[0] == 0:
|
311 |
+
rets.append(trimesh.Trimesh())
|
312 |
+
continue
|
313 |
+
|
314 |
+
mesh.unwrap_uv()
|
315 |
+
|
316 |
+
# Build textures
|
317 |
+
rast = self.baker.rasterize(
|
318 |
+
mesh.v_tex, mesh.t_pos_idx, bake_resolution
|
319 |
+
)
|
320 |
+
bake_mask = self.baker.get_mask(rast)
|
321 |
+
|
322 |
+
pos_bake = self.baker.interpolate(
|
323 |
+
mesh.v_pos,
|
324 |
+
rast,
|
325 |
+
mesh.t_pos_idx,
|
326 |
+
mesh.v_tex,
|
327 |
+
)
|
328 |
+
gb_pos = pos_bake[bake_mask]
|
329 |
+
|
330 |
+
tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
|
331 |
+
decoded = self.decoder(
|
332 |
+
tri_query, exclude=["density", "vertex_offset"]
|
333 |
+
)
|
334 |
+
|
335 |
+
nrm = self.baker.interpolate(
|
336 |
+
mesh.v_nrm,
|
337 |
+
rast,
|
338 |
+
mesh.t_pos_idx,
|
339 |
+
mesh.v_tex,
|
340 |
+
)
|
341 |
+
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
|
342 |
+
decoded["normal"] = gb_nrm
|
343 |
+
|
344 |
+
# Check if any keys in global_dict start with decoded_
|
345 |
+
for k, v in global_dict.items():
|
346 |
+
if k.startswith("decoder_"):
|
347 |
+
decoded[k.replace("decoder_", "")] = v[i]
|
348 |
+
|
349 |
+
mat_out = {
|
350 |
+
"albedo": decoded["features"],
|
351 |
+
"roughness": decoded["roughness"],
|
352 |
+
"metallic": decoded["metallic"],
|
353 |
+
"normal": normalize(decoded["perturb_normal"]),
|
354 |
+
"bump": None,
|
355 |
+
}
|
356 |
+
|
357 |
+
for k, v in mat_out.items():
|
358 |
+
if v is None:
|
359 |
+
continue
|
360 |
+
if v.shape[0] == 1:
|
361 |
+
# Skip and directly add a single value
|
362 |
+
mat_out[k] = v[0]
|
363 |
+
else:
|
364 |
+
f = torch.zeros(
|
365 |
+
bake_resolution,
|
366 |
+
bake_resolution,
|
367 |
+
v.shape[-1],
|
368 |
+
dtype=v.dtype,
|
369 |
+
device=v.device,
|
370 |
+
)
|
371 |
+
if v.shape == f.shape:
|
372 |
+
continue
|
373 |
+
if k == "normal":
|
374 |
+
# Use un-normalized tangents here so that larger smaller tris
|
375 |
+
# Don't effect the tangents that much
|
376 |
+
tng = self.baker.interpolate(
|
377 |
+
mesh.v_tng,
|
378 |
+
rast,
|
379 |
+
mesh.t_pos_idx,
|
380 |
+
mesh.v_tex,
|
381 |
+
)
|
382 |
+
gb_tng = tng[bake_mask]
|
383 |
+
gb_tng = F.normalize(gb_tng, dim=-1)
|
384 |
+
gb_btng = F.normalize(
|
385 |
+
torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
|
386 |
+
)
|
387 |
+
normal = F.normalize(mat_out["normal"], dim=-1)
|
388 |
+
|
389 |
+
bump = torch.cat(
|
390 |
+
# Check if we have to flip some things
|
391 |
+
(
|
392 |
+
dot(normal, gb_tng),
|
393 |
+
dot(normal, gb_btng),
|
394 |
+
dot(normal, gb_nrm).clip(
|
395 |
+
0.3, 1
|
396 |
+
), # Never go below 0.3. This would indicate a flipped (or close to one) normal
|
397 |
+
),
|
398 |
+
-1,
|
399 |
+
)
|
400 |
+
bump[..., :2] *= 0.5
|
401 |
+
bump = (bump * 0.5 + 0.5).clamp(0, 1)
|
402 |
+
|
403 |
+
f[bake_mask] = bump.view(-1, 3)
|
404 |
+
mat_out["bump"] = f
|
405 |
+
else:
|
406 |
+
f[bake_mask] = v.view(-1, v.shape[-1])
|
407 |
+
mat_out[k] = f
|
408 |
+
|
409 |
+
def uv_padding(arr):
|
410 |
+
if arr.ndim == 1:
|
411 |
+
return arr
|
412 |
+
return (
|
413 |
+
dilate_fill(
|
414 |
+
arr.permute(2, 0, 1)[None, ...],
|
415 |
+
bake_mask.unsqueeze(0).unsqueeze(0),
|
416 |
+
iterations=bake_resolution // 150,
|
417 |
+
)
|
418 |
+
.squeeze(0)
|
419 |
+
.permute(1, 2, 0)
|
420 |
+
)
|
421 |
+
|
422 |
+
verts_np = convert_data(mesh.v_pos)
|
423 |
+
faces = convert_data(mesh.t_pos_idx)
|
424 |
+
uvs = convert_data(mesh.v_tex)
|
425 |
+
|
426 |
+
basecolor_tex = Image.fromarray(
|
427 |
+
float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
|
428 |
+
).convert("RGB")
|
429 |
+
basecolor_tex.format = "JPEG"
|
430 |
+
|
431 |
+
metallic = mat_out["metallic"].squeeze().cpu().item()
|
432 |
+
roughness = mat_out["roughness"].squeeze().cpu().item()
|
433 |
+
|
434 |
+
if "bump" in mat_out and mat_out["bump"] is not None:
|
435 |
+
bump_np = convert_data(uv_padding(mat_out["bump"]))
|
436 |
+
bump_up = np.ones_like(bump_np)
|
437 |
+
bump_up[..., :2] = 0.5
|
438 |
+
bump_up[..., 2:] = 1
|
439 |
+
bump_tex = Image.fromarray(
|
440 |
+
float32_to_uint8_np(
|
441 |
+
bump_np,
|
442 |
+
dither=True,
|
443 |
+
# Do not dither if something is perfectly flat
|
444 |
+
dither_mask=np.all(
|
445 |
+
bump_np == bump_up, axis=-1, keepdims=True
|
446 |
+
).astype(np.float32),
|
447 |
+
)
|
448 |
+
).convert("RGB")
|
449 |
+
bump_tex.format = (
|
450 |
+
"JPEG" # PNG would be better but the assets are larger
|
451 |
+
)
|
452 |
+
else:
|
453 |
+
bump_tex = None
|
454 |
+
|
455 |
+
material = trimesh.visual.material.PBRMaterial(
|
456 |
+
baseColorTexture=basecolor_tex,
|
457 |
+
roughnessFactor=roughness,
|
458 |
+
metallicFactor=metallic,
|
459 |
+
normalTexture=bump_tex,
|
460 |
+
)
|
461 |
+
|
462 |
+
tmesh = trimesh.Trimesh(
|
463 |
+
vertices=verts_np,
|
464 |
+
faces=faces,
|
465 |
+
visual=trimesh.visual.texture.TextureVisuals(
|
466 |
+
uv=uvs, material=material
|
467 |
+
),
|
468 |
+
)
|
469 |
+
rot = trimesh.transformations.rotation_matrix(
|
470 |
+
np.radians(-90), [1, 0, 0]
|
471 |
+
)
|
472 |
+
tmesh.apply_transform(rot)
|
473 |
+
tmesh.apply_transform(
|
474 |
+
trimesh.transformations.rotation_matrix(
|
475 |
+
np.radians(90), [0, 1, 0]
|
476 |
+
)
|
477 |
+
)
|
478 |
+
|
479 |
+
tmesh.invert()
|
480 |
+
|
481 |
+
rets.append(tmesh)
|
482 |
+
|
483 |
+
return rets, global_dict
|
sf3d/texture_baker.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import slangtorch
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from jaxtyping import Bool, Float
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
class TextureBaker(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.baker = slangtorch.loadModule(
|
14 |
+
os.path.join(os.path.dirname(__file__), "texture_baker.slang")
|
15 |
+
)
|
16 |
+
|
17 |
+
def rasterize(
|
18 |
+
self,
|
19 |
+
uv: Float[Tensor, "Nv 2"],
|
20 |
+
face_indices: Float[Tensor, "Nf 3"],
|
21 |
+
bake_resolution: int,
|
22 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
|
23 |
+
if not face_indices.is_cuda or not uv.is_cuda:
|
24 |
+
raise ValueError("All input tensors must be on cuda")
|
25 |
+
|
26 |
+
face_indices = face_indices.to(torch.int32)
|
27 |
+
uv = uv.to(torch.float32)
|
28 |
+
|
29 |
+
rast_result = torch.empty(
|
30 |
+
bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
|
31 |
+
)
|
32 |
+
|
33 |
+
block_size = 16
|
34 |
+
grid_size = bake_resolution // block_size
|
35 |
+
self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
|
36 |
+
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
|
37 |
+
)
|
38 |
+
|
39 |
+
return rast_result
|
40 |
+
|
41 |
+
def get_mask(
|
42 |
+
self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
|
43 |
+
) -> Bool[Tensor, "bake_resolution bake_resolution"]:
|
44 |
+
return rast[..., -1] >= 0
|
45 |
+
|
46 |
+
def interpolate(
|
47 |
+
self,
|
48 |
+
attr: Float[Tensor, "Nv 3"],
|
49 |
+
rast: Float[Tensor, "bake_resolution bake_resolution 4"],
|
50 |
+
face_indices: Float[Tensor, "Nf 3"],
|
51 |
+
uv: Float[Tensor, "Nv 2"],
|
52 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
|
53 |
+
# Make sure all input tensors are on torch
|
54 |
+
if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
|
55 |
+
raise ValueError("All input tensors must be on cuda")
|
56 |
+
|
57 |
+
attr = attr.to(torch.float32)
|
58 |
+
face_indices = face_indices.to(torch.int32)
|
59 |
+
uv = uv.to(torch.float32)
|
60 |
+
|
61 |
+
pos_bake = torch.zeros(
|
62 |
+
rast.shape[0],
|
63 |
+
rast.shape[1],
|
64 |
+
3,
|
65 |
+
device=attr.device,
|
66 |
+
dtype=attr.dtype,
|
67 |
+
)
|
68 |
+
|
69 |
+
block_size = 16
|
70 |
+
grid_size = rast.shape[0] // block_size
|
71 |
+
self.baker.interpolate(
|
72 |
+
attr=attr, indices=face_indices, rast=rast, output=pos_bake
|
73 |
+
).launchRaw(
|
74 |
+
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
|
75 |
+
)
|
76 |
+
|
77 |
+
return pos_bake
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
attr: Float[Tensor, "Nv 3"],
|
82 |
+
uv: Float[Tensor, "Nv 2"],
|
83 |
+
face_indices: Float[Tensor, "Nf 3"],
|
84 |
+
bake_resolution: int,
|
85 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
|
86 |
+
rast = self.rasterize(uv, face_indices, bake_resolution)
|
87 |
+
return self.interpolate(attr, rast, face_indices, uv)
|
sf3d/texture_baker.slang
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// xy: 2D test position
|
2 |
+
// v1: vertex position 1
|
3 |
+
// v2: vertex position 2
|
4 |
+
// v3: vertex position 3
|
5 |
+
//
|
6 |
+
bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
|
7 |
+
{
|
8 |
+
// Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
|
9 |
+
// If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
|
10 |
+
float2 v1v2 = v2 - v1;
|
11 |
+
float2 v1v3 = v3 - v1;
|
12 |
+
float2 xyv1 = xy - v1;
|
13 |
+
|
14 |
+
float d00 = dot(v1v2, v1v2);
|
15 |
+
float d01 = dot(v1v2, v1v3);
|
16 |
+
float d11 = dot(v1v3, v1v3);
|
17 |
+
float d20 = dot(xyv1, v1v2);
|
18 |
+
float d21 = dot(xyv1, v1v3);
|
19 |
+
|
20 |
+
float denom = d00 * d11 - d01 * d01;
|
21 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
22 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
23 |
+
u = 1.0 - v - w;
|
24 |
+
|
25 |
+
return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
|
26 |
+
}
|
27 |
+
|
28 |
+
[AutoPyBindCUDA]
|
29 |
+
[CUDAKernel]
|
30 |
+
void interpolate(
|
31 |
+
TensorView<float3> attr,
|
32 |
+
TensorView<int3> indices,
|
33 |
+
TensorView<float4> rast,
|
34 |
+
TensorView<float3> output)
|
35 |
+
{
|
36 |
+
// Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
|
37 |
+
|
38 |
+
uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
|
39 |
+
|
40 |
+
if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
|
41 |
+
return;
|
42 |
+
|
43 |
+
float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
|
44 |
+
int triangle_idx = int(barycentric.w);
|
45 |
+
|
46 |
+
if (triangle_idx < 0) {
|
47 |
+
output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
|
48 |
+
return;
|
49 |
+
}
|
50 |
+
|
51 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
52 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
53 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
54 |
+
|
55 |
+
output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
|
56 |
+
}
|
57 |
+
|
58 |
+
[AutoPyBindCUDA]
|
59 |
+
[CUDAKernel]
|
60 |
+
void bake_uv(
|
61 |
+
TensorView<float2> uv,
|
62 |
+
TensorView<int3> indices,
|
63 |
+
TensorView<float4> output)
|
64 |
+
{
|
65 |
+
uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
|
66 |
+
|
67 |
+
if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
|
68 |
+
return;
|
69 |
+
|
70 |
+
// We index x,y but the orginal coords are HW. So swap them
|
71 |
+
float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
|
72 |
+
// Normalize to [0, 1]
|
73 |
+
pixel_coord /= float2(output.size(1), output.size(0));
|
74 |
+
pixel_coord = clamp(pixel_coord, 0.0, 1.0);
|
75 |
+
// Flip x-axis
|
76 |
+
pixel_coord.y = 1 - pixel_coord.y;
|
77 |
+
|
78 |
+
for (int i = 0; i < indices.size(0); i++) {
|
79 |
+
float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
|
80 |
+
float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
|
81 |
+
float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
|
82 |
+
|
83 |
+
float u, v, w;
|
84 |
+
bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
|
85 |
+
|
86 |
+
if (hit){
|
87 |
+
output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
|
88 |
+
return;
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
|
93 |
+
}
|
sf3d/utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import rembg
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import sf3d.models.utils as sf3d_utils
|
9 |
+
|
10 |
+
|
11 |
+
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
|
12 |
+
intrinsic = sf3d_utils.get_intrinsic_from_fov(
|
13 |
+
np.deg2rad(fov_deg),
|
14 |
+
H=cond_height,
|
15 |
+
W=cond_width,
|
16 |
+
)
|
17 |
+
intrinsic_normed_cond = intrinsic.clone()
|
18 |
+
intrinsic_normed_cond[..., 0, 2] /= cond_width
|
19 |
+
intrinsic_normed_cond[..., 1, 2] /= cond_height
|
20 |
+
intrinsic_normed_cond[..., 0, 0] /= cond_width
|
21 |
+
intrinsic_normed_cond[..., 1, 1] /= cond_height
|
22 |
+
|
23 |
+
return intrinsic, intrinsic_normed_cond
|
24 |
+
|
25 |
+
|
26 |
+
def default_cond_c2w(distance: float):
|
27 |
+
c2w_cond = torch.as_tensor(
|
28 |
+
[
|
29 |
+
[0, 0, 1, distance],
|
30 |
+
[1, 0, 0, 0],
|
31 |
+
[0, 1, 0, 0],
|
32 |
+
[0, 0, 0, 1],
|
33 |
+
]
|
34 |
+
).float()
|
35 |
+
return c2w_cond
|
36 |
+
|
37 |
+
|
38 |
+
def remove_background(
|
39 |
+
image: Image,
|
40 |
+
rembg_session: Any = None,
|
41 |
+
force: bool = False,
|
42 |
+
**rembg_kwargs,
|
43 |
+
) -> Image:
|
44 |
+
do_remove = True
|
45 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
46 |
+
do_remove = False
|
47 |
+
do_remove = do_remove or force
|
48 |
+
if do_remove:
|
49 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def resize_foreground(
|
54 |
+
image: Image,
|
55 |
+
ratio: float,
|
56 |
+
) -> Image:
|
57 |
+
image = np.array(image)
|
58 |
+
assert image.shape[-1] == 4
|
59 |
+
alpha = np.where(image[..., 3] > 0)
|
60 |
+
y1, y2, x1, x2 = (
|
61 |
+
alpha[0].min(),
|
62 |
+
alpha[0].max(),
|
63 |
+
alpha[1].min(),
|
64 |
+
alpha[1].max(),
|
65 |
+
)
|
66 |
+
# crop the foreground
|
67 |
+
fg = image[y1:y2, x1:x2]
|
68 |
+
# pad to square
|
69 |
+
size = max(fg.shape[0], fg.shape[1])
|
70 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
71 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
72 |
+
new_image = np.pad(
|
73 |
+
fg,
|
74 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
75 |
+
mode="constant",
|
76 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
77 |
+
)
|
78 |
+
|
79 |
+
# compute padding according to the ratio
|
80 |
+
new_size = int(new_image.shape[0] / ratio)
|
81 |
+
# pad to size, double side
|
82 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
83 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
84 |
+
new_image = np.pad(
|
85 |
+
new_image,
|
86 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
87 |
+
mode="constant",
|
88 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
89 |
+
)
|
90 |
+
new_image = Image.fromarray(new_image, mode="RGBA")
|
91 |
+
return new_image
|