import os
import sys
from setuptools import find_packages, setup
TARGET_FILES = [
"./train_dreambooth_lora_flux2.py",
"./train_dreambooth_lora_flux2_img2img.py",
]
INJECTION_LINE = 110
INJECTION_CODE = """from transformer_patches import apply_patches
apply_patches()
dist.init_process_group(
backend="hccl", # Primary backend for NPU
init_method="env://"
)
if dist.is_initialized():
cpu_group = dist.new_group(backend="gloo")
else:
raise RuntimeError("Distributed initialization failed")
from torch._C._distributed_c10d import _register_process_group
_register_process_group("gloo", cpu_group)
"""
def patch_target_files():
for target in TARGET_FILES:
if not os.path.exists(target):
print(
f"Set up process cannot find {target}. Patch will not be applied.",
file=sys.stderr,
)
continue
with open(target, "r") as f:
lines = f.readlines()
content = "".join(lines)
if "apply_patches()" in content:
continue
insert_index = INJECTION_LINE - 1
if insert_index > len(lines):
print(
f"Warning: {target} has only {len(lines)} lines. Appending at end.",
file=sys.stderr,
)
lines.append("\n" + INJECTION_CODE + "\n")
else:
lines.insert(insert_index, INJECTION_CODE + "\n")
with open(target, "w") as f:
f.writelines(lines)
print(f"Patched {target} at line {INJECTION_LINE}.")
if __name__ != "distutils.core":
patch_target_files()
setup(
name="flux2-patch",
version="0.1",
package_dir={"": "src"},
packages=find_packages(where="src"),
install_requires=[
"accelerate==1.10.0",
"torch==2.7.1",
"torchvision",
"transformers==4.55.0",
"ftfy",
"tensorboard",
"Jinja2",
"peft==0.17.0",
"sentencepiece",
"einops",
"attrs",
"scipy",
"decorator",
"mindstudio-probe",
"datasets",
],
)