You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In https://pytorch.org/blog/accelerating-generative-ai/, it stated the mask training WeightNormSparsifier. When I use the function apply_sparse() which include WeightNormSparsifier to sparse my linear model(only one linear layer), the time is not saved, but even more. If I do not use WeightNormSparsifier, the time is save by 18.x compared to the dense one. However, I would like to have my mask trained, I need to use WeightNormSparsifier, can you help me to find the issue? many thanks !!!! My GPU is A800-50G, Python 3.10.15
seed_everything(42)
import operator
from functools import reduce
from typing import Callable, Optional, Tuple, Union
import time
import numpy as np
import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.linear1 = torch.nn.Linear(10240, 3072)
def forward(self, x):
x = self.linear1(x)
return x
def apply_sparse():
device = torch.device('cuda:0')
SparseSemiStructuredTensor._FORCE_CUTLASS = True
sparsifier = WeightNormSparsifier(
sparsity_level=1.0,
sparse_block_shape=(1,4),
zeros_per_block=2
)
x = torch.randn((3072,10240)).half().to(device)
model = TinyModel().half().to(device)
start_time = time.time()
original_output = model(x)
end_time = time.time()
ori_t = end_time - start_time
sparse_config = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.step()
sparsifier.squash_mask()
with torch.inference_mode():
start_time = time.time()
dense_output = model(x)
end_time = time.time()
dense_t = end_time - start_time
to_sparse_semi_structured_compiled = torch.compile(lambda x: to_sparse_semi_structured(x))
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured_compiled(mod.weight))
start_time = time.time()
sparse_output = model(x)
end_time = time.time()
sparse_t = end_time - start_time
print(f"Ori: {ori_t:.5f}ms Dense with mask: {dense_t:.5f}ms Sparse: {sparse_t:.5f}ms | Speedup: {(dense_t / sparse_t):.5f}x")
The result is Ori: 0.25929ms Dense with mask: 0.00010ms Sparse: 0.01709ms. However, if I use normal mask instead check_sparse_linear() (no WeightNormSparsifier), I get a satisfying result: Dense with mask: 0.087ms Sparse: 0.005ms | Speedup: 18.264x
def check_sparse_linear():
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = TinyModel().half().cuda()
for name, mod in linear.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(mask * mod.weight)
x = torch.rand(3072, 10240).half().cuda()
import time
with torch.inference_mode():
start_time = time.time()
dense_output = linear(x)
end_time = time.time()
dense_t = end_time - start_time
for name, mod in linear.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
start_time = time.time()
sparse_output = linear(x)
end_time = time.time()
sparse_t = end_time - start_time
print(f"Dense with mask: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
The text was updated successfully, but these errors were encountered:
Hi @phyllispeng123, the weight norm sparsifier does not speed up the model at all, it is meant to be able to do masking / fine-tuning. The expected workflow is to train your model with the WeightNormSparsifier, which will not be accelerated, and then accelerate the final model for inference.
In https://pytorch.org/blog/accelerating-generative-ai/, it stated the mask training
WeightNormSparsifier
. When I use the functionapply_sparse()
which includeWeightNormSparsifier
to sparse my linear model(only one linear layer), the time is not saved, but even more. If I do not useWeightNormSparsifier
, the time is save by 18.x compared to the dense one. However, I would like to have my mask trained, I need to useWeightNormSparsifier
, can you help me to find the issue? many thanks !!!! My GPU is A800-50G, Python 3.10.15The result is
Ori: 0.25929ms Dense with mask: 0.00010ms Sparse: 0.01709ms
. However, if I use normal mask insteadcheck_sparse_linear()
(no WeightNormSparsifier), I get a satisfying result:Dense with mask: 0.087ms Sparse: 0.005ms | Speedup: 18.264x
The text was updated successfully, but these errors were encountered: