Commit ea45c392 authored by Gabriele Franch's avatar Gabriele Franch
Browse files

updated loss

parent 190bf292
......@@ -37,8 +37,8 @@ class Model(object):
self.network = nn.DataParallel(self.network)
self.optimizer = Adam(self.network.parameters(), lr=configs.lr)
# TODO: size_average to sum
self.MSE_criterion = nn.MSELoss(size_average=False)
self.MAE_criterion = nn.L1Loss(size_average=False)
self.MSE_criterion = nn.MSELoss()
self.MAE_criterion = nn.L1Loss()
def save(self, itr, eta):
if self.configs.is_parallel:
......@@ -76,9 +76,14 @@ class Model(object):
# mask_tensor = torch.FloatTensor(mask).cuda()
self.optimizer.zero_grad()
next_frames = self.network(frames_tensor, mask_tensor)
npixels = (self.configs.total_length-self.configs.input_length)*self.configs.img_width*self.configs.img_width*self.configs.img_channel
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
next_frames[:, self.configs.input_length-1], frames_tensor[:, self.configs.input_length:]
) + self.MAE_criterion(
next_frames[:, self.configs.input_length-1], frames_tensor[:, self.configs.input_length:]
) + 2 * self.MAE_criterion(
next_frames[:, self.configs.input_length-1].sum(), frames_tensor[:, self.configs.input_length:].sum()
) / npixels
# 0.02*self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
loss.backward()
self.optimizer.step()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment