Commit 5a69648d authored by Gabriele Franch's avatar Gabriele Franch
Browse files

updated loss

parent eb13bba6
......@@ -77,13 +77,15 @@ class Model(object):
self.optimizer.zero_grad()
next_frames = self.network(frames_tensor, mask_tensor)
npixels = (self.configs.total_length-self.configs.input_length)*self.configs.img_width*self.configs.img_width*self.configs.img_channel
loss = self.MSE_criterion(
next_frames[:, self.configs.input_length-1:], frames_tensor[:, self.configs.input_length:]
) + self.MAE_criterion(
next_frames[:, self.configs.input_length-1:], frames_tensor[:, self.configs.input_length:]
) + 2 * self.MAE_criterion(
next_frames[:, self.configs.input_length-1:].sum(), frames_tensor[:, self.configs.input_length:].sum()
) / npixels
loss = 1000 * (
self.MSE_criterion(
next_frames[:, self.configs.input_length-1:], frames_tensor[:, self.configs.input_length:]
) + self.MAE_criterion(
next_frames[:, self.configs.input_length-1:], frames_tensor[:, self.configs.input_length:]
) + 2 * self.MAE_criterion(
next_frames[:, self.configs.input_length-1:].sum(), frames_tensor[:, self.configs.input_length:].sum()
) / npixels
)
# 0.02*self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
loss.backward()
self.optimizer.step()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment