-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtest.py
51 lines (42 loc) · 1.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import sys
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import utils as vutils
import utils
from dataset import Dataset
from model import Resnet_Unet as model
#####
#可调整的参数
test_batch_size = 1
title = 'ResNet_final'
path = '../'
test_path = path+'imgs/test/'
Model_path = '../7.pth'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
transform = transforms.Compose([
transforms.ToTensor()
])
save_path = path+'log/'+title+'_test/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#####
test_set = Dataset(path=test_path, transform=transform, mode='test')
test_loader = DataLoader(test_set, batch_size=test_batch_size, shuffle=False)
Model = model(BN_enable=True, resnet_pretrain=False).to(device)
utils.path_checker(save_path)
Model.load_state_dict(torch.load(Model_path))
Model.eval()
for index, (img,name) in enumerate(test_loader):
Model.eval()
img = img.to(device)
with torch.no_grad():
output = Model(img)
output = torch.ge(output, 0.5).type(dtype=torch.float32) #二值化
output = utils.post_process(output)#后处理
for i in range(test_batch_size):
vutils.save_image(output[i,:,:,:], save_path+name[i].split('/')[1], padding=0)
sys.stdout.write("\r[test] [Epoch {}/{}] [Batch {}/{}]".format(7, 10, index+1, len(test_loader)))
sys.stdout.flush()