Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not saved time using WeightNormSparsifier #1420

Open
phyllispeng123 opened this issue Dec 16, 2024 · 2 comments
Open

Not saved time using WeightNormSparsifier #1420

phyllispeng123 opened this issue Dec 16, 2024 · 2 comments
Assignees
Labels

Comments

@phyllispeng123
Copy link

In https://pytorch.org/blog/accelerating-generative-ai/, it stated the mask training WeightNormSparsifier. When I use the function apply_sparse() which include WeightNormSparsifier to sparse my linear model(only one linear layer), the time is not saved, but even more. If I do not use WeightNormSparsifier, the time is save by 18.x compared to the dense one. However, I would like to have my mask trained, I need to use WeightNormSparsifier, can you help me to find the issue? many thanks !!!! My GPU is A800-50G, Python 3.10.15

seed_everything(42)
import operator
from functools import reduce
from typing import Callable, Optional, Tuple, Union
import time
import numpy as np
import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
class TinyModel(torch.nn.Module):
    def __init__(self):
        super(TinyModel, self).__init__()
        self.linear1 = torch.nn.Linear(10240, 3072)

    def forward(self, x):
        x = self.linear1(x)

        return x


def apply_sparse():
    device = torch.device('cuda:0')
    SparseSemiStructuredTensor._FORCE_CUTLASS = True
    sparsifier = WeightNormSparsifier(
        sparsity_level=1.0,
        sparse_block_shape=(1,4),
        zeros_per_block=2
    )

    x = torch.randn((3072,10240)).half().to(device)
    model = TinyModel().half().to(device)
    start_time = time.time()
    original_output = model(x)
    end_time = time.time()
    ori_t = end_time - start_time
    sparse_config = []
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{name}.weight"})

    sparsifier.prepare(model, sparse_config)
    sparsifier.step()
    sparsifier.step()
    sparsifier.squash_mask()
    with torch.inference_mode():
        start_time = time.time()
        dense_output = model(x)
        end_time = time.time()
        dense_t = end_time - start_time
        to_sparse_semi_structured_compiled = torch.compile(lambda x: to_sparse_semi_structured(x))
        for name, mod in model.named_modules():
            if isinstance(mod, torch.nn.Linear):
                mod.weight = torch.nn.Parameter(to_sparse_semi_structured_compiled(mod.weight))
        
        start_time = time.time()
        sparse_output = model(x)
        end_time = time.time()
        sparse_t = end_time - start_time
        print(f"Ori: {ori_t:.5f}ms Dense with mask: {dense_t:.5f}ms Sparse: {sparse_t:.5f}ms | Speedup: {(dense_t / sparse_t):.5f}x")
      

The result is Ori: 0.25929ms Dense with mask: 0.00010ms Sparse: 0.01709ms. However, if I use normal mask instead check_sparse_linear() (no WeightNormSparsifier), I get a satisfying result: Dense with mask: 0.087ms Sparse: 0.005ms | Speedup: 18.264x

def check_sparse_linear():
    import torch
    from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
    from torch.utils.benchmark import Timer
    SparseSemiStructuredTensor._FORCE_CUTLASS = True

    mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
    linear = TinyModel().half().cuda()
    for name, mod in linear.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(mask * mod.weight)


    x = torch.rand(3072, 10240).half().cuda()
    import time
    with torch.inference_mode():
        start_time = time.time()
        dense_output = linear(x)
        end_time = time.time()
        dense_t = end_time - start_time
  
        for name, mod in linear.named_modules():
            if isinstance(mod, torch.nn.Linear):
                mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

        start_time = time.time()
        sparse_output = linear(x)
        end_time = time.time()
        sparse_t = end_time - start_time
        print(f"Dense with mask: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
 
@drisspg
Copy link
Contributor

drisspg commented Dec 16, 2024

cc @jcaip

@jcaip jcaip self-assigned this Dec 17, 2024
@jcaip
Copy link
Contributor

jcaip commented Dec 18, 2024

Hi @phyllispeng123, the weight norm sparsifier does not speed up the model at all, it is meant to be able to do masking / fine-tuning. The expected workflow is to train your model with the WeightNormSparsifier, which will not be accelerated, and then accelerate the final model for inference.

If you're interested in speedup up training with 2:4 sparsity you can check out this blog post: https://pytorch.org/blog/accelerating-neural-network-training/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants