New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Single-matrix cholesky much slower than batch mode with batch_size=1? #54778
Comments
Fixed on master
|
Single-matrix cholesky on CUDA was replaced by cusolver function. Single-matrix cholesky_solve and cholesky_inverse PRs are ongoing. You can try this in a nightly build and leave us feedback if there is any issue. https://pytorch.org/get-started/locally/ |
@xwang233 in that case it may be an issue with my cuda version? I'm still on 10.2, haven't updated to 11. I haven't noticed any issues with cholesky_solve or cholesky_inverse. Just the factorization step. |
cuda 10.2 is fine. You'll just need to try with a nightly build. pytorch 1.8.0 or 1.8.1 doesn't have this update yet. |
@xwang233 I see this now in the nightly, thank you. I read that cusolver offers LDL factorization for sym indefinite matrices - is there a plan to include this in pytorch? Perhaps there is a feature request somewhere. |
Thanks for the comments @rfeinman ! I think we don't have a |
Please file a new issue with that request, @rfeinman. We're actively working on improving linalg support and feedback from community members like yourself is helpful in prioritizing our work. Including an example of how you'd like to use such a function that would be especially helpful. |
Thanks @rfeinman! |
馃悰 Bug
Calling single-matrix
torch.cholesky
is significantly slower than batch mode with batch_size=1 and the same matrix. This occurs on both cpu and cuda, but the difference is more pronounced on cuda. The behavior is the same withtorch.linalg.cholesky
.To Reproduce
Minimum working example:
Produces:
Expected behavior
I would expect that the single-matrix variant is at least as fast as the batch variant for batch_size=1. Otherwise we should always call
torch.cholesky(A.unsqueeze(0)).squeeze(0)
.Environment
PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: (GCC) 4.8.5 2015062 (Red Hat 4.8.5-44)
Clang version: 3.4.2 (tags/RELEASE_34/dot2-final)
CMake version: version 2.8.12.2
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 440.36
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @VitalyFedyunin @ngimel
The text was updated successfully, but these errors were encountered: