|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import glob |
|
import importlib.util |
|
import os |
|
import re |
|
|
|
import black |
|
from doc_builder.style_doc import style_docstrings_in_code |
|
|
|
|
|
|
|
|
|
DIFFUSERS_PATH = "src/diffusers" |
|
REPO_PATH = "." |
|
|
|
|
|
|
|
spec = importlib.util.spec_from_file_location( |
|
"diffusers", |
|
os.path.join(DIFFUSERS_PATH, "__init__.py"), |
|
submodule_search_locations=[DIFFUSERS_PATH], |
|
) |
|
diffusers_module = spec.loader.load_module() |
|
|
|
|
|
def _should_continue(line, indent): |
|
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None |
|
|
|
|
|
def find_code_in_diffusers(object_name): |
|
"""Find and return the code source code of `object_name`.""" |
|
parts = object_name.split(".") |
|
i = 0 |
|
|
|
|
|
module = parts[i] |
|
while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")): |
|
i += 1 |
|
if i < len(parts): |
|
module = os.path.join(module, parts[i]) |
|
if i >= len(parts): |
|
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.") |
|
|
|
with open(os.path.join(DIFFUSERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: |
|
lines = f.readlines() |
|
|
|
|
|
indent = "" |
|
line_index = 0 |
|
for name in parts[i + 1 :]: |
|
while ( |
|
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None |
|
): |
|
line_index += 1 |
|
indent += " " |
|
line_index += 1 |
|
|
|
if line_index >= len(lines): |
|
raise ValueError(f" {object_name} does not match any function or class in {module}.") |
|
|
|
|
|
start_index = line_index |
|
while line_index < len(lines) and _should_continue(lines[line_index], indent): |
|
line_index += 1 |
|
|
|
while len(lines[line_index - 1]) <= 1: |
|
line_index -= 1 |
|
|
|
code_lines = lines[start_index:line_index] |
|
return "".join(code_lines) |
|
|
|
|
|
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)") |
|
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") |
|
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>") |
|
|
|
|
|
def get_indent(code): |
|
lines = code.split("\n") |
|
idx = 0 |
|
while idx < len(lines) and len(lines[idx]) == 0: |
|
idx += 1 |
|
if idx < len(lines): |
|
return re.search(r"^(\s*)\S", lines[idx]).groups()[0] |
|
return "" |
|
|
|
|
|
def blackify(code): |
|
""" |
|
Applies the black part of our `make style` command to `code`. |
|
""" |
|
has_indent = len(get_indent(code)) > 0 |
|
if has_indent: |
|
code = f"class Bla:\n{code}" |
|
mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True) |
|
result = black.format_str(code, mode=mode) |
|
result, _ = style_docstrings_in_code(result) |
|
return result[len("class Bla:\n") :] if has_indent else result |
|
|
|
|
|
def is_copy_consistent(filename, overwrite=False): |
|
""" |
|
Check if the code commented as a copy in `filename` matches the original. |
|
Return the differences or overwrites the content depending on `overwrite`. |
|
""" |
|
with open(filename, "r", encoding="utf-8", newline="\n") as f: |
|
lines = f.readlines() |
|
diffs = [] |
|
line_index = 0 |
|
|
|
while line_index < len(lines): |
|
search = _re_copy_warning.search(lines[line_index]) |
|
if search is None: |
|
line_index += 1 |
|
continue |
|
|
|
|
|
indent, object_name, replace_pattern = search.groups() |
|
theoretical_code = find_code_in_diffusers(object_name) |
|
theoretical_indent = get_indent(theoretical_code) |
|
|
|
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 |
|
indent = theoretical_indent |
|
line_index = start_index |
|
|
|
|
|
should_continue = True |
|
while line_index < len(lines) and should_continue: |
|
line_index += 1 |
|
if line_index >= len(lines): |
|
break |
|
line = lines[line_index] |
|
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None |
|
|
|
while len(lines[line_index - 1]) <= 1: |
|
line_index -= 1 |
|
|
|
observed_code_lines = lines[start_index:line_index] |
|
observed_code = "".join(observed_code_lines) |
|
|
|
|
|
theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None] |
|
theoretical_code = "\n".join(theoretical_code) |
|
|
|
|
|
if len(replace_pattern) > 0: |
|
patterns = replace_pattern.replace("with", "").split(",") |
|
patterns = [_re_replace_pattern.search(p) for p in patterns] |
|
for pattern in patterns: |
|
if pattern is None: |
|
continue |
|
obj1, obj2, option = pattern.groups() |
|
theoretical_code = re.sub(obj1, obj2, theoretical_code) |
|
if option.strip() == "all-casing": |
|
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) |
|
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) |
|
|
|
|
|
|
|
theoretical_code = blackify(lines[start_index - 1] + theoretical_code) |
|
theoretical_code = theoretical_code[len(lines[start_index - 1]) :] |
|
|
|
|
|
if observed_code != theoretical_code: |
|
diffs.append([object_name, start_index]) |
|
if overwrite: |
|
lines = lines[:start_index] + [theoretical_code] + lines[line_index:] |
|
line_index = start_index + 1 |
|
|
|
if overwrite and len(diffs) > 0: |
|
|
|
print(f"Detected changes, rewriting {filename}.") |
|
with open(filename, "w", encoding="utf-8", newline="\n") as f: |
|
f.writelines(lines) |
|
return diffs |
|
|
|
|
|
def check_copies(overwrite: bool = False): |
|
all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True) |
|
diffs = [] |
|
for filename in all_files: |
|
new_diffs = is_copy_consistent(filename, overwrite) |
|
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] |
|
if not overwrite and len(diffs) > 0: |
|
diff = "\n".join(diffs) |
|
raise Exception( |
|
"Found the following copy inconsistencies:\n" |
|
+ diff |
|
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") |
|
args = parser.parse_args() |
|
|
|
check_copies(args.fix_and_overwrite) |
|
|