diff --git a/configs/actors/obama.yml b/configs/actors/obama.yml new file mode 100644 index 0000000..2ecf3f3 --- /dev/null +++ b/configs/actors/obama.yml @@ -0,0 +1,12 @@ +# Tracker config + +actor: './input/obama' +save_folder: './output/' +flame_geom_path: 'data/FLAME2023/generic.pkl' +optimize_shape: true +optimize_jaw: true +begin_frames: 269 +end_frames: 2950 +keyframes: [ 0, 1, 2, 3 ] +pyr_levels: [ [ 0.25, 8 ], [ 0.5, 8 ], [ 1.0, 8 ] ] +raster_update: 4 \ No newline at end of file diff --git a/configs/actors/yufeng_exp.yml b/configs/actors/yufeng_exp.yml new file mode 100644 index 0000000..c045a1e --- /dev/null +++ b/configs/actors/yufeng_exp.yml @@ -0,0 +1,11 @@ +# Tracker config + +actor: './input/yufeng_exp' +save_folder: './output/' +optimize_shape: true +optimize_jaw: true +begin_frames: 150 +end_frames: 1000 +keyframes: [ 0, 1, 2 ] +pyr_levels: [ [ 0.25, 8 ], [ 0.5, 8 ], [ 1.0, 8 ] ] +raster_update: 1 \ No newline at end of file diff --git a/configs/actors/yufeng_rot.yml b/configs/actors/yufeng_rot.yml new file mode 100644 index 0000000..5aff6f6 --- /dev/null +++ b/configs/actors/yufeng_rot.yml @@ -0,0 +1,11 @@ +# Tracker config + +actor: './input/yufeng_rot' +save_folder: './output/' +optimize_shape: true +optimize_jaw: true +begin_frames: 0 +end_frames: 400 +keyframes: [ 0, 1, 2 ] +pyr_levels: [ [ 0.25, 8 ], [ 0.5, 8 ], [ 1.0, 8 ] ] +raster_update: 1 \ No newline at end of file diff --git a/configs/config.py b/configs/config.py index 668fbf7..fc2f617 100644 --- a/configs/config.py +++ b/configs/config.py @@ -24,7 +24,7 @@ cfg.flame_geom_path = 'data/FLAME2020/generic_model.pkl' cfg.flame_template_path = 'data/uv_template.obj' cfg.flame_lmk_path = 'data/landmark_embedding.npy' -cfg.tex_space_path = 'data/FLAME2020/FLAME_texture.npz' +cfg.tex_space_path = 'data/FLAME2020/FLAME_albedo_from_BFM.npz' cfg.num_shape_params = 300 cfg.num_exp_params = 100 @@ -39,18 +39,19 @@ cfg.begin_frames = 0 cfg.end_frames = 0 cfg.image_size = [512, 512] # height, width -cfg.rotation_lr = 0.075 -cfg.translation_lr = 0.003 +cfg.rotation_lr = 0.001 +cfg.translation_lr = 0.001 cfg.raster_update = 16 cfg.pyr_levels = [[0.25, 90], [0.5, 90], [1.0, 90]] # Gaussian pyramid levels (scaling, iters per level) cfg.optimize_shape = False +cfg.optimize_dense = True cfg.optimize_jaw = False cfg.crop_image = True cfg.save_folder = './output/' # Weights -cfg.w_pho = 350 -cfg.w_lmks = 5000 +cfg.w_pho = 1000 +cfg.w_lmks = 15000 cfg.w_lmks_lid = 1000 cfg.w_lmks_mouth = 15000 cfg.w_lmks_iris = 1000 @@ -58,7 +59,7 @@ cfg.w_exp = 0.02 cfg.w_shape = 0.3 -cfg.w_tex = 0.04 +cfg.w_tex = 0.05 def get_cfg_defaults(): diff --git a/masking.py b/masking.py index c7f3040..2e5c406 100644 --- a/masking.py +++ b/masking.py @@ -79,10 +79,17 @@ def get_mask_lips(self): return self.masks.lips def get_mask_rendering(self): - face_mask = torch.ones_like(self.vertices)[None] - face_mask[:, self.masks.boundary, :] = 0.0 - face_mask[:, self.masks.left_ear, :] = 0.0 - face_mask[:, self.masks.right_ear, :] = 0.0 + # face_mask = torch.ones_like(self.vertices)[None] + # face_mask[:, self.masks.boundary, :] = 0.0 + # face_mask[:, self.masks.left_ear, :] = 0.0 + # face_mask[:, self.masks.right_ear, :] = 0.0 + # face_mask[:, self.masks.scalp, :] = 0.0 + + face_mask = torch.zeros_like(self.vertices)[None] + face_mask[:, self.masks.face, :] = 1.0 + face_mask[:, self.masks.left_eyeball, :] = 1.0 + face_mask[:, self.masks.right_eyeball, :] = 1.0 + return face_mask def get_mask_depth(self): diff --git a/renderer.py b/renderer.py index abcf895..dc8c0ea 100644 --- a/renderer.py +++ b/renderer.py @@ -161,11 +161,12 @@ def forward(self, vertices_world, albedos, lights, cameras): attributes = torch.cat([uv, face_vertices_ndc, face_normals, face_mask, face_vertices_view, render_mask, depth_mask, eyes_region_mask, eyes_mask], -1) rendering, zbuffer = self.rasterizer(meshes_world, attributes, cameras=cameras) - uvcoords_images = rendering[:, 0:3, :, :].detach() + uvcoords_images = rendering[:, 0:3, :, :] ndc_vertices_images = rendering[:, 3:6, :, :] + view_vertices_images = rendering[:, 12:15, :, :] + normal_images = rendering[:, 6:9, :, :].detach() mask_images_mesh = rendering[:, 9:12, :, :].detach() - view_vertices_images = rendering[:, 12:15, :, :] mask_images_rendering = rendering[:, 15:18, :, :].detach() mask_images_depth = rendering[:, 18:21, :, :].detach() mask_images_eyes_region = rendering[:, 21:24, :, :].detach() @@ -184,12 +185,12 @@ def forward(self, vertices_world, albedos, lights, cameras): 'images': images * alpha_images, 'albedo_images': albedo_images, 'alpha_images': alpha_images, - 'mask_images': (mask_images * alpha_images > 0).float(), - 'mask_images_mesh': (mask_images_mesh > 0).float(), - 'mask_images_rendering': (mask_images_rendering > 0).float(), - 'mask_images_depth': (mask_images_depth > 0).float(), - 'mask_images_eyes_region': (mask_images_eyes_region > 0).float(), - 'mask_images_eyes': (mask_images_eyes > 0).float(), + 'mask_images': mask_images * alpha_images, + 'mask_images_mesh': mask_images_mesh, + 'mask_images_rendering': mask_images_rendering, + 'mask_images_depth': mask_images_depth, + 'mask_images_eyes_region': mask_images_eyes_region, + 'mask_images_eyes': mask_images_eyes, 'position_images': ndc_vertices_images, 'position_view_images': view_vertices_images, 'zbuffer': zbuffer diff --git a/tracker.py b/tracker.py index 3c54b90..39a9e3d 100644 --- a/tracker.py +++ b/tracker.py @@ -72,6 +72,11 @@ class View(Enum): DEPTH = 64 +class Mode(Enum): + SPARSE = 1 + DENSE = 2 + + class Tracker(object): def __init__(self, config, device='cuda:0'): self.config = config @@ -87,7 +92,9 @@ def __init__(self, config, device='cuda:0'): logger.add(os.path.join(self.config.save_folder, self.actor_name, 'train.log')) # Latter will be set up - self.frame = 0 + self.sparse_frame = 0 + self.dense_frame = 0 + self.optimization_mode = Mode.SPARSE self.is_initializing = False self.image_size = torch.tensor([[config.image_size[0], config.image_size[1]]]).cuda() self.save_folder = self.config.save_folder @@ -140,10 +147,10 @@ def setup_renderer(self): shader=SoftPhongShader(device=self.device, lights=self.lights) ) - def load_checkpoint(self, idx=-1): + def load_checkpoint(self): if not os.path.exists(self.checkpoint_folder): return False - snaps = sorted(glob(self.checkpoint_folder + '/*.frame')) + snaps = sorted(glob(self.checkpoint_folder + '/checkpoint.frame')) if len(snaps) == 0: logger.info('Training from beginning...') return False @@ -151,51 +158,58 @@ def load_checkpoint(self, idx=-1): logger.info('Training has finished...') exit(0) - last_snap = snaps[idx] + last_snap = snaps[0] payload = torch.load(last_snap) + def to_gpu(params_list): + return [nn.Parameter(torch.from_numpy(params_list[i]).to(self.device)) for i in range(len(params_list))] + camera_params = payload['camera'] - self.R = nn.Parameter(torch.from_numpy(camera_params['R']).to(self.device)) - self.t = nn.Parameter(torch.from_numpy(camera_params['t']).to(self.device)) + self.R = to_gpu(camera_params['R']) + self.t = to_gpu(camera_params['t']) self.focal_length = nn.Parameter(torch.from_numpy(camera_params['fl']).to(self.device)) self.principal_point = nn.Parameter(torch.from_numpy(camera_params['pp']).to(self.device)) flame_params = payload['flame'] self.tex = nn.Parameter(torch.from_numpy(flame_params['tex']).to(self.device)) - self.exp = nn.Parameter(torch.from_numpy(flame_params['exp']).to(self.device)) - self.sh = nn.Parameter(torch.from_numpy(flame_params['sh']).to(self.device)) + self.exp = to_gpu(flame_params['exp']) + self.sh = to_gpu(flame_params['sh']) self.shape = nn.Parameter(torch.from_numpy(flame_params['shape']).to(self.device)) - self.mica_shape = nn.Parameter(torch.from_numpy(flame_params['shape']).to(self.device)) - self.eyes = nn.Parameter(torch.from_numpy(flame_params['eyes']).to(self.device)) - self.eyelids = nn.Parameter(torch.from_numpy(flame_params['eyelids']).to(self.device)) - self.jaw = nn.Parameter(torch.from_numpy(flame_params['jaw']).to(self.device)) + self.mica_shape = torch.from_numpy(flame_params['shape']).to(self.device) + self.eyes = to_gpu(flame_params['eyes']) + self.eyelids = to_gpu(flame_params['eyelids']) + self.jaw = to_gpu(flame_params['jaw']) - self.frame = int(payload['frame_id']) + self.sparse_frame = int(payload['sparse_frame']) + self.dense_frame = int(payload['dense_frame']) self.global_step = payload['global_step'] - self.update_prev_frame() + self.image_size = torch.from_numpy(payload['img_size'])[None].to(self.device) self.setup_renderer() - logger.info(f'Snapshot loaded for frame {self.frame}') + logger.info(f'Snapshot loaded from Sparse Frame {self.sparse_frame} | Dense Frame {self.dense_frame}') return True - def save_checkpoint(self, frame_id): + def save_checkpoint(self): opencv = opencv_from_cameras_projection(self.cameras, self.image_size) + def to_cpu(params_list): + return [params_list[i].clone().detach().cpu().numpy() for i in range(len(self.dataset))] + frame = { 'flame': { - 'exp': self.exp.clone().detach().cpu().numpy(), + 'exp': to_cpu(self.exp), 'shape': self.shape.clone().detach().cpu().numpy(), 'tex': self.tex.clone().detach().cpu().numpy(), - 'sh': self.sh.clone().detach().cpu().numpy(), - 'eyes': self.eyes.clone().detach().cpu().numpy(), - 'eyelids': self.eyelids.clone().detach().cpu().numpy(), - 'jaw': self.jaw.clone().detach().cpu().numpy() + 'sh': to_cpu(self.sh), + 'eyes': to_cpu(self.eyes), + 'eyelids': to_cpu(self.eyelids), + 'jaw': to_cpu(self.jaw) }, 'camera': { - 'R': self.R.clone().detach().cpu().numpy(), - 't': self.t.clone().detach().cpu().numpy(), + 'R': to_cpu(self.R), + 't': to_cpu(self.t), 'fl': self.focal_length.clone().detach().cpu().numpy(), 'pp': self.principal_point.clone().detach().cpu().numpy(), }, @@ -205,24 +219,27 @@ def save_checkpoint(self, frame_id): 'K': opencv[2].clone().detach().cpu().numpy(), }, 'img_size': self.image_size.clone().detach().cpu().numpy()[0], - 'frame_id': frame_id, + 'sparse_frame': self.sparse_frame, + 'dense_frame': self.dense_frame, 'global_step': self.global_step } + frame_id = str(self.get_frame()).zfill(5) + vertices, _, _ = self.flame( cameras=torch.inverse(self.cameras.R), shape_params=self.shape, - expression_params=self.exp, - eye_pose_params=self.eyes, - jaw_pose_params=self.jaw, - eyelid_params=self.eyelids + expression_params=self.exp[self.get_frame()], + eye_pose_params=self.eyes[self.get_frame()], + jaw_pose_params=self.jaw[self.get_frame()], + eyelid_params=self.eyelids[self.get_frame()] ) f = self.diff_renderer.faces[0].cpu().numpy() v = vertices[0].cpu().numpy() trimesh.Trimesh(faces=f, vertices=v, process=False).export(f'{self.mesh_folder}/{frame_id}.ply') - torch.save(frame, f'{self.checkpoint_folder}/{frame_id}.frame') + torch.save(frame, f'{self.checkpoint_folder}/checkpoint.frame') def save_canonical(self): canon = os.path.join(self.save_folder, self.actor_name, "canonical.obj") @@ -245,13 +262,6 @@ def get_heatmap(self, values): return heatmap - def update_prev_frame(self): - self.prev_R = self.R.clone().detach() - self.prev_t = self.t.clone().detach() - self.prev_exp = self.exp.clone().detach() - self.prev_eyes = self.eyes.clone().detach() - self.prev_jaw = self.jaw.clone().detach() - def render_shape(self, vertices, faces=None, white=True): B = vertices.shape[0] V = vertices.shape[1] @@ -281,20 +291,20 @@ def to_cuda(self, batch, unsqueeze=False): return batch def create_parameters(self): - bz = 1 + frames = len(self.dataset) R, T = look_at_view_transform(dist=1.0) - self.R = nn.Parameter(matrix_to_rotation_6d(R).to(self.device)) - self.t = nn.Parameter(T.to(self.device)) - self.shape = nn.Parameter(self.mica_shape) - self.mica_shape = nn.Parameter(self.mica_shape) - self.tex = nn.Parameter(torch.zeros(bz, self.config.tex_params).float().to(self.device)) - self.exp = nn.Parameter(torch.zeros(bz, self.config.num_exp_params).float().to(self.device)) - self.sh = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device)) + self.R = [nn.Parameter(matrix_to_rotation_6d(R).to(self.device)) for _ in range(frames)] + self.t = [nn.Parameter(T.to(self.device)) for _ in range(frames)] + self.shape = nn.Parameter(self.mica_shape.clone()) + self.mica_shape = self.mica_shape.clone() + self.tex = nn.Parameter(torch.zeros(1, self.config.tex_params).float().to(self.device)) + self.exp = [nn.Parameter(torch.zeros(1, self.config.num_exp_params).float().to(self.device)) for _ in range(frames)] + self.sh = [nn.Parameter(torch.zeros(1, 9, 3).float().to(self.device)) for _ in range(frames)] self.focal_length = nn.Parameter(torch.tensor([[5000 / self.get_image_size()[0]]]).to(self.device)) - self.principal_point = nn.Parameter(torch.zeros(bz, 2).float().to(self.device)) - self.eyes = nn.Parameter(torch.cat([matrix_to_rotation_6d(I), matrix_to_rotation_6d(I)], dim=1)) - self.jaw = nn.Parameter(matrix_to_rotation_6d(I)) - self.eyelids = nn.Parameter(torch.zeros(bz, 2).float().to(self.device)) + self.principal_point = nn.Parameter(torch.zeros(1, 2).float().to(self.device)) + self.eyes = [nn.Parameter(torch.cat([matrix_to_rotation_6d(I), matrix_to_rotation_6d(I)], dim=1)) for _ in range(frames)] + self.jaw = [nn.Parameter(matrix_to_rotation_6d(I)) for _ in range(frames)] + self.eyelids = [nn.Parameter(torch.zeros(1, 2).float().to(self.device)) for _ in range(frames)] @staticmethod def save_tensor(tensor, path='tensor.jpg'): @@ -302,14 +312,14 @@ def save_tensor(tensor, path='tensor.jpg'): img = np.minimum(np.maximum(img, 0), 255).astype(np.uint8) cv2.imwrite(path, img) - def parse_mask(self, ops, batch, visualization=False): + def parse_mask(self, ops, visualization=False): _, _, h, w = ops['alpha_images'].shape - result = ops['mask_images_rendering'] * 0.25 + ops['mask_images'] + result = ops['mask_images'] # * 0.25 + ops['mask_images'] # Lower the region for eyes blinking - if not self.is_initializing: - eyes = ops['mask_images_eyes_region'] - result = (1.0 - eyes) * result + eyes * 0.5 + # if not self.is_initializing: + # eyes = ops['mask_images_eyes_region'] + # result = (1.0 - eyes) * result + eyes * 1.5 if visualization: result = ops['mask_images'] @@ -319,47 +329,74 @@ def parse_mask(self, ops, batch, visualization=False): def update(self, param_groups): for param in param_groups: for i, name in enumerate(param['name']): - setattr(self, name, nn.Parameter(param['params'][i].clone().detach())) + attr = getattr(self, name) + if type(attr) is list: + attr[self.get_frame()] = nn.Parameter(param['params'][i].clone().detach()) + else: + setattr(self, name, nn.Parameter(param['params'][i].clone().detach())) def get_param(self, name, param_groups): for param in param_groups: if name in param['name']: return param['params'][param['name'].index(name)] - return getattr(self, name) + attr = getattr(self, name) + if type(attr) is list: + return attr[self.get_frame()] + else: + return attr def clone_params_tracking(self): params = [ - {'params': [nn.Parameter(self.exp.clone())], 'lr': 0.01, 'name': ['exp']}, - {'params': [nn.Parameter(self.eyes.clone())], 'lr': 0.001, 'name': ['eyes']}, - {'params': [nn.Parameter(self.eyelids.clone())], 'lr': 0.001, 'name': ['eyelids']}, - {'params': [nn.Parameter(self.R.clone())], 'lr': self.config.rotation_lr, 'name': ['R']}, - {'params': [nn.Parameter(self.t.clone())], 'lr': self.config.translation_lr, 'name': ['t']}, - {'params': [nn.Parameter(self.sh.clone())], 'lr': 0.001, 'name': ['sh']} + # Initialized form the previous + {'params': [nn.Parameter(self.eyes[self.frame].detach().clone())], 'lr': 0.001, 'name': ['eyes']}, + {'params': [nn.Parameter(self.eyelids[self.frame].detach().clone())], 'lr': 0.001, 'name': ['eyelids']}, + {'params': [nn.Parameter(self.sh[self.frame].detach().clone())], 'lr': 0.001, 'name': ['sh']}, + {'params': [nn.Parameter(self.exp[self.frame].detach().clone())], 'lr': 0.01, 'name': ['exp']}, + {'params': [nn.Parameter(self.R[self.frame].detach().clone())], 'lr': self.config.rotation_lr, 'name': ['R']}, + {'params': [nn.Parameter(self.t[self.frame].detach().clone())], 'lr': self.config.translation_lr, 'name': ['t']} ] if self.config.optimize_jaw: - params.append({'params': [nn.Parameter(self.jaw.clone().detach())], 'lr': 0.001, 'name': ['jaw']}) + params.append({'params': [nn.Parameter(self.jaw[self.frame].detach().clone())], 'lr': 0.001, 'name': ['jaw']}) + + return params + + def clone_params_camera(self): + params = [ + {'params': [nn.Parameter(self.t[self.frame].detach().clone())], 'lr': 0.05, 'name': ['t']}, + {'params': [nn.Parameter(self.R[self.frame].detach().clone())], 'lr': 0.05, 'name': ['R']}, + {'params': [nn.Parameter(self.principal_point.detach().clone())], 'lr': 0.05, 'name': ['principal_point']}, + {'params': [nn.Parameter(self.focal_length.detach().clone())], 'lr': 0.05, 'name': ['focal_length']} + ] + + return params + + def clone_params_light(self): + params = [ + {'params': [nn.Parameter(self.sh[self.frame].detach().clone())], 'lr': 0.01, 'name': ['sh']}, + {'params': [nn.Parameter(self.tex).detach().clone()], 'lr': 0.005, 'name': ['tex']}, + ] return params def clone_params_color(self): params = [ - {'params': [nn.Parameter(self.exp.clone())], 'lr': 0.025, 'name': ['exp']}, - {'params': [nn.Parameter(self.eyes.clone())], 'lr': 0.001, 'name': ['eyes']}, - {'params': [nn.Parameter(self.eyelids.clone())], 'lr': 0.01, 'name': ['eyelids']}, - {'params': [nn.Parameter(self.sh.clone())], 'lr': 0.01, 'name': ['sh']}, - {'params': [nn.Parameter(self.tex.clone())], 'lr': 0.005, 'name': ['tex']}, - {'params': [nn.Parameter(self.t.clone())], 'lr': 0.005, 'name': ['t']}, - {'params': [nn.Parameter(self.R.clone())], 'lr': 0.005, 'name': ['R']}, - {'params': [nn.Parameter(self.principal_point.clone())], 'lr': 0.001, 'name': ['principal_point']}, - {'params': [nn.Parameter(self.focal_length.clone())], 'lr': 0.001, 'name': ['focal_length']} + {'params': [nn.Parameter(self.exp[self.frame].detach().clone())], 'lr': 0.025, 'name': ['exp']}, + {'params': [nn.Parameter(self.eyes[self.frame].detach().clone())], 'lr': 0.001, 'name': ['eyes']}, + {'params': [nn.Parameter(self.eyelids[self.frame].detach().clone())], 'lr': 0.01, 'name': ['eyelids']}, + {'params': [nn.Parameter(self.sh[self.frame].detach().clone())], 'lr': 0.01, 'name': ['sh']}, + {'params': [nn.Parameter(self.tex).detach().clone()], 'lr': 0.005, 'name': ['tex']}, + {'params': [nn.Parameter(self.t[self.frame].detach().clone())], 'lr': 0.005, 'name': ['t']}, + {'params': [nn.Parameter(self.R[self.frame].detach().clone())], 'lr': 0.005, 'name': ['R']}, + {'params': [nn.Parameter(self.principal_point.detach().clone())], 'lr': 0.001, 'name': ['principal_point']}, + {'params': [nn.Parameter(self.focal_length.detach().clone())], 'lr': 0.001, 'name': ['focal_length']} ] if self.config.optimize_shape: - params.append({'params': [nn.Parameter(self.shape.clone().detach())], 'lr': 0.025, 'name': ['shape']}) + params.append({'params': [nn.Parameter(self.shape.detach().clone())], 'lr': 0.025, 'name': ['shape']}) if self.config.optimize_jaw: - params.append({'params': [nn.Parameter(self.jaw.clone().detach())], 'lr': 0.001, 'name': ['jaw']}) + params.append({'params': [nn.Parameter(self.jaw[self.frame].detach().clone())], 'lr': 0.001, 'name': ['jaw']}) return params @@ -371,84 +408,199 @@ def reduce_loss(losses): losses['all_loss'] = all_loss return all_loss - def optimize_camera(self, batch, steps=1000): - batch = self.to_cuda(batch) - images, landmarks, landmarks_dense, lmk_dense_mask, lmk_mask = self.parse_batch(batch) + def optimize_sparse(self): + self.optimization_mode = Mode.SPARSE + for i in tqdm(list(range(self.sparse_frame, len(self.dataset)))): + batch = self.to_cuda(self.dataset[i], unsqueeze=True) + if type(batch) is torch.Tensor: + continue - h, w = images.shape[2:4] - self.shape = batch['shape'] - self.mica_shape = batch['shape'].clone().detach() # Save it for regularization + images, landmarks, landmarks_dense, lmk_dense_mask, lmk_mask = self.parse_batch(batch) + h, w = images.shape[2:4] + + frame = self.sparse_frame + if self.sparse_frame > 0: + frame -= 1 + + left_iris = batch['left_iris'] + right_iris = batch['right_iris'] + mask_left_iris = batch['mask_left_iris'] + mask_right_iris = batch['mask_right_iris'] + + R = nn.Parameter(self.R[frame].detach().clone()) + T = nn.Parameter(self.t[frame].detach().clone()) + exp = nn.Parameter(self.exp[frame].detach().clone()) + eyes = nn.Parameter(self.eyes[frame].detach().clone()) + eyelids = nn.Parameter(self.eyelids[frame].detach().clone()) + jaw = nn.Parameter(self.jaw[frame].detach().clone()) + + optimizer = torch.optim.LBFGS( + [R, T, exp, eyes, eyelids, jaw], + lr=0.5, + max_iter=32, + line_search_fn="strong_wolfe") + + def closure(): + self.cameras = PerspectiveCameras( + device=self.device, + principal_point=self.principal_point, + focal_length=self.focal_length, + R=rotation_6d_to_matrix(R), T=T, + image_size=self.image_size + ) + vertices, lmk68, lmkMP = self.flame( + cameras=torch.inverse(self.cameras.R), + shape_params=self.shape, + expression_params=exp, + eye_pose_params=eyes, + jaw_pose_params=jaw, + eyelid_params=eyelids + ) + points68 = self.cameras.transform_points_screen(lmk68)[..., :2] + pointsMP = self.cameras.transform_points_screen(lmkMP)[..., :2] + proj_vertices = self.cameras.transform_points_screen(vertices)[..., :2] - # Important to initialize - self.create_parameters() + losses = {} + losses['lmkMP'] = util.lmk_loss(pointsMP, landmarks_dense[..., :2], [h, w], lmk_dense_mask) * self.config.w_lmks + losses['lmk_oval'] = util.oval_lmk_loss(points68, landmarks[..., :2], [h, w], lmk_mask) * self.config.w_lmks_oval + losses['lmk_mouth'] = util.mouth_lmk_loss(pointsMP, landmarks_dense[..., :2], [h, w], True, lmk_dense_mask) * self.config.w_lmks_mouth + losses['lmk_eye'] = util.eye_closure_lmk_loss(pointsMP, landmarks_dense[..., :2], [h, w], lmk_dense_mask) * self.config.w_lmks_lid + losses['lmk_eyelids'] = util.eye_lids_lmk_loss(pointsMP, landmarks_dense[..., :2], [h, w], lmk_dense_mask) * self.config.w_lmks_lid + losses['lmk_iris_left'] = util.lmk_loss(proj_vertices[:, left_iris_flame, ...], left_iris, [h, w], mask_left_iris) * self.config.w_lmks_iris + losses['lmk_iris_right'] = util.lmk_loss(proj_vertices[:, right_iris_flame, ...], right_iris, [h, w], mask_right_iris) * self.config.w_lmks_iris - params = [{'params': [self.t, self.R, self.focal_length, self.principal_point], 'lr': 0.05}] + right_eye, left_eye = eyes[:, :6], eyes[:, 6:] - optimizer = torch.optim.Adam(params) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.1) + losses['reg/exp'] = torch.sum(exp ** 2) * self.config.w_exp + losses['reg/eyelids'] = torch.sum(eyelids ** 2) * 0.001 + losses['reg/sym'] = torch.sum((right_eye - left_eye) ** 2) * 8.0 + losses['reg/jaw'] = torch.sum((I6D - jaw) ** 2) * 16.0 + losses['reg/eye_lids'] = torch.sum((eyelids[:, 0] - eyelids[:, 1]) ** 2) + losses['reg/eye_left'] = torch.sum((I6D - left_eye) ** 2) + losses['reg/eye_right'] = torch.sum((I6D - right_eye) ** 2) - t = tqdm(range(steps), desc='', leave=True, miniters=100) - for k in t: - self.cameras = PerspectiveCameras( - device=self.device, - principal_point=self.principal_point, - focal_length=self.focal_length, - R=rotation_6d_to_matrix(self.R), T=self.t, - image_size=self.image_size - ) - _, lmk68, lmkMP = self.flame(cameras=torch.inverse(self.cameras.R), shape_params=self.shape, expression_params=self.exp, eye_pose_params=self.eyes, jaw_pose_params=self.jaw) - points68 = self.cameras.transform_points_screen(lmk68)[..., :2] - pointsMP = self.cameras.transform_points_screen(lmkMP)[..., :2] - - losses = {} - losses['pp_reg'] = torch.sum(self.principal_point ** 2) - losses['lmk68'] = util.lmk_loss(points68, landmarks[..., :2], [h, w], lmk_mask) * self.config.w_lmks - losses['lmkMP'] = util.lmk_loss(pointsMP, landmarks_dense[..., :2], [h, w], lmk_dense_mask) * self.config.w_lmks - - all_loss = 0. - for key in losses.keys(): - all_loss = all_loss + losses[key] - losses['all_loss'] = all_loss - - optimizer.zero_grad() - all_loss.backward() - optimizer.step() - scheduler.step() - - loss = all_loss.item() - # self.writer.add_scalar('camera', loss, global_step=k) - t.set_description(f'Loss for camera {loss:.4f}') - self.frame += 1 - if k % 100 == 0 and k > 0: - self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY]], frame_dst='/camera', save=False, dump_directly=True) - - self.frame = 0 + all_loss = 0. + for key in losses.keys(): + all_loss = all_loss + losses[key] + losses['all_loss'] = all_loss + + optimizer.zero_grad() + all_loss.backward() + + return all_loss + + for _ in range(4): + optimizer.step(closure) + + self.R[self.sparse_frame] = nn.Parameter(R.detach().clone()) + self.t[self.sparse_frame] = nn.Parameter(T.detach().clone()) + self.exp[self.sparse_frame] = nn.Parameter(exp.detach().clone()) + self.eyes[self.sparse_frame] = nn.Parameter(eyes.detach().clone()) + self.eyelids[self.sparse_frame] = nn.Parameter(eyelids.detach().clone()) + self.jaw[self.sparse_frame] = nn.Parameter(jaw.detach().clone()) + + self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE]], frame_dst='/sparse', dump_directly=True) + self.sparse_frame += 1 + + def optimize_camera(self, steps=700): + for i, j in enumerate(self.config.keyframes[:1]): + self.frame = j + batch = self.to_cuda(self.dataset[j], unsqueeze=True) + + images, landmarks, landmarks_dense, lmk_dense_mask, lmk_mask = self.parse_batch(batch) + + h, w = images.shape[2:4] + self.shape = batch['shape'] + self.mica_shape = batch['shape'].clone().detach() # Save it for regularization + + # Important to initialize + self.create_parameters() + + params = self.clone_params_camera() + optimizer = torch.optim.Adam(params) + + t = self.get_param('t', params) + R = self.get_param('R', params) + fl = self.get_param('focal_length', params) + pp = self.get_param('principal_point', params) + + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.1) + + tq = tqdm(range(steps), desc='', leave=True, miniters=100) + for k in tq: + self.cameras = PerspectiveCameras( + device=self.device, + principal_point=pp, + focal_length=fl, + R=rotation_6d_to_matrix(R), T=t, + image_size=self.image_size + ) + _, lmk68, lmkMP = self.flame(cameras=torch.inverse(self.cameras.R), shape_params=self.shape) + points68 = self.cameras.transform_points_screen(lmk68)[..., :2] + pointsMP = self.cameras.transform_points_screen(lmkMP)[..., :2] + + losses = {} + losses['pp_reg'] = torch.sum(self.principal_point ** 2) + losses['lmk68'] = util.lmk_loss(points68, landmarks[..., :2], [h, w], lmk_mask) * self.config.w_lmks + losses['lmkMP'] = util.lmk_loss(pointsMP, landmarks_dense[..., :2], [h, w], lmk_dense_mask) * self.config.w_lmks + + all_loss = 0. + for key in losses.keys(): + all_loss = all_loss + losses[key] + losses['all_loss'] = all_loss + + optimizer.zero_grad() + all_loss.backward() + optimizer.step() + scheduler.step() + + loss = all_loss.item() + # self.writer.add_scalar('camera', loss, global_step=k) + tq.set_description(f'Loss for camera {loss:.4f}') + + self.update(params) + self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY]], frame_dst='/camera', save=False, dump_directly=True) def optimize_color(self, batch, pyramid, params_func, pho_weight_func, reg_from_prev=False): - self.update_prev_frame() images, landmarks, landmarks_dense, lmk_dense_mask, lmk_mask = self.parse_batch(batch) aspect_ratio = util.get_aspect_ratio(images) h, w = images.shape[2:4] logs = [] + frame = self.dense_frame + if self.dense_frame > 0: + frame -= 1 + + levels = len(pyramid) + for k, level in enumerate(pyramid): img, iters, size, image_size = level - # Optimizer per step - optimizer = torch.optim.Adam(params_func()) - params = optimizer.param_groups - - shape = self.get_param('shape', params) - exp = self.get_param('exp', params) - eyes = self.get_param('eyes', params) - eyelids = self.get_param('eyelids', params) - jaw = self.get_param('jaw', params) - tex = self.get_param('tex', params) - sh = self.get_param('sh', params) - t = self.get_param('t', params) - R = self.get_param('R', params) - fl = self.get_param('focal_length', params) - pp = self.get_param('principal_point', params) + + shape = nn.Parameter(self.shape.detach().clone()) + exp = nn.Parameter(self.exp[self.dense_frame].detach().clone()) + eyes = nn.Parameter(self.eyes[self.dense_frame].detach().clone()) + eyelids = nn.Parameter(self.eyelids[self.dense_frame].detach().clone()) + jaw = nn.Parameter(self.jaw[self.dense_frame].detach().clone()) + tex = nn.Parameter(self.tex.detach().clone()) + sh = nn.Parameter(self.sh[frame].detach().clone()) + R = nn.Parameter(self.R[self.dense_frame].detach().clone()) + t = nn.Parameter(self.t[self.dense_frame].detach().clone()) + fl = nn.Parameter(self.focal_length.detach().clone()) + pp = nn.Parameter(self.principal_point.detach().clone()) + + max_iters = 8 + params = [shape, exp, eyelids, eyes, jaw, sh, R, t, fl, pp] + + if self.dense_frame == 0: + params = [tex, sh] + + if not self.is_initializing: + params = [exp, eyelids, eyes, jaw, sh, R, t] + max_iters = 4 + + optimizer = torch.optim.LBFGS(params, lr=1.0, max_iter=32, line_search_fn="strong_wolfe") + # optimizer = torch.optim.Adam(params, lr=0.01) scale = image_size[0] / h self.diff_renderer.set_size(size) @@ -462,13 +614,10 @@ def optimize_color(self, batch, pyramid, params_func, pho_weight_func, reg_from_ mask_left_iris = batch['mask_left_iris'] * scale mask_right_iris = batch['mask_right_iris'] * scale - self.diff_renderer.rasterizer.reset() - - best_loss = np.inf - - for p in range(iters): - if p % self.config.raster_update == 0: - self.diff_renderer.rasterizer.reset() + # for p in range(iters): + def closure(): + # if p % self.config.raster_update == 0: + # self.diff_renderer.rasterizer.reset() losses = {} self.cameras = PerspectiveCameras( device=self.device, @@ -496,18 +645,14 @@ def optimize_color(self, batch, pyramid, params_func, pho_weight_func, reg_from_ losses['loss/lmk_oval'] = util.oval_lmk_loss(proj_lmks68, image_lmks68, image_size, lmk_mask) * self.config.w_lmks_oval losses['loss/lmk_MP'] = util.face_lmk_loss(proj_lmksMP, image_lmksMP, image_size, True, lmk_dense_mask) * self.config.w_lmks losses['loss/lmk_eye'] = util.eye_closure_lmk_loss(proj_lmksMP, image_lmksMP, image_size, lmk_dense_mask) * self.config.w_lmks_lid + losses['loss/lmk_eyelids'] = util.eye_lids_lmk_loss(proj_lmksMP, image_lmksMP, image_size, lmk_dense_mask) * self.config.w_lmks_lid losses['loss/lmk_mouth'] = util.mouth_lmk_loss(proj_lmksMP, image_lmksMP, image_size, True, lmk_dense_mask) * self.config.w_lmks_mouth losses['loss/lmk_iris_left'] = util.lmk_loss(proj_vertices[:, left_iris_flame, ...], left_iris, image_size, mask_left_iris) * self.config.w_lmks_iris losses['loss/lmk_iris_right'] = util.lmk_loss(proj_vertices[:, right_iris_flame, ...], right_iris, image_size, mask_right_iris) * self.config.w_lmks_iris - # Increase landmark weight for the lower level of the pyramid - lmk_scale = np.exp(0.6 * len(pyramid) / (k + 1)) - for key in losses.keys(): - if "lmk_" in key: - losses[key] = losses[key] * lmk_scale - # Reguralizers losses['reg/exp'] = torch.sum(exp ** 2) * self.config.w_exp + losses['reg/eyelids'] = torch.sum(eyelids ** 2) * 0.001 losses['reg/sym'] = torch.sum((right_eye - left_eye) ** 2) * 8.0 losses['reg/jaw'] = torch.sum((I6D - jaw) ** 2) * 16.0 losses['reg/eye_lids'] = torch.sum((eyelids[:, 0] - eyelids[:, 1]) ** 2) @@ -517,11 +662,15 @@ def optimize_color(self, batch, pyramid, params_func, pho_weight_func, reg_from_ losses['reg/tex'] = torch.sum(tex ** 2) * self.config.w_tex losses['reg/pp'] = torch.sum(pp ** 2) - # Temporal smoothing (only to t - 1) - if reg_from_prev: - losses['reg/exp_prev_r'] = torch.sum((self.prev_exp - exp) ** 2) * 0.01 - losses['reg/trans_prev_r'] = torch.sum((self.prev_t - t) ** 2) * 100.0 - losses['reg/rot_prev_r'] = torch.sum((self.prev_R - R) ** 2) * 100.0 + # Temporal smoothing (only to t - 1) L1 + if 0 < self.dense_frame < len(self.dataset) - 1 and not self.is_initializing: + losses['reg/T-1'] = torch.sum((self.t[self.dense_frame - 1].clone().detach() - t).abs()) + losses['reg/R-1'] = torch.sum((self.R[self.dense_frame - 1].clone().detach() - R).abs()) + # losses['reg/exp-1'] = torch.sum((self.exp[self.dense_frame - 1].clone().detach() - exp).abs()) + + losses['reg/T+1'] = torch.sum((self.t[self.dense_frame + 1].clone().detach() - t).abs()) + losses['reg/R+1'] = torch.sum((self.R[self.dense_frame + 1].clone().detach() - R).abs()) + # losses['reg/exp+1'] = torch.sum((self.exp[self.dense_frame + 1].clone().detach() - exp).abs()) # Render RGB albedos = self.flametex(tex) @@ -530,26 +679,40 @@ def optimize_color(self, batch, pyramid, params_func, pho_weight_func, reg_from_ # Photometric dense term grid = ops['position_images'].permute(0, 2, 3, 1)[:, :, :, :2] sampled_image = F.grid_sample(flipped, grid * aspect_ratio, align_corners=False) - losses['loss/pho'] = util.pixel_loss(ops['images'], sampled_image, self.parse_mask(ops, batch)) * pho_weight_func(k) + losses['loss/pho'] = util.pixel_loss(ops['images'], sampled_image, self.parse_mask(ops)) * pho_weight_func(k) all_loss = self.reduce_loss(losses) - optimizer.zero_grad() - all_loss.backward() - optimizer.step() for key in losses.keys(): self.writer.add_scalar(key, losses[key], global_step=self.global_step) self.global_step += 1 + logs.append(f"Color loss for level {k} [frame {str(self.get_frame()).zfill(4)}] =" + reduce(lambda a, b: a + f' {b}={round(losses[b].item(), 4)}', [""] + list(losses.keys()))) + + optimizer.zero_grad() + all_loss.backward() - if p % iters == 0: - logs.append(f"Color loss for level {k} [frame {str(self.frame).zfill(4)}] =" + reduce(lambda a, b: a + f' {b}={round(losses[b].item(), 4)}', [""] + list(losses.keys()))) + return all_loss - loss_color = all_loss.item() + best_loss = np.inf - if loss_color < best_loss: - best_loss = loss_color - self.update(optimizer.param_groups) + for _ in range(max_iters): + self.diff_renderer.rasterizer.reset() + loss = optimizer.step(closure) + if best_loss > loss.item(): + best_loss = loss.item() + print(f'Loss = {best_loss}') + self.shape = nn.Parameter(shape.detach().clone()) + self.exp[self.dense_frame] = nn.Parameter(exp.detach().clone()) + self.eyes[self.dense_frame] = nn.Parameter(eyes.detach().clone()) + self.eyelids[self.dense_frame] = nn.Parameter(eyelids.detach().clone()) + self.jaw[self.dense_frame] = nn.Parameter(jaw.detach().clone()) + self.tex = nn.Parameter(tex.detach().clone()) + self.sh[self.dense_frame] = nn.Parameter(sh.detach().clone()) + self.t[self.dense_frame] = nn.Parameter(t.detach().clone()) + self.R[self.dense_frame] = nn.Parameter(R.detach().clone()) + self.focal_length = nn.Parameter(fl.detach().clone()) + self.principal_point = nn.Parameter(pp.detach().clone()) for log in logs: logger.info(log) @@ -562,12 +725,14 @@ def checkpoint(self, batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, savefolder = self.save_folder + self.actor_name + frame_dst Path(savefolder).mkdir(parents=True, exist_ok=True) + frame = self.get_frame() + with torch.no_grad(): self.cameras = PerspectiveCameras( device=self.device, principal_point=self.principal_point, focal_length=self.focal_length, - R=rotation_6d_to_matrix(self.R), T=self.t, + R=rotation_6d_to_matrix(self.R[frame]), T=self.t[frame], image_size=self.image_size) self.diff_renderer.rasterizer.reset() @@ -577,10 +742,10 @@ def checkpoint(self, batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, vertices, lmk68, lmkMP = self.flame( cameras=torch.inverse(self.cameras.R), shape_params=self.shape, - expression_params=self.exp, - eye_pose_params=self.eyes, - jaw_pose_params=self.jaw, - eyelid_params=self.eyelids + expression_params=self.exp[frame], + eye_pose_params=self.eyes[frame], + jaw_pose_params=self.jaw[frame], + eyelid_params=self.eyelids[frame] ) lmk68 = self.cameras.transform_points_screen(lmk68, image_size=self.image_size) @@ -588,10 +753,10 @@ def checkpoint(self, batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, albedos = self.flametex(self.tex) albedos = F.interpolate(albedos, self.get_image_size(), mode='bilinear') - ops = self.diff_renderer(vertices, albedos, self.sh, cameras=self.cameras) - mask = (self.parse_mask(ops, batch, visualization=True) > 0).float() + ops = self.diff_renderer(vertices, albedos, self.sh[frame], cameras=self.cameras) + mask = (self.parse_mask(ops, visualization=True) > 0).float() predicted_images = (ops['images'] * mask + (images * (1.0 - mask)))[0] - shape_mask = ((ops['alpha_images'] * ops['mask_images_mesh']) > 0.).int()[0] + shape_mask = ((ops['alpha_images'] * ops['mask_images']) > 0.).int()[0] final_views = [] @@ -624,7 +789,7 @@ def checkpoint(self, batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, # VIDEO final_views = util.merge_views(final_views) - frame_id = str(self.frame).zfill(5) + frame_id = str(frame).zfill(5) cv2.imwrite('{}/{}.jpg'.format(savefolder, frame_id), final_views) cv2.imwrite('{}/{}.png'.format(self.input_folder, frame_id), input_image) @@ -633,7 +798,7 @@ def checkpoint(self, batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, return # CHECKPOINT - self.save_checkpoint(frame_id) + self.save_checkpoint() # DEPTH depth_view = self.diff_renderer.render_depth(vertices, cameras=self.cameras, faces=torch.cat([util.get_flame_extra_faces(), self.diff_renderer.faces], dim=1)) @@ -649,17 +814,25 @@ def optimize_frame(self, batch): self.optimize_color(batch, pyramid, self.clone_params_tracking, lambda k: self.config.w_pho, reg_from_prev=True) self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.COLOR_OVERLAY, View.LANDMARKS, View.SHAPE]]) - def optimize_video(self): - self.is_initializing = False - for i in list(range(self.frame, len(self.dataset))): + def optimize_dense(self): + self.optimization_mode = Mode.DENSE + for i in list(range(self.dense_frame, len(self.dataset))): batch = self.to_cuda(self.dataset[i], unsqueeze=True) if type(batch) is torch.Tensor: continue self.optimize_frame(batch) - self.frame += 1 + self.dense_frame += 1 + + def get_frame(self): + if self.optimization_mode == Mode.SPARSE: + return self.sparse_frame + if self.optimization_mode == Mode.DENSE: + return self.dense_frame + return -1 def output_video(self): - util.images_to_video(self.output_folder, self.config.fps) + util.images_to_video(self.output_folder, self.config.fps, src='video') + util.images_to_video(self.output_folder, self.config.fps, src='sparse') def parse_batch(self, batch): images = batch['image'] @@ -687,33 +860,49 @@ def prepare_data(self): self.dataset = ImagesDataset(self.config) self.dataloader = DataLoader(self.dataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=True, drop_last=False) - def initialize_tracking(self): + def initialize_light(self): + sh = [nn.Parameter(self.sh[self.dense_frame].detach().clone()) for _ in range(len(self.dataset))] + self.sh = sh + + def initialize_dense(self): + if self.dense_frame > 0: + return + self.optimization_mode = Mode.DENSE self.is_initializing = True for i, j in enumerate(self.config.keyframes): + self.dense_frame = i batch = self.to_cuda(self.dataset[j], unsqueeze=True) images = self.parse_batch(batch)[0] h, w = images.shape[2:4] pyramid_size = np.array([h, w]) pyramid = util.get_gaussian_pyramid([(pyramid_size * size, util.round_up_to_odd(steps * 2)) for size, steps in self.pyr_levels], images, self.kernel_size, self.sigma) - weighting = lambda k: self.config.w_pho + weighting_fun = lambda k: self.config.w_pho + params_fun = self.clone_params_color if i == 0: - self.optimize_camera(batch) - weighting = lambda k: self.config.w_pho if k > 0 else self.config.w_pho / 32.0 for k, level in enumerate(pyramid): self.save_tensor(level[0], f"{self.pyramid_folder}/{k}.png") - self.optimize_color(batch, pyramid, self.clone_params_color, weighting) - self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.COLOR_OVERLAY, View.LANDMARKS, View.SHAPE]], frame_dst='/initialization') - self.frame += 1 + self.optimize_color(batch, pyramid, params_fun, weighting_fun) + self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.COLOR_OVERLAY, View.LANDMARKS, View.SHAPE]], frame_dst='/initialization', save=False) + + if i == 0: + self.initialize_light() + + self.is_initializing = False self.save_canonical() + self.dense_frame = 0 def run(self): self.prepare_data() if not self.load_checkpoint(): - self.initialize_tracking() - self.frame = 0 + self.optimize_camera() + + self.optimize_sparse() + + if self.config.optimize_dense: + self.initialize_dense() + self.optimize_dense() - self.optimize_video() self.output_video() diff --git a/util.py b/util.py index 882addb..056503d 100644 --- a/util.py +++ b/util.py @@ -36,13 +36,6 @@ nose_mask[:, [27, 28, 29, 30, 31, 32, 33, 34, 35], :] *= 4.0 oval_mask[:, [i for i in range(17)], :] *= 0.4 -nose_mask_mp = torch.ones([1, 105, 2]).cuda().float() -face_mask_mp = torch.ones([1, 105, 2]).cuda().float() - -nose_mask_mp[:, get_idx(NOSE_LANDMARK_IDS), :] *= 8.0 - - -# face_mask_mp[:, get_idx(LEFT_EYE_LANDMARK_IDS) + get_idx(RIGHT_EYE_LANDMARK_IDS), :] *= 0.1 # Input is R, t in opencv spave @@ -109,7 +102,7 @@ def face_lmk_loss(opt_lmks, target_lmks, image_size, is_mediapipe, lmk_mask): diff = torch.pow(opt_lmks - target_lmks, 2) if not is_mediapipe: return (diff * face_mask * nose_mask * oval_mask * lmk_mask).mean() - return (diff * nose_mask_mp * lmk_mask).mean() + return (diff * lmk_mask).mean() def oval_lmk_loss(opt_lmks, target_lmks, image_size, lmk_mask): @@ -139,6 +132,13 @@ def eye_closure_lmk_loss(opt_lmks, target_lmks, image_size, lmk_mask): return (diff * lmk_mask[:, upper_eyelid_lmk_ids, :]).mean() +def eye_lids_lmk_loss(opt_lmks, target_lmks, image_size, lmk_mask): + eyelid_lmk_ids = [i for i in range(20, 20 + 32)] + opt_lmks, target_lmks = scale_lmks(opt_lmks, target_lmks, image_size) + diff = torch.pow(opt_lmks[:, eyelid_lmk_ids, :] - target_lmks[:, eyelid_lmk_ids, :], 2) + return (diff * lmk_mask[:, eyelid_lmk_ids, :]).mean() + + def mouth_closure_lmk_loss(opt_lmks, target_lmks, image_size, lmk_mask): upper_mouth_lmk_ids = [49, 50, 51, 52, 53, 61, 62, 63] lower_mouth_lmk_ids = [59, 58, 57, 56, 55, 67, 66, 65] @@ -149,11 +149,9 @@ def mouth_closure_lmk_loss(opt_lmks, target_lmks, image_size, lmk_mask): return (diff * lmk_mask[:, upper_mouth_lmk_ids, :]).mean() -def pixel_loss(opt_img, target_img, mask=None): - if mask is None: - mask = torch.ones_like(opt_img) - n_pixels = torch.sum((mask[:, 0, ...] > 0).int()).detach().float() - loss = (mask * (opt_img - target_img)).abs() +def pixel_loss(opt_img, target_img, mask): + n_pixels = torch.sum(mask) + loss = mask * (opt_img - target_img).abs() loss = torch.sum(loss) / n_pixels return loss @@ -376,7 +374,7 @@ def images_to_video(path, fps=25, src='video', video_format='DIVX'): img_array.append(img) if len(img_array) > 0: - out = cv2.VideoWriter(f'{path}/video.avi', cv2.VideoWriter_fourcc(*video_format), fps, size) + out = cv2.VideoWriter(f'{path}/{src}.avi', cv2.VideoWriter_fourcc(*video_format), fps, size) for i in range(len(img_array)): out.write(img_array[i]) out.release()