Commit eb13bba6 authored by Gabriele Franch's avatar Gabriele Franch
Browse files

fixed bug in loss calculation

parent ea45c392
......@@ -78,11 +78,11 @@ class Model(object):
next_frames = self.network(frames_tensor, mask_tensor)
npixels = (self.configs.total_length-self.configs.input_length)*self.configs.img_width*self.configs.img_width*self.configs.img_channel
loss = self.MSE_criterion(
next_frames[:, self.configs.input_length-1], frames_tensor[:, self.configs.input_length:]
next_frames[:, self.configs.input_length-1:], frames_tensor[:, self.configs.input_length:]
) + self.MAE_criterion(
next_frames[:, self.configs.input_length-1], frames_tensor[:, self.configs.input_length:]
next_frames[:, self.configs.input_length-1:], frames_tensor[:, self.configs.input_length:]
) + 2 * self.MAE_criterion(
next_frames[:, self.configs.input_length-1].sum(), frames_tensor[:, self.configs.input_length:].sum()
next_frames[:, self.configs.input_length-1:].sum(), frames_tensor[:, self.configs.input_length:].sum()
) / npixels
# 0.02*self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
loss.backward()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment