Skip to content

Commit

Permalink
fix: memory leak when precision=single (#5839)
Browse files Browse the repository at this point in the history
* fix: memory leak when precision=single

* change op

* fix wrong logic of atomic+random
  • Loading branch information
Qianruipku authored Jan 9, 2025
1 parent 48fbc90 commit 74b2954
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 14 deletions.
7 changes: 5 additions & 2 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,18 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
container::kernels::destroyGpuBlasHandle();
container::kernels::destroyGpuSolverHandle();
#endif
delete reinterpret_cast<psi::Psi<T, Device>*>(this->kspw_psi);
}
#ifdef __DSP
std::cout << " ** Closing DSP Hardware..." << std::endl;
dspDestoryHandle(GlobalV::MY_RANK);
#endif
if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
delete this->kspw_psi;
}
if (PARAM.inp.precision == "single")
{
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
delete this->__kspw_psi;
}

delete this->psi;
Expand Down
5 changes: 4 additions & 1 deletion source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,12 @@ void pseudopot_cell_vnl::getvnl(Device* ctx,
delmem_var_op()(ctx, ylm);
delmem_var_op()(ctx, vkb1);
delmem_complex_op()(ctx, sk);
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
delmem_var_op()(ctx, gk);
}
if (PARAM.inp.device == "gpu")
{
delmem_int_op()(ctx, atom_nh);
delmem_int_op()(ctx, atom_nb);
delmem_int_op()(ctx, atom_na);
Expand Down
16 changes: 16 additions & 0 deletions source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,28 @@ void ReadInput::item_system()
para.input.device=base_device::information::get_device_flag(
para.inp.device, para.inp.basis_type);
};
item.check_value = [](const Input_Item& item, const Parameter& para) {
std::vector<std::string> avail_list = {"cpu", "gpu"};
if (std::find(avail_list.begin(), avail_list.end(), para.input.device) == avail_list.end())
{
const std::string warningstr = nofound_str(avail_list, "device");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
}
};
this->add_item(item);
}
{
Input_Item item("precision");
item.annotation = "the computing precision for ABACUS";
read_sync_string(input.precision);
item.check_value = [](const Input_Item& item, const Parameter& para) {
std::vector<std::string> avail_list = {"single", "double"};
if (std::find(avail_list.begin(), avail_list.end(), para.input.precision) == avail_list.end())
{
const std::string warningstr = nofound_str(avail_list, "precision");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
}
};
this->add_item(item);
}
}
Expand Down
35 changes: 26 additions & 9 deletions source/module_psi/psi_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void PSIInit<T, Device>::prepare_init(const int& random_seed)
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_random<T>());
}
else if (this->init_wfc == "atomic"
|| (this->init_wfc == "atomic+random" && this->ucell.natomwfc != PARAM.inp.nbands))
|| (this->init_wfc == "atomic+random" && this->ucell.natomwfc < PARAM.inp.nbands))
{
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_atomic<T>());
}
Expand Down Expand Up @@ -99,17 +99,30 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
const int nbands_start = this->psi_initer->nbands_start();
const int nbands = psi->get_nbands();
const int nbasis = psi->get_nbasis();
const bool another_psi_space = (nbands_start != nbands || PARAM.inp.precision == "single");
const bool not_equal = (nbands_start != nbands);

Psi<T>* psi_cpu = reinterpret_cast<psi::Psi<T>*>(psi);
Psi<T, Device>* psi_device = kspw_psi;

if (another_psi_space)
if (not_equal)
{
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
psi_device = PARAM.inp.device == "gpu" ? new psi::Psi<T, Device>(psi_cpu[0])
: reinterpret_cast<psi::Psi<T, Device>*>(psi_cpu);
}
else if (PARAM.inp.precision == "single")
{
if (PARAM.inp.device == "cpu")
{
psi_cpu = reinterpret_cast<psi::Psi<T>*>(kspw_psi);
psi_device = kspw_psi;
}
else
{
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
psi_device = kspw_psi;
}
}

// loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
Expand All @@ -126,16 +139,16 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
this->psi_initer->init_psig(psi_cpu->get_pointer(), ik);
if (psi_device->get_pointer() != psi_cpu->get_pointer())
{
castmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
syncmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
}

std::vector<typename GetTypeReal<T>::type> etatom(nbands_start, 0.0);

if (this->ks_solver == "cg")
{
if (another_psi_space)
if (not_equal)
{
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() should be different
// for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
psi_device->get_pointer(),
nbands_start,
Expand All @@ -145,7 +158,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}
else
{
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() can be the same
// for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt,
*psi_device,
*kspw_psi,
Expand All @@ -155,21 +168,25 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}
else // dav, bpcg
{
if (another_psi_space)
if (psi_device->get_pointer() != kspw_psi->get_pointer())
{
syncmem_complex_op()(ctx, ctx, kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis);
}
}
} // end k-point loop

if (another_psi_space)
if (not_equal)
{
delete psi_cpu;
if(PARAM.inp.device == "gpu")
{
delete psi_device;
}
}
else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu")
{
delete psi_cpu;
}

ModuleBase::timer::tick("PSIInit", "initialize_psi");
}
Expand Down
3 changes: 1 addition & 2 deletions source/module_psi/psi_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ class PSIInit

//-------------------------OP--------------------------------------------
using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
using castmem_h2d_op
= base_device::memory::cast_memory_op<T, T, Device, base_device::DEVICE_CPU>;
using syncmem_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
};

///@brief allocate the wavefunction
Expand Down

0 comments on commit 74b2954

Please sign in to comment.