39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
#
|
|
# Copyright (C) 2023, Inria
|
|
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
|
# All rights reserved.
|
|
#
|
|
# This software is free for non-commercial, research and evaluation use
|
|
# under the terms of the LICENSE.md file.
|
|
#
|
|
# For inquiries contact george.drettakis@inria.fr
|
|
#
|
|
|
|
import torch
|
|
|
|
def mse(img1, img2):
|
|
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
|
@torch.no_grad()
|
|
def psnr(img1, img2, mask=None):
|
|
if mask is not None:
|
|
img1 = img1.flatten(1)
|
|
img2 = img2.flatten(1)
|
|
|
|
mask = mask.flatten(1).repeat(3,1)
|
|
mask = torch.where(mask!=0,True,False)
|
|
img1 = img1[mask]
|
|
img2 = img2[mask]
|
|
|
|
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
|
|
|
else:
|
|
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
|
psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float()))
|
|
if mask is not None:
|
|
if torch.isinf(psnr).any():
|
|
print(mse.mean(),psnr.mean())
|
|
psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float()))
|
|
psnr = psnr[~torch.isinf(psnr)]
|
|
|
|
return psnr
|