Skip to content

Commit

Permalink
Merge pull request #81 from slothkong/master
Browse files Browse the repository at this point in the history
Correct paths and imports for step 4
  • Loading branch information
yanx27 authored Jul 21, 2020
2 parents 0080a4d + 2de8ea7 commit a740351
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,29 @@ We train and evaluate on Ubuntu 16.04, so if you don't have linux environment, y
#### Make source pictures
* Put source video mv.mp4 in `./data/source/` and run `make_source.py`, the label images and coordinate of head will save in `./data/source/test_label_ori/` and `./data/source/pose_souce.npy` (will use in step6). If you want to capture video by camera, you can directly run `./src/utils/save_img.py`
#### Make target pictures
* Put target video mv.mp4 in `./data/target/` and run `make_target.py`, `pose.npy` will save in `./data/target/`, which contain the coordinate of faces (will use in step6).
* Rename your own target video as mv.mp4 and put it in `./data/target/` and run `make_target.py`, `pose.npy` will save in `./data/target/`, which contain the coordinate of faces (will use in step6).
![](/result/fig3.png)
#### Train and use pose2vid network
* Run `train_pose2vid.py` and check loss and full training process in `./checkpoints/`

* If you break the traning and want to continue last training, set `load_pretrain = './checkpoints/target/` in `./src/config/train_opt.py`
* Run `normalization.py` rescale the label images, you can use two sample images from `./data/target/train/train_label/` and `./data/source/test_label_ori/` to complete normalization between two skeleton size
* Run `transfer.py` and get results in `./result`
* Run `transfer.py` and get results in `./results`
#### Face enhancement network

![](/result/fig2.png)
#### Train and use face enhancement network
* Run `./face_enhancer/prepare.py` and check the results in `./data/face/test_sync` and `./data/face/test_real`.
* Run `./face_enhancer/main.py` train face enhancer and run`./face_enhancer/enhance.py` to gain results <br>
* Run `cd ./face_enhancer`.
* Run `prepare.py` and check the results in `data` directory at the root of the repo (`data/face/test_sync` and `data/face/test_real`).
* Run `main.py` to rain the face enhancer. Then run `enhance.py` to obtain the results <br>
This is comparision in original (left), generated image before face enhancement (median) and after enhancement (right). FaceGAN can learn the residual error between the real picture and the generated picture faces.

#### Performance of face enhancement
#### Performance of face enhancement
![](/result/37500_enhanced_full.png)
![](/result/37500_enhanced_head.png)

#### Gain results
* Run `make_gif.py` and make result pictures to gif picture
* `cd` back to the root dir and run `make_gif.py` to create a gif out of the resulting images.

![Result](/result/output.gif)

Expand All @@ -66,5 +67,3 @@ Ubuntu 16.04 <br>
Python 3.6.5 <br>
Pytorch 0.4.1 <br>
OpenCV 3.4.4 <br>


5 changes: 3 additions & 2 deletions face_enhancer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ def __init__(self, root, cache=None, is_test=False):
with open(cache, 'rb') as f:
self.root, self.images, self.size = pickle.load(f)
else:
self.images = sorted(os.listdir(os.path.join(root, 'test_real')))
self.root = root
self.images = sorted(os.listdir(os.path.join(root, 'test_real')))
if self.images[0] == " ":
self.images.pop(0)
tmp = imread(os.path.join(self.root, 'test_real', self.images[0]))
self.size = tmp.shape[:-1]
if cache is not None:
Expand Down Expand Up @@ -84,4 +86,3 @@ def __getitem__(self, item):

def __len__(self):
return len(self.image_dataset)

20 changes: 10 additions & 10 deletions face_enhancer/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import cv2
from tqdm import tqdm

face_sync_dir = Path('../data/face/ ')
face_sync_dir = Path('../data/face/')
face_sync_dir.mkdir(exist_ok=True)
test_sync_dir = Path('../data/face/test_sync/ ')
test_sync_dir = Path('../data/face/test_sync/')
test_sync_dir.mkdir(exist_ok=True)
test_real_dir = Path('../data/face/test_real/ ')
test_real_dir = Path('../data/face/test_real/')
test_real_dir.mkdir(exist_ok=True)
test_img = Path('../data/target/test_img/ ')
test_img = Path('../data/target/test_img/')
test_img.mkdir(exist_ok=True)
test_label = Path('../data/target/test_label/ ')
test_label = Path('../data/target/test_label/')
test_label.mkdir(exist_ok=True)

train_dir = '../data/target/train/train_img/'
Expand All @@ -32,15 +32,17 @@
from pathlib import Path
from tqdm import tqdm
import sys
pix2pixhd_dir = Path('../src/pix2pixHD/')
sys.path.append(str(pix2pixhd_dir))
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)
pix2pixhd_dir=os.path.join(root_dir, "src/pix2pixHD/")
sys.path.append(pix2pixhd_dir)

from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import src.config.test_opt as opt
from src.config import test_opt as opt
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
opt.checkpoints_dir = '../checkpoints/'
opt.dataroot='../data/target/'
Expand Down Expand Up @@ -74,5 +76,3 @@
for img_idx in tqdm(range(len(os.listdir(synthesized_image_dir)))):
img = cv2.imread(synthesized_image_dir+' {:05}_synthesized_image.jpg'.format(img_idx))
cv2.imwrite(str(test_sync_dir) + '{:05}.png'.format(img_idx), img)


4 changes: 2 additions & 2 deletions normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

target_img = cv2.imread('./data/target/train/train_label/00001.png')[:,:,0]
target_img_rgb = cv2.imread('./data/target/train/train_img/00001.png')
source_img = cv2.imread('./data/target/train/train_label/00001.png')[:,:,0]
source_img_rgb = cv2.imread('./data/target/train/train_img/00001.png')
source_img = cv2.imread('./data/source/test_label_ori/00001.png')[:,:,0]
source_img_rgb = cv2.imread('./data/source/test_img/00001.png')

path = './data/source/test_label_ori/'
save_dir = Path('./data/source/')
Expand Down

0 comments on commit a740351

Please sign in to comment.