Skip to content

Factorized Tensor slower than Neural Network Layer !!! #35

@KaidDuong

Description

@KaidDuong
import tltorch
import torch
from torch.profiler import profile, record_function, ProfilerActivity

data = torch.randn((4, 16), dtype=torch.float32)
linear = torch.nn.Linear(16, 10)

fact_linear = tltorch.FactorizedLinear.from_linear(linear, auto_tensorize=False,
                    in_tensorized_features=(4, 4), out_tensorized_features=(2, 5), rank=0.1, factorization="tucker")

data = data.to("cuda")
linear = linear.to("cuda")
fact_linear = fact_linear.to("cuda")
with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        linear(data)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        16.99%       1.054ms        99.71%       6.186ms       6.186ms       0.000us         0.00%       4.000us       4.000us             1  
                                           aten::linear         0.29%      18.000us        82.72%       5.132ms       5.132ms       0.000us         0.00%       4.000us       4.000us             1  
                                            aten::addmm        60.54%       3.756ms        81.43%       5.052ms       5.052ms       4.000us       100.00%       4.000us       4.000us             1  
void gemmSN_TN_kernel<float, 128, 16, 2, 4, 4, 4, tr...         0.00%       0.000us         0.00%       0.000us       0.000us       4.000us       100.00%       4.000us       4.000us             1  
                                                aten::t         0.63%      39.000us         1.00%      62.000us      62.000us       0.000us         0.00%       0.000us       0.000us             1  
                                        aten::transpose         0.24%      15.000us         0.37%      23.000us      23.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       aten::as_strided         0.13%       8.000us         0.13%       8.000us       8.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        20.89%       1.296ms        20.89%       1.296ms       1.296ms       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize         0.29%      18.000us         0.29%      18.000us      18.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.204ms
Self CUDA time total: 4.000us

with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        fact_linear(data)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        15.19%       1.098ms        99.79%       7.215ms       7.215ms       0.000us         0.00%      27.000us      27.000us             1  
                                           aten::matmul         0.40%      29.000us        62.63%       4.528ms       1.132ms       0.000us         0.00%      12.000us       3.000us             4  
                                               aten::mm        48.37%       3.497ms        62.23%       4.499ms       1.125ms      12.000us        44.44%      12.000us       3.000us             4  
                                          aten::reshape         0.91%      66.000us         6.10%     441.000us      44.100us       0.000us         0.00%      10.000us       1.000us            10  
                                            aten::clone         0.55%      40.000us         3.91%     283.000us      94.333us       0.000us         0.00%      10.000us       3.333us             3  
                                            aten::copy_         1.40%     101.000us         2.28%     165.000us      55.000us      10.000us        37.04%      10.000us       3.333us             3  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us        37.04%      10.000us       3.333us             3  
void gemmk1_kernel<int, float, 256, 5, false, false,...         0.00%       0.000us         0.00%       0.000us       0.000us       9.000us        33.33%       9.000us       3.000us             3  
                                           aten::linear         0.36%      26.000us         2.23%     161.000us     161.000us       0.000us         0.00%       5.000us       5.000us             1  
                                            aten::addmm         1.00%      72.000us         1.27%      92.000us      92.000us       5.000us        18.52%       5.000us       5.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.230ms
Self CUDA time total: 27.000us

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions