Skip to content

Commit

Permalink
Refactor: Remove the global dependence of all remained functions in D…
Browse files Browse the repository at this point in the history
…eePKS. (#5835)

* Remove global dependence of cal_gevdm and rearrange the calling order for simplifying.

* Move some checks from FORCE_STRESS to LCAO_Deepks_interface.

* Remove the global dependence of cal_e_delta_band.

* Move cal_gedm to deepks_basic.cpp

* Remove the global dependence of functions related to pdm in DeePKS.

* Revert "Remove the global dependence of functions related to pdm in DeePKS."

This reverts commit 7a97a95.

* Remove global dependence of pdm related functions in DeePKS.

* Fix the compile bug of DeePKS UT test.

* Remove the global dependence of functions related to phialpha in DeePKS.

* Simplify some function for LCAO_deepks_io.

* Update FORCE_STRESS.cpp

* [pre-commit.ci lite] apply automatic fixes

* Update esolver_ks_lcao.cpp

* Update LCAO_deepks_interface.cpp

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
ErjieWu and pre-commit-ci-lite[bot] authored Jan 9, 2025
1 parent 74b2954 commit 24abddd
Show file tree
Hide file tree
Showing 31 changed files with 1,043 additions and 1,286 deletions.
11 changes: 5 additions & 6 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -193,21 +193,20 @@ OBJS_CELL=atom_pseudo.o\
read_atom_species.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_basic.o\
deepks_descriptor.o\
deepks_force.o\
deepks_fpre.o\
deepks_spre.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
deepks_vdelta.o\
deepks_vdpre.o\
deepks_hmat.o\
deepks_pdm.o\
deepks_phialpha.o\
LCAO_deepks_io.o\
LCAO_deepks_pdm.o\
LCAO_deepks_phialpha.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
LCAO_deepks_interface.o\
cal_gedm.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down
18 changes: 10 additions & 8 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,16 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
if (PARAM.inp.deepks_scf)
{
// load the DeePKS model from deep neural network
GlobalC::ld.load_model(PARAM.inp.deepks_model);
DeePKS_domain::load_model(PARAM.inp.deepks_model, GlobalC::ld.model_deepks);
// read pdm from file for NSCF or SCF-restart, do it only once in whole calculation
GlobalC::ld.read_projected_DM((PARAM.inp.init_chg == "file"), PARAM.inp.deepks_equiv, *orb_.Alpha);
DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"),
PARAM.inp.deepks_equiv,
GlobalC::ld.init_pdm,
GlobalC::ld.inlmax,
GlobalC::ld.lmaxd,
GlobalC::ld.inl_l,
*orb_.Alpha,
GlobalC::ld.pdm);
}
#endif

Expand Down Expand Up @@ -928,9 +935,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
// 1) calculate the kinetic energy density tau, sunliang 2024-09-18
if (PARAM.inp.out_elf[0] > 0)
{
elecstate::lcao_cal_tau<TK>(&(this->GG),
&(this->GK),
this->pelec->charge);
elecstate::lcao_cal_tau<TK>(&(this->GG), &(this->GK), this->pelec->charge);
}

//! 2) call after_scf() of ESolver_KS
Expand Down Expand Up @@ -1047,7 +1052,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&GlobalC::ld, [](LCAO_Deepks*) {});
LCAO_Deepks_Interface<TK, TR> LDI(ld_shared_ptr);

ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
LDI.out_deepks_labels(this->pelec->f_en.etot,
this->pelec->klist->get_nks(),
ucell.nat,
Expand All @@ -1061,8 +1065,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
*(this->psi),
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
p_ham_deepks);

ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
}
#endif

Expand Down
12 changes: 9 additions & 3 deletions source/module_esolver/lcao_before_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,19 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
{
const Parallel_Orbitals* pv = &this->pv;
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
// build and save <phi(0)|alpha(R)> at beginning
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
DeePKS_domain::build_phialpha(PARAM.inp.cal_force,
ucell,
orb_,
this->gd,
pv,
*(two_center_bundle_.overlap_orb_alpha),
GlobalC::ld.phialpha);

if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
}
}
#endif
Expand Down
12 changes: 9 additions & 3 deletions source/module_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,19 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
{
const Parallel_Orbitals* pv = &this->pv;
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
// build and save <phi(0)|alpha(R)> at beginning
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
DeePKS_domain::build_phialpha(PARAM.inp.cal_force,
ucell,
orb_,
this->gd,
pv,
*(two_center_bundle_.overlap_orb_alpha),
GlobalC::ld.phialpha);

if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
}
}
#endif
Expand Down
153 changes: 10 additions & 143 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,87 +500,16 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
if (PARAM.inp.deepks_out_labels) // not parallelized yet
{
const std::string file_ftot = PARAM.globalv.global_out_dir + "deepks_ftot.npy";
LCAO_deepks_io::save_npy_f(fcs, file_ftot, ucell.nat,
GlobalV::MY_RANK); // Ty/Bohr, F_tot
LCAO_deepks_io::save_npy_f(fcs, file_ftot, GlobalV::MY_RANK); // Ry/Bohr, F_tot

const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
if (PARAM.inp.deepks_scf)
{
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha,
file_fbase,
ucell.nat,
GlobalV::MY_RANK); // Ry/Bohr, F_base

if (!PARAM.inp.deepks_equiv) // training with force label not supported by equivariant version now
{
torch::Tensor gdmx;
if (PARAM.globalv.gamma_only_local)
{
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_gamma,
ucell,
orb,
pv,
gd,
gdmx);
}
else
{
const std::vector<std::vector<std::complex<double>>>& dm_k
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
->get_DM()
->get_DMK_vector();

DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_k,
ucell,
orb,
pv,
gd,
gdmx);
}
std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
torch::Tensor gvx;
DeePKS_domain::cal_gvx(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
gevdm,
gdmx,
gvx);

if (PARAM.inp.deepks_out_unittest)
{
DeePKS_domain::check_gdmx(gdmx);
DeePKS_domain::check_gvx(gvx);
}

LCAO_deepks_io::save_npy_gvx(ucell.nat,
GlobalC::ld.des_per_atom,
gvx,
PARAM.globalv.global_out_dir,
GlobalV::MY_RANK);
}
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha, file_fbase, GlobalV::MY_RANK); // Ry/Bohr, F_base
}
else
{
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
LCAO_deepks_io::save_npy_f(fcs, file_fbase, ucell.nat,
GlobalV::MY_RANK); // no scf, F_base=F_tot
LCAO_deepks_io::save_npy_f(fcs, file_fbase, GlobalV::MY_RANK); // no scf, F_base=F_tot
}
}
#endif
Expand Down Expand Up @@ -758,80 +687,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_tot, w/ model

// wenfei add 2021/11/2
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
if (PARAM.inp.deepks_scf)
{
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
LCAO_deepks_io::save_npy_s(scs - svnl_dalpha,
file_sbase,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;

if (!PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
{
torch::Tensor gdmepsl;
if (PARAM.globalv.gamma_only_local)
{
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_gamma,
ucell,
orb,
pv,
gd,
gdmepsl);
}
else
{
const std::vector<std::vector<std::complex<double>>>& dm_k
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
->get_DM()
->get_DMK_vector();

DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
GlobalC::ld.inlmax,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
GlobalC::ld.inl_index,
dm_k,
ucell,
orb,
pv,
gd,
gdmepsl);
}

std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
torch::Tensor gvepsl;
DeePKS_domain::cal_gvepsl(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
gevdm,
gdmepsl,
gvepsl);

if (PARAM.inp.deepks_out_unittest)
{
DeePKS_domain::check_gdmepsl(gdmepsl);
DeePKS_domain::check_gvepsl(gvepsl);
}

LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
GlobalC::ld.des_per_atom,
gvepsl,
PARAM.globalv.global_out_dir,
GlobalV::MY_RANK); // unitless, grad_vepsl
}
}
else
{
LCAO_deepks_io::save_npy_s(scs, file_sbase, ucell.omega,
GlobalV::MY_RANK); // sbase = stot
}
}
#endif
Expand Down
38 changes: 11 additions & 27 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,23 @@ void Force_LCAO<double>::ftable(const bool isforce,
if (PARAM.inp.deepks_scf)
{
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

DeePKS_domain::cal_descriptor(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(ucell.nat, descriptor);
DeePKS_domain::cal_gedm(ucell.nat,
GlobalC::ld.lmaxd,
GlobalC::ld.nmaxd,
GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
descriptor,
GlobalC::ld.pdm,
GlobalC::ld.model_deepks,
GlobalC::ld.gedm,
GlobalC::ld.E_delta);

const int nks = 1;
DeePKS_domain::cal_f_delta<double>(dm_gamma,
Expand Down Expand Up @@ -302,32 +310,8 @@ void Force_LCAO<double>::ftable(const bool isforce,
}

#ifdef __DEEPKS
// It seems these test should not all be here, should be moved in the future
// Also, these test are not in multi-k case now
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
{
const int nks = 1; // 1 for gamma-only
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);

GlobalC::ld.check_projected_dm();

DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
ucell,
PARAM.globalv.global_out_dir,
descriptor);

GlobalC::ld.check_gedm();

GlobalC::ld.cal_e_delta_band(dm_gamma, nks);

std::ofstream ofs("E_delta_bands.dat");
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;

std::ofstream ofs1("E_delta.dat");
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;

DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
}
#endif
Expand Down
Loading

0 comments on commit 24abddd

Please sign in to comment.