import os
import sys
import struct
import numpy as np
import torch
import torch.nn.functional as F
def bin2tensor(binName):
size = os.path.getsize(binName)
binfile = open(binName, 'rb')
Len = int(size / 4)
res=[]
for i in range(Len):
data = binfile.read(4)
num = struct.unpack('f', data)
res.append(num[0])
binfile.close()
dim_res = np.array(res)
dim_res = torch.from_numpy(dim_res)
return dim_res
def mel_loss(mel_out, mel_tgt):
"""
mel_out: torch.tensor, shape(batchsize, 80, 900)
mel_tgt: torch.tensor, shape(batchsize, 80, 900)
"""
mel_tgt = mel_tgt.transpose(1, 2)
mel_out = mel_out.transpose(1, 2)
mel_mask = mel_tgt.ne(0).float()
mel_mask_sum = mel_mask.sum()
loss_fn = F.mse_loss
mel_loss = loss_fn(mel_out, mel_tgt, reduction='none')
mel_loss = (mel_loss * mel_mask).sum() / mel_mask_sum
return mel_loss
def test_om():
tgt_path = './test/mel_tgt_pth/'
out_path = sys.argv[1]
data_len = 99
mel_loss_total = 0
for i in range(data_len):
mel_out = bin2tensor(os.path.join(out_path, f"data{i}_0.bin")).reshape(1, 80, 900)
mel_tgt = torch.load(os.path.join(tgt_path, f"mel_tgt{i}.pth"))
mel_loss_ = mel_loss(mel_out, mel_tgt)
mel_loss_total += mel_loss_
mel_loss_average = mel_loss_total / data_len
print("mel_loss_average", mel_loss_average.item())
def test_pth():
out_path = './test/mel_out_pth/'
tgt_path = './test/mel_tgt_pth/'
data_len = 99
mel_loss_total = 0
for i in range(data_len):
mel_out = torch.load(os.path.join(out_path, f"mel_out{i}.pth"))
mel_tgt = torch.load(os.path.join(tgt_path, f"mel_tgt{i}.pth"))
mel_loss_ = mel_loss(mel_out, mel_tgt)
mel_loss_total += mel_loss_
mel_loss_average = mel_loss_total / data_len
print("mel_loss_average", mel_loss_average.item())
if __name__ == "__main__":
print("==================om==================")
test_om()
print("==================pth==================")
test_pth()