Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single-matrix cholesky much slower than batch mode with batch_size=1? #54778

Closed
rfeinman opened this issue Mar 26, 2021 · 9 comments
Closed

Single-matrix cholesky much slower than batch mode with batch_size=1? #54778

rfeinman opened this issue Mar 26, 2021 · 9 comments
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@rfeinman
Copy link

rfeinman commented Mar 26, 2021

馃悰 Bug

Calling single-matrix torch.cholesky is significantly slower than batch mode with batch_size=1 and the same matrix. This occurs on both cpu and cuda, but the difference is more pronounced on cuda. The behavior is the same with torch.linalg.cholesky.

To Reproduce

Minimum working example:

import torch
import torch.utils.benchmark as benchmark

torch.manual_seed(391)
A = torch.randn(100, 100)
A = torch.mm(A, A.t()) + 1e-3 * torch.eye(100)  # make symmetric


### CPU ###

res1 = benchmark.Timer(
    stmt="torch.cholesky(A)",
    setup="import torch",
    globals=dict(A=A),
    num_threads=torch.get_num_threads(),
    label='Cholesky (cpu)',
    sub_label='single',
    description='time',
).blocked_autorange(min_run_time=1)

res2 = benchmark.Timer(
    stmt="torch.cholesky(A)",
    setup="import torch",
    globals=dict(A=A.unsqueeze(0)),
    num_threads=torch.get_num_threads(),
    label='Cholesky (cpu)',
    sub_label='batch',
    description='time',
).blocked_autorange(min_run_time=1)

compare = benchmark.Compare([res1, res2])
compare.print()


### CUDA ###

A = A.cuda()

res1 = benchmark.Timer(
    stmt="torch.cholesky(A)",
    setup="import torch",
    globals=dict(A=A),
    num_threads=torch.get_num_threads(),
    label='Cholesky (cuda)',
    sub_label='single',
    description='time',
).blocked_autorange(min_run_time=1)

res2 = benchmark.Timer(
    stmt="torch.cholesky(A)",
    setup="import torch",
    globals=dict(A=A.unsqueeze(0)),
    num_threads=torch.get_num_threads(),
    label='Cholesky (cuda)',
    sub_label='batch',
    description='time',
).blocked_autorange(min_run_time=1)

compare = benchmark.Compare([res1, res2])
compare.print()

Produces:

[-- Cholesky (cpu) -]
              |  time
8 threads: ----------
      single  |  90.9
      batch   |  77.1

Times are in microseconds (us).

[-- Cholesky (cuda) --]
              |   time 
8 threads: ------------
      single  |  2910.4
      batch   |   151.4

Times are in microseconds (us).

Expected behavior

I would expect that the single-matrix variant is at least as fast as the batch variant for batch_size=1. Otherwise we should always call torch.cholesky(A.unsqueeze(0)).squeeze(0).

Environment

PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: (GCC) 4.8.5 2015062 (Red Hat 4.8.5-44)
Clang version: 3.4.2 (tags/RELEASE_34/dot2-final)
CMake version: version 2.8.12.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 440.36
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @VitalyFedyunin @ngimel

@rfeinman rfeinman changed the title Single-matrix cholesky much slower than batch cholesky with bsize=1? Single-matrix cholesky much slower than batch mode with batch_size=1? Mar 26, 2021
@ngimel
Copy link
Collaborator

ngimel commented Mar 26, 2021

Fixed on master

[-- Cholesky (cpu) --]
              |   time
40 threads: ----------
      single  |  146.3
      batch   |  142.8

Times are in microseconds (us).

[- Cholesky (cuda) --]
              |   time
40 threads: ----------
      single  |  135.2
      batch   |  122.9

Times are in microseconds (us).

@ngimel ngimel added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: performance Issues related to performance, either of kernel code or framework glue labels Mar 26, 2021
@xwang233
Copy link
Collaborator

Single-matrix cholesky on CUDA was replaced by cusolver function. Single-matrix cholesky_solve and cholesky_inverse PRs are ongoing.

You can try this in a nightly build and leave us feedback if there is any issue. https://pytorch.org/get-started/locally/

@rfeinman
Copy link
Author

@xwang233 in that case it may be an issue with my cuda version? I'm still on 10.2, haven't updated to 11.

I haven't noticed any issues with cholesky_solve or cholesky_inverse. Just the factorization step.

@xwang233
Copy link
Collaborator

cuda 10.2 is fine. You'll just need to try with a nightly build. pytorch 1.8.0 or 1.8.1 doesn't have this update yet.

@rfeinman
Copy link
Author

@xwang233 I see this now in the nightly, thank you. I read that cusolver offers LDL factorization for sym indefinite matrices - is there a plan to include this in pytorch? Perhaps there is a feature request somewhere.

@xwang233
Copy link
Collaborator

Thanks for the comments @rfeinman ! I think we don't have a ldl operator in pytorch right now, but that sounds like a good idea! If you didn't find an existing issue for this, feel free to file a new issue with your feature request and comments so that we can discuss and track it.

@mruberry
Copy link
Collaborator

Please file a new issue with that request, @rfeinman. We're actively working on improving linalg support and feedback from community members like yourself is helpful in prioritizing our work. Including an example of how you'd like to use such a function that would be especially helpful.

@rfeinman
Copy link
Author

@mruberry @xwang233 Done at #54847

We can close this ticket if you see fit.

@mruberry
Copy link
Collaborator

Thanks @rfeinman!

@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants