From c898e529d36e78eb4dc82db8f8b4bc0ab313f016 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:41:24 +0800 Subject: [PATCH] Fix: use gemm instead of einsum in BPCG (#5827) * Add dimension parameter for BPCG method * Add utils for hsovler gemm_op * Change code to fit new bpcg init interface * using gemm instead of einsum in orth_cholesky * using gemm instead of einsum in orth_projection * replace einsum by gemm in orth_projection * replace einsum by gemm in rotate_wf * replace einsum by gemm in diag_hsub * Update 102_PW_BPCG totalstressref reference value --- source/module_hsolver/diago_bpcg.cpp | 99 +++++++++++++++++-- source/module_hsolver/diago_bpcg.h | 13 ++- source/module_hsolver/hsolver_pw.cpp | 3 +- .../module_hsolver/test/diago_bpcg_test.cpp | 3 +- tests/integrate/102_PW_BPCG/result.ref | 2 +- 5 files changed, 110 insertions(+), 10 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 635e3a7943..7830af2f67 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -22,6 +22,10 @@ DiagoBPCG::DiagoBPCG(const Real* precondition_in) this->device_type = ct::DeviceTypeToEnum::value; this->h_prec = std::move(ct::TensorMap((void *) precondition_in, r_type, device_type, {this->n_basis})); + + this->one = &one_; + this->zero = &zero_; + this->neg_one = &neg_one_; } template @@ -30,11 +34,11 @@ DiagoBPCG::~DiagoBPCG() { } template -void DiagoBPCG::init_iter(const int nband, const int nbasis) { +void DiagoBPCG::init_iter(const int nband, const int nbasis, const int ndim) { // Specify the problem size n_basis, n_band, while lda is n_basis this->n_band = nband; this->n_basis = nbasis; - + this->n_dim = ndim; // All column major tensors @@ -93,7 +97,23 @@ void DiagoBPCG::orth_cholesky( // hsub_out = psi_out * transc(psi_out) ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); + + // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + psi_out.data(), + this->n_basis, //lda + psi_out.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_out.data(), + this->n_band); //ldc // set hsub matrix to lower format; ct::kernels::set_matrix()( @@ -145,12 +165,45 @@ void DiagoBPCG::orth_projection( { ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); - hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); + // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); + + // this->orth_projection(this->psi, this->hsub, this->grad); + // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band) + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + psi_in.data(), + this->n_basis, //lda + grad_out.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_in.data(), + this->n_band); //ldc // set_matrix_op()('L', hsub_in->data(), this->n_band); option = ct::EinsumOption( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); - grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); + // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); + + // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band) + gemm_op()(this->ctx, + 'N', + 'N', + this->n_dim, //m + this->n_band, //n + this->n_band, //k + this->neg_one, //-1.0 + psi_in.data(), + this->n_basis, //lda + hsub_in.data(), + this->n_band, //ldb + this->one, //1.0 + grad_out.data(), + this->n_basis); //ldc return; } @@ -165,6 +218,24 @@ void DiagoBPCG::rotate_wf( /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in); workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); + // this->rotate_wf(hsub_out, psi_out, workspace_in); + // this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub); + // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band) + // gemm_op()(this->ctx, + // 'N', + // 'N', + // this->n_basis, //m + // this->n_band, //n + // this->n_band, //k + // this->one, //1.0 + // psi_out.data(), + // this->n_basis, //lda + // hsub_in.data(), + // this->n_band, //ldb + // this->zero, //0.0 + // workspace_in.data(), + // this->n_basis); //ldc + syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band * this->n_basis); return; @@ -192,7 +263,23 @@ void DiagoBPCG::diag_hsub( // it controls the ops to use the corresponding device to calculate results ct::EinsumOption option( /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); + // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); + + // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) + gemm_op()(this->ctx, + 'C', + 'N', + this->n_band, //m + this->n_band, //n + this->n_dim, //k + this->one, //1.0 + hpsi_in.data(), + this->n_basis, //lda + psi_in.data(), + this->n_basis, //ldb + this->zero, //0.0 + hsub_out.data(), + this->n_band); //ldc ct::kernels::lapack_dnevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index c57ed5e5ee..44ddd9736f 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -52,8 +52,9 @@ class DiagoBPCG * * @param nband The number of bands. * @param nbasis The number of basis functions. Leading dimension of psi. + * @param ndim The number of valid dimension of psi. */ - void init_iter(const int nband, const int nbasis); + void init_iter(const int nband, const int nbasis, const int ndim); using HPsiFunc = std::function; @@ -77,6 +78,8 @@ class DiagoBPCG int n_band = 0; /// the number of cols of the input psi int n_basis = 0; + /// valid dimension of psi + int n_dim = 0; /// max iter steps for all-band cg loop int nline = 4; @@ -107,6 +110,13 @@ class DiagoBPCG /// work for some calculations within this class, including rotate_wf call ct::Tensor work = {}; + // These are for hsolver gemm_op use + /// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;), + Device * ctx = {}; + // Pointer to objects of 1 and 0 for gemm + const T *one = nullptr, *zero = nullptr, *neg_one = nullptr; + const T one_ = static_cast(1.0), zero_ = static_cast(0.0), neg_one_ = static_cast(-1.0); + /** * @brief Update the precondition array. * @@ -332,6 +342,7 @@ class DiagoBPCG using calc_grad_with_block_op = hsolver::calc_grad_with_block_op; using line_minimize_with_block_op = hsolver::line_minimize_with_block_op; + using gemm_op = hsolver::gemm_op; }; diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index f7ef10711b..de627d3474 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -483,6 +483,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, { const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); + const int ndim = psi.get_current_ngk(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); @@ -499,7 +500,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ModuleBase::timer::tick("DavSubspace", "hpsi_func"); }; DiagoBPCG bpcg(pre_condition.data()); - bpcg.init_iter(nband, nbasis); + bpcg.init_iter(nband, nbasis, ndim); bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } else if (this->method == "dav_subspace") diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index d8060b1763..8978334106 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -153,7 +153,8 @@ class DiagoBPCGPrepare zero_, hpsi_out, ld_psi); }; - bpcg.init_iter(nband, npw); + const int ndim = psi_local.get_current_ngk(); + bpcg.init_iter(nband, npw, ndim); std::vector ethr_band(nband, 1e-5); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); diff --git a/tests/integrate/102_PW_BPCG/result.ref b/tests/integrate/102_PW_BPCG/result.ref index e702dfbb6b..2972395a15 100644 --- a/tests/integrate/102_PW_BPCG/result.ref +++ b/tests/integrate/102_PW_BPCG/result.ref @@ -1,7 +1,7 @@ etotref -4869.74705201 etotperatomref -2434.87352600 totalforceref 5.19483000 -totalstressref 37241.44843500 +totalstressref 37241.45334600 pointgroupref C_1 spacegroupref C_1 nksibzref 8