Skip to content

Commit

Permalink
Refactor simclr
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Nov 22, 2020
1 parent e7966ac commit 64cf526
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def add_model_specific_args(parent_parser):
parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur")
parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength")
parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10")
parser.add_argument("--data_path", type=str, default=".", help="path to download data")
parser.add_argument("--data_dir", type=str, default=".", help="path to download data")

# training params
parser.add_argument("--fast_dev_run", action='store_true')
Expand Down Expand Up @@ -368,7 +368,7 @@ def cli_main():

if args.dataset == 'stl10':
dm = STL10DataModule(
data_dir=args.data_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers
)
Expand All @@ -391,7 +391,7 @@ def cli_main():
val_split = args.nodes * args.gpus * args.batch_size

dm = CIFAR10DataModule(
data_dir=args.data_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
val_split=val_split
Expand Down Expand Up @@ -429,7 +429,7 @@ def cli_main():
args.online_ft = True

dm = ImagenetDataModule(
data_dir=args.data_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers
)
Expand Down

0 comments on commit 64cf526

Please sign in to comment.