import os
import sys
import time
import torch
from torch import nn
from models.swin import PatchMerging
from utils import format_time
class PatchMergingFixed(PatchMerging):
def forward(self, x):
b, c, h, w = x.shape
new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
raw_device = x.device
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1).to(raw_device)
x = self.linear(x)
return x
for k in sys.modules:
if k == 'models.swin':
if getattr(sys.modules[k], "PatchMerging", None):
setattr(sys.modules[k], "PatchMerging", PatchMergingFixed)
print("PatchMerging has been replaced with PatchMergingFixed for performance optimization.")
try:
_, term_width_str = os.popen('stty size', 'r').read().split()
except ValueError:
term_width_str = '80'
term_width = int(term_width_str)
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
total_step_time = 0
def progress_bar(current, total, msg=None):
global last_time, begin_time, total_step_time
if current == 0:
begin_time = time.time()
total_step_time = 0
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
total_step_time += step_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
return total_step_time/total