Skip to content

Commit

Permalink
Fix: use gemm instead of einsum in BPCG (#5827)
Browse files Browse the repository at this point in the history
* Add dimension parameter for BPCG method

* Add utils for hsovler gemm_op

* Change code to fit new bpcg init interface

* using gemm instead of einsum in orth_cholesky

* using gemm instead of einsum in orth_projection

* replace einsum by gemm in orth_projection

* replace einsum by gemm in rotate_wf

* replace einsum by gemm in diag_hsub

* Update 102_PW_BPCG totalstressref reference value
  • Loading branch information
Cstandardlib authored Jan 10, 2025
1 parent 24abddd commit c898e52
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 10 deletions.
99 changes: 93 additions & 6 deletions source/module_hsolver/diago_bpcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
this->device_type = ct::DeviceTypeToEnum<Device>::value;

this->h_prec = std::move(ct::TensorMap((void *) precondition_in, r_type, device_type, {this->n_basis}));

this->one = &one_;
this->zero = &zero_;
this->neg_one = &neg_one_;
}

template<typename T, typename Device>
Expand All @@ -30,11 +34,11 @@ DiagoBPCG<T, Device>::~DiagoBPCG() {
}

template<typename T, typename Device>
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis, const int ndim) {
// Specify the problem size n_basis, n_band, while lda is n_basis
this->n_band = nband;
this->n_basis = nbasis;

this->n_dim = ndim;

// All column major tensors

Expand Down Expand Up @@ -93,7 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
// hsub_out = psi_out * transc(psi_out)
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);

// gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
gemm_op()(this->ctx,
'C',
'N',
this->n_band, //m
this->n_band, //n
this->n_dim, //k
this->one, //1.0
psi_out.data<T>(),
this->n_basis, //lda
psi_out.data<T>(),
this->n_basis, //ldb
this->zero, //0.0
hsub_out.data<T>(),
this->n_band); //ldc

// set hsub matrix to lower format;
ct::kernels::set_matrix<T, ct_Device>()(
Expand Down Expand Up @@ -145,12 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
{
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);

// this->orth_projection(this->psi, this->hsub, this->grad);
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
gemm_op()(this->ctx,
'C',
'N',
this->n_band, //m
this->n_band, //n
this->n_dim, //k
this->one, //1.0
psi_in.data<T>(),
this->n_basis, //lda
grad_out.data<T>(),
this->n_basis, //ldb
this->zero, //0.0
hsub_in.data<T>(),
this->n_band); //ldc

// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
option = ct::EinsumOption(
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);

// grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
gemm_op()(this->ctx,
'N',
'N',
this->n_dim, //m
this->n_band, //n
this->n_band, //k
this->neg_one, //-1.0
psi_in.data<T>(),
this->n_basis, //lda
hsub_in.data<T>(),
this->n_band, //ldb
this->one, //1.0
grad_out.data<T>(),
this->n_basis); //ldc

return;
}
Expand All @@ -165,6 +218,24 @@ void DiagoBPCG<T, Device>::rotate_wf(
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in);
workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);

// this->rotate_wf(hsub_out, psi_out, workspace_in);
// this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
// gemm_op()(this->ctx,
// 'N',
// 'N',
// this->n_basis, //m
// this->n_band, //n
// this->n_band, //k
// this->one, //1.0
// psi_out.data<T>(),
// this->n_basis, //lda
// hsub_in.data<T>(),
// this->n_band, //ldb
// this->zero, //0.0
// workspace_in.data<T>(),
// this->n_basis); //ldc

syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);

return;
Expand Down Expand Up @@ -192,7 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
// it controls the ops to use the corresponding device to calculate results
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
// hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);

// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
gemm_op()(this->ctx,
'C',
'N',
this->n_band, //m
this->n_band, //n
this->n_dim, //k
this->one, //1.0
hpsi_in.data<T>(),
this->n_basis, //lda
psi_in.data<T>(),
this->n_basis, //ldb
this->zero, //0.0
hsub_out.data<T>(),
this->n_band); //ldc

ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());

Expand Down
13 changes: 12 additions & 1 deletion source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class DiagoBPCG
*
* @param nband The number of bands.
* @param nbasis The number of basis functions. Leading dimension of psi.
* @param ndim The number of valid dimension of psi.
*/
void init_iter(const int nband, const int nbasis);
void init_iter(const int nband, const int nbasis, const int ndim);

using HPsiFunc = std::function<void(T*, T*, const int, const int)>;

Expand All @@ -77,6 +78,8 @@ class DiagoBPCG
int n_band = 0;
/// the number of cols of the input psi
int n_basis = 0;
/// valid dimension of psi
int n_dim = 0;
/// max iter steps for all-band cg loop
int nline = 4;

Expand Down Expand Up @@ -107,6 +110,13 @@ class DiagoBPCG
/// work for some calculations within this class, including rotate_wf call
ct::Tensor work = {};

// These are for hsolver gemm_op use
/// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;),
Device * ctx = {};
// Pointer to objects of 1 and 0 for gemm
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);

/**
* @brief Update the precondition array.
*
Expand Down Expand Up @@ -332,6 +342,7 @@ class DiagoBPCG

using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
using gemm_op = hsolver::gemm_op<T, Device>;

};

Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
{
const int nband = psi.get_nbands();
const int nbasis = psi.get_nbasis();
const int ndim = psi.get_current_ngk();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
Expand All @@ -499,7 +500,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
};
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(nband, nbasis);
bpcg.init_iter(nband, nbasis, ndim);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
}
else if (this->method == "dav_subspace")
Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/test/diago_bpcg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class DiagoBPCGPrepare
zero_,
hpsi_out, ld_psi);
};
bpcg.init_iter(nband, npw);
const int ndim = psi_local.get_current_ngk();
bpcg.init_iter(nband, npw, ndim);
std::vector<double> ethr_band(nband, 1e-5);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
Expand Down
2 changes: 1 addition & 1 deletion tests/integrate/102_PW_BPCG/result.ref
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
etotref -4869.74705201
etotperatomref -2434.87352600
totalforceref 5.19483000
totalstressref 37241.44843500
totalstressref 37241.45334600
pointgroupref C_1
spacegroupref C_1
nksibzref 8
Expand Down

0 comments on commit c898e52

Please sign in to comment.