leonelhs commited on
Commit
125f259
·
verified ·
1 Parent(s): 54a5078

Upload model.py

Browse files

since this file is out of the root path, here is another file updated needed for the previous PR.

Files changed (1) hide show
  1. models/model.py +29 -56
models/model.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
 
2
  from typing import Dict
3
- from typing import Optional
4
  from typing import Tuple
5
 
6
  import kornia
@@ -10,28 +10,30 @@ import torch.nn as nn
10
  import torch.nn.functional as F
11
  from loguru import logger
12
 
13
- from arcface_torch.backbones.iresnet import iresnet100
14
- from configs.train_config import TrainConfig
15
  from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
16
  from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
17
  from HRNet.hrnet import HighResolutionNet
 
18
  from models.discriminator import Discriminator
19
  from models.gan_loss import GANLoss
20
  from models.generator import Generator
21
  from models.init_weight import init_net
22
 
23
-
24
  class HifiFace:
25
  def __init__(
26
  self,
27
  identity_extractor_config,
28
- is_training=True,
29
- device="cpu",
30
- load_checkpoint: Optional[Tuple[str, int]] = None,
31
  ):
32
  super(HifiFace, self).__init__()
 
 
33
  self.generator = Generator(identity_extractor_config)
34
  self.is_training = is_training
 
 
35
 
36
  if self.is_training:
37
  self.lr = TrainConfig().lr
@@ -80,10 +82,9 @@ class HifiFace:
80
 
81
  self.dilation_kernel = torch.ones(5, 5)
82
 
83
- if load_checkpoint is not None:
84
- self.load(load_checkpoint[0], load_checkpoint[1])
85
 
86
- self.setup(device)
87
 
88
  def save(self, path, idx=None):
89
  os.makedirs(path, exist_ok=True)
@@ -100,18 +101,9 @@ class HifiFace:
100
  torch.save(self.generator.state_dict(), g_path)
101
  torch.save(self.discriminator.state_dict(), d_path)
102
 
103
- def load(self, path, idx=None):
104
- if idx is None:
105
- g_path = os.path.join(path, "generator.pth")
106
- d_path = os.path.join(path, "discriminator.pth")
107
- else:
108
- g_path = os.path.join(path, f"generator_{idx}.pth")
109
- d_path = os.path.join(path, f"discriminator_{idx}.pth")
110
- logger.info(f"Loading generator from {g_path}")
111
- self.generator.load_state_dict(torch.load(g_path, map_location="cpu"))
112
- if self.is_training:
113
- logger.info(f"Loading discriminator from {d_path}")
114
- self.discriminator.load_state_dict(torch.load(d_path, map_location="cpu"))
115
 
116
  def setup(self, device):
117
  self.generator.to(device)
@@ -399,37 +391,18 @@ class HifiFace:
399
  }
400
 
401
 
402
- if __name__ == "__main__":
403
- import torch
404
- import cv2
405
- from configs.train_config import TrainConfig
406
-
407
- identity_extractor_config = TrainConfig().identity_extractor_config
408
-
409
- model = HifiFace(identity_extractor_config, is_training=True)
410
-
411
- # src = cv2.imread("/home/xuehongyang/data/test1.jpg")
412
- # tgt = cv2.imread("/home/xuehongyang/data/test2.jpg")
413
- # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
414
- # tgt = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB)
415
- # src = cv2.resize(src, (256, 256))
416
- # tgt = cv2.resize(tgt, (256, 256))
417
- # src = src.transpose(2, 0, 1)[None, ...]
418
- # tgt = tgt.transpose(2, 0, 1)[None, ...]
419
- # source_img = torch.from_numpy(src).float() / 255.0
420
- # target_img = torch.from_numpy(tgt).float() / 255.0
421
- # same_id_mask = torch.Tensor([1]).unsqueeze(0)
422
- # tgt_mask = target_img[:, 0, :, :].unsqueeze(1)
423
- # if torch.cuda.is_available():
424
- # model.to("cuda:3")
425
- # source_img = source_img.to("cuda:3")
426
- # target_img = target_img.to("cuda:3")
427
- # tgt_mask = tgt_mask.to("cuda:3")
428
- # same_id_mask = same_id_mask.to("cuda:3")
429
- # source_img = source_img.repeat(16, 1, 1, 1)
430
- # target_img = target_img.repeat(16, 1, 1, 1)
431
- # tgt_mask = tgt_mask.repeat(16, 1, 1, 1)
432
- # same_id_mask = same_id_mask.repeat(16, 1)
433
- # while True:
434
- # x = model.optimize(source_img, target_img, tgt_mask, same_id_mask)
435
- # print(x[0]["loss_generator"])
 
1
  import os
2
+ from abc import abstractmethod
3
  from typing import Dict
 
4
  from typing import Tuple
5
 
6
  import kornia
 
10
  import torch.nn.functional as F
11
  from loguru import logger
12
 
 
 
13
  from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel
14
  from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
15
  from HRNet.hrnet import HighResolutionNet
16
+ from arcface_torch.backbones.iresnet import iresnet100
17
  from models.discriminator import Discriminator
18
  from models.gan_loss import GANLoss
19
  from models.generator import Generator
20
  from models.init_weight import init_net
21
 
 
22
  class HifiFace:
23
  def __init__(
24
  self,
25
  identity_extractor_config,
26
+ generator_path,
27
+ is_training=False,
28
+ device="cpu"
29
  ):
30
  super(HifiFace, self).__init__()
31
+ self.d_optimizer = None
32
+ self.g_optimizer = None
33
  self.generator = Generator(identity_extractor_config)
34
  self.is_training = is_training
35
+ self.device = device
36
+ self.generator_path = generator_path
37
 
38
  if self.is_training:
39
  self.lr = TrainConfig().lr
 
82
 
83
  self.dilation_kernel = torch.ones(5, 5)
84
 
85
+ self.load_checkpoint()
 
86
 
87
+ self.setup(self.device)
88
 
89
  def save(self, path, idx=None):
90
  os.makedirs(path, exist_ok=True)
 
101
  torch.save(self.generator.state_dict(), g_path)
102
  torch.save(self.discriminator.state_dict(), d_path)
103
 
104
+ @abstractmethod
105
+ def load_checkpoint(self):
106
+ pass
 
 
 
 
 
 
 
 
 
107
 
108
  def setup(self, device):
109
  self.generator.to(device)
 
391
  }
392
 
393
 
394
+ class HifiFaceST(HifiFace):
395
+ def __init__(self, identity_extractor_config, device, generator_path):
396
+ super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
397
+
398
+ def load_checkpoint(self):
399
+ self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
400
+ logger.info(f"Loading generator from {self.generator_path}")
401
+
402
+ class HifiFaceWGM(HifiFace):
403
+ def __init__(self, identity_extractor_config, device, generator_path):
404
+ super().__init__(identity_extractor_config, device=device, generator_path=generator_path)
405
+
406
+ def load_checkpoint(self):
407
+ self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
408
+ logger.info(f"Loading generator from {self.generator_path}")