diff --git a/source/source_esolver/esolver_dp.cpp b/source/source_esolver/esolver_dp.cpp index 879193e668b..7eb2d9e1fab 100644 --- a/source/source_esolver/esolver_dp.cpp +++ b/source/source_esolver/esolver_dp.cpp @@ -36,6 +36,10 @@ void ESolver_DP::before_all_runners(UnitCell& ucell, const Input_para& inp) dp_potential = 0; dp_force.create(ucell.nat, 3); dp_virial.create(3, 3); + dp_cell.resize(9); + dp_coord.resize(3 * ucell.nat); + dp_model_force.clear(); + dp_model_virial.clear(); ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU.cif", ucell, @@ -59,38 +63,38 @@ void ESolver_DP::runner(UnitCell& ucell, const int istep) ModuleBase::TITLE("ESolver_DP", "runner"); ModuleBase::timer::start("ESolver_DP", "runner"); - std::vector cell(9, 0.0); - cell[0] = ucell.latvec.e11 * ucell.lat0_angstrom; - cell[1] = ucell.latvec.e12 * ucell.lat0_angstrom; - cell[2] = ucell.latvec.e13 * ucell.lat0_angstrom; - cell[3] = ucell.latvec.e21 * ucell.lat0_angstrom; - cell[4] = ucell.latvec.e22 * ucell.lat0_angstrom; - cell[5] = ucell.latvec.e23 * ucell.lat0_angstrom; - cell[6] = ucell.latvec.e31 * ucell.lat0_angstrom; - cell[7] = ucell.latvec.e32 * ucell.lat0_angstrom; - cell[8] = ucell.latvec.e33 * ucell.lat0_angstrom; - - std::vector coord(3 * ucell.nat, 0.0); + dp_cell[0] = ucell.latvec.e11 * ucell.lat0_angstrom; + dp_cell[1] = ucell.latvec.e12 * ucell.lat0_angstrom; + dp_cell[2] = ucell.latvec.e13 * ucell.lat0_angstrom; + dp_cell[3] = ucell.latvec.e21 * ucell.lat0_angstrom; + dp_cell[4] = ucell.latvec.e22 * ucell.lat0_angstrom; + dp_cell[5] = ucell.latvec.e23 * ucell.lat0_angstrom; + dp_cell[6] = ucell.latvec.e31 * ucell.lat0_angstrom; + dp_cell[7] = ucell.latvec.e32 * ucell.lat0_angstrom; + dp_cell[8] = ucell.latvec.e33 * ucell.lat0_angstrom; + + dp_coord.resize(3 * ucell.nat); int iat = 0; for (int it = 0; it < ucell.ntype; ++it) { for (int ia = 0; ia < ucell.atoms[it].na; ++ia) { - coord[3 * iat] = ucell.atoms[it].tau[ia].x * ucell.lat0_angstrom; - coord[3 * iat + 1] = ucell.atoms[it].tau[ia].y * ucell.lat0_angstrom; - coord[3 * iat + 2] = ucell.atoms[it].tau[ia].z * ucell.lat0_angstrom; + dp_coord[3 * iat] = ucell.atoms[it].tau[ia].x * ucell.lat0_angstrom; + dp_coord[3 * iat + 1] = ucell.atoms[it].tau[ia].y * ucell.lat0_angstrom; + dp_coord[3 * iat + 2] = ucell.atoms[it].tau[ia].z * ucell.lat0_angstrom; iat++; } } assert(ucell.nat == iat); #ifdef __DPMD - std::vector f, v; dp_potential = 0; dp_force.zero_out(); dp_virial.zero_out(); + dp_model_force.clear(); + dp_model_virial.clear(); - dp.compute(dp_potential, f, v, coord, atype, cell, fparam, aparam); + dp.compute(dp_potential, dp_model_force, dp_model_virial, dp_coord, atype, dp_cell, fparam, aparam); // rescale the energy, force, and stress const double fact_e = rescaling / ModuleBase::Ry_to_eV; @@ -103,16 +107,16 @@ void ESolver_DP::runner(UnitCell& ucell, const int istep) for (int i = 0; i < ucell.nat; ++i) { - dp_force(i, 0) = f[3 * i] * fact_f; - dp_force(i, 1) = f[3 * i + 1] * fact_f; - dp_force(i, 2) = f[3 * i + 2] * fact_f; + dp_force(i, 0) = dp_model_force[3 * i] * fact_f; + dp_force(i, 1) = dp_model_force[3 * i + 1] * fact_f; + dp_force(i, 2) = dp_model_force[3 * i + 2] * fact_f; } for (int i = 0; i < 3; ++i) { for (int j = 0; j < 3; ++j) { - dp_virial(i, j) = v[3 * i + j] * fact_v; + dp_virial(i, j) = dp_model_virial[3 * i + j] * fact_v; } } #else diff --git a/source/source_esolver/esolver_dp.h b/source/source_esolver/esolver_dp.h index 405bae44461..52ddd869911 100644 --- a/source/source_esolver/esolver_dp.h +++ b/source/source_esolver/esolver_dp.h @@ -115,6 +115,10 @@ class ESolver_DP : public ESolver double dp_potential = 0.0; ///< computed potential energy ModuleBase::matrix dp_force; ///< computed atomic forces ModuleBase::matrix dp_virial; ///< computed lattice virials + std::vector dp_cell; ///< DP cell buffer in Angstrom + std::vector dp_coord; ///< DP coordinate buffer in Angstrom + std::vector dp_model_force; ///< raw force buffer returned by DP + std::vector dp_model_virial; ///< raw virial buffer returned by DP }; } // namespace ModuleESolver diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 454ef7532a2..57d0a1100be 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -279,9 +279,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Calculate kinetic energy density tau for ELF if needed if (PARAM.inp.out_elf[0] > 0) { - auto* elec_pw = static_cast*>(this->pelec); - auto& psi = *this->stp.template get_psi_t(); - elec_pw->cal_tau(psi); + this->pelec->cal_tau(*(this->stp.psi_cpu)); } ESolver_KS::after_scf(ucell, istep, conv_esolver); diff --git a/source/source_esolver/esolver_nep.cpp b/source/source_esolver/esolver_nep.cpp index 8944776aaa6..b586983a647 100644 --- a/source/source_esolver/esolver_nep.cpp +++ b/source/source_esolver/esolver_nep.cpp @@ -23,7 +23,7 @@ #include "source_io/module_output/output_log.h" #include "source_io/module_output/cif_io.h" -#include +#include #include using namespace ModuleESolver; @@ -34,9 +34,26 @@ void ESolver_NEP::before_all_runners(UnitCell& ucell, const Input_para& inp) nep_force.create(ucell.nat, 3); nep_virial.create(3, 3); atype.resize(ucell.nat); + nep_cell.resize(9); + nep_coord.resize(3 * ucell.nat); + nep_virial_sum.resize(9); _e.resize(ucell.nat); _f.resize(3 * ucell.nat); _v.resize(9 * ucell.nat); + atom_type_index.resize(ucell.nat); + atom_local_index.resize(ucell.nat); + + int iat = 0; + for (int it = 0; it < ucell.ntype; ++it) + { + for (int ia = 0; ia < ucell.atoms[it].na; ++ia) + { + atom_type_index[iat] = it; + atom_local_index[iat] = ia; + ++iat; + } + } + assert(ucell.nat == iat); ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU.cif", ucell, @@ -56,39 +73,35 @@ void ESolver_NEP::runner(UnitCell& ucell, const int istep) // note that NEP are column major, thus a transpose is needed // cell - std::vector cell(9, 0.0); - cell[0] = ucell.latvec.e11 * ucell.lat0_angstrom; - cell[1] = ucell.latvec.e21 * ucell.lat0_angstrom; - cell[2] = ucell.latvec.e31 * ucell.lat0_angstrom; - cell[3] = ucell.latvec.e12 * ucell.lat0_angstrom; - cell[4] = ucell.latvec.e22 * ucell.lat0_angstrom; - cell[5] = ucell.latvec.e32 * ucell.lat0_angstrom; - cell[6] = ucell.latvec.e13 * ucell.lat0_angstrom; - cell[7] = ucell.latvec.e23 * ucell.lat0_angstrom; - cell[8] = ucell.latvec.e33 * ucell.lat0_angstrom; + nep_cell[0] = ucell.latvec.e11 * ucell.lat0_angstrom; + nep_cell[1] = ucell.latvec.e21 * ucell.lat0_angstrom; + nep_cell[2] = ucell.latvec.e31 * ucell.lat0_angstrom; + nep_cell[3] = ucell.latvec.e12 * ucell.lat0_angstrom; + nep_cell[4] = ucell.latvec.e22 * ucell.lat0_angstrom; + nep_cell[5] = ucell.latvec.e32 * ucell.lat0_angstrom; + nep_cell[6] = ucell.latvec.e13 * ucell.lat0_angstrom; + nep_cell[7] = ucell.latvec.e23 * ucell.lat0_angstrom; + nep_cell[8] = ucell.latvec.e33 * ucell.lat0_angstrom; // coord - std::vector coord(3 * ucell.nat, 0.0); - int iat = 0; + nep_coord.resize(3 * ucell.nat); const int nat = ucell.nat; - for (int it = 0; it < ucell.ntype; ++it) +#pragma omp parallel for schedule(static) if (nat >= 256) + for (int iat = 0; iat < nat; ++iat) { - for (int ia = 0; ia < ucell.atoms[it].na; ++ia) - { - coord[iat] = ucell.atoms[it].tau[ia].x * ucell.lat0_angstrom; - coord[iat + nat] = ucell.atoms[it].tau[ia].y * ucell.lat0_angstrom; - coord[iat + 2 * nat] = ucell.atoms[it].tau[ia].z * ucell.lat0_angstrom; - iat++; - } + const int it = atom_type_index[iat]; + const int ia = atom_local_index[iat]; + nep_coord[iat] = ucell.atoms[it].tau[ia].x * ucell.lat0_angstrom; + nep_coord[iat + nat] = ucell.atoms[it].tau[ia].y * ucell.lat0_angstrom; + nep_coord[iat + 2 * nat] = ucell.atoms[it].tau[ia].z * ucell.lat0_angstrom; } - assert(ucell.nat == iat); #ifdef __NEP nep_potential = 0.0; nep_force.zero_out(); nep_virial.zero_out(); - nep.compute(atype, cell, coord, _e, _f, _v); + nep.compute(atype, nep_cell, nep_coord, _e, _f, _v); // unit conversion const double fact_e = 1.0 / ModuleBase::Ry_to_eV; @@ -97,11 +110,18 @@ void ESolver_NEP::runner(UnitCell& ucell, const int istep) // potential energy - nep_potential = fact_e * std::accumulate(_e.begin(), _e.end(), 0.0) ; + double energy_sum = 0.0; +#pragma omp parallel for reduction(+:energy_sum) schedule(static) if (nat >= 256) + for (int i = 0; i < nat; ++i) + { + energy_sum += _e[i]; + } + nep_potential = fact_e * energy_sum; GlobalV::ofs_running << " #TOTAL ENERGY# " << std::setprecision(11) << nep_potential * ModuleBase::Ry_to_eV << " eV" << std::endl; // forces +#pragma omp parallel for schedule(static) if (nat >= 256) for (int i = 0; i < nat; ++i) { nep_force(i, 0) = _f[i] * fact_f; @@ -110,22 +130,44 @@ void ESolver_NEP::runner(UnitCell& ucell, const int istep) } // virial - std::vector v_sum(9, 0.0); - for (int j = 0; j < 9; ++j) + double v0 = 0.0; + double v1 = 0.0; + double v2 = 0.0; + double v3 = 0.0; + double v4 = 0.0; + double v5 = 0.0; + double v6 = 0.0; + double v7 = 0.0; + double v8 = 0.0; +#pragma omp parallel for reduction(+:v0, v1, v2, v3, v4, v5, v6, v7, v8) schedule(static) if (nat >= 256) + for (int i = 0; i < nat; ++i) { - for (int i = 0; i < nat; ++i) - { - int index = j * nat + i; - v_sum[j] += _v[index]; - } + v0 += _v[i]; + v1 += _v[nat + i]; + v2 += _v[2 * nat + i]; + v3 += _v[3 * nat + i]; + v4 += _v[4 * nat + i]; + v5 += _v[5 * nat + i]; + v6 += _v[6 * nat + i]; + v7 += _v[7 * nat + i]; + v8 += _v[8 * nat + i]; } + nep_virial_sum[0] = v0; + nep_virial_sum[1] = v1; + nep_virial_sum[2] = v2; + nep_virial_sum[3] = v3; + nep_virial_sum[4] = v4; + nep_virial_sum[5] = v5; + nep_virial_sum[6] = v6; + nep_virial_sum[7] = v7; + nep_virial_sum[8] = v8; // virial -> stress for (int i = 0; i < 3; ++i) { for (int j = 0; j < 3; ++j) { - nep_virial(i, j) = v_sum[3 * i + j] * fact_v; + nep_virial(i, j) = nep_virial_sum[3 * i + j] * fact_v; } } #else diff --git a/source/source_esolver/esolver_nep.h b/source/source_esolver/esolver_nep.h index dfec17a83c2..bcbd658c319 100644 --- a/source/source_esolver/esolver_nep.h +++ b/source/source_esolver/esolver_nep.h @@ -95,9 +95,14 @@ class ESolver_NEP : public ESolver std::string nep_file; ///< directory of NEP model file std::vector atype = {}; ///< atom type mapping for NEP model + std::vector atom_type_index; ///< global atom index to UnitCell atom type + std::vector atom_local_index; ///< global atom index to local index inside atom type double nep_potential; ///< computed potential energy ModuleBase::matrix nep_force; ///< computed atomic forces ModuleBase::matrix nep_virial; ///< computed lattice virials + std::vector nep_cell; ///< NEP cell buffer in Angstrom, column-major + std::vector nep_coord; ///< NEP coordinate buffer in Angstrom, column-major + std::vector nep_virial_sum; ///< summed per-atom virial components std::vector _e; ///< temporary storage for energy computation std::vector _f; ///< temporary storage for force computation std::vector _v; ///< temporary storage for virial computation @@ -105,4 +110,4 @@ class ESolver_NEP : public ESolver } // namespace ModuleESolver -#endif \ No newline at end of file +#endif diff --git a/source/source_md/md_base.cpp b/source/source_md/md_base.cpp index 390e1c2b082..f87fca9d2ee 100644 --- a/source/source_md/md_base.cpp +++ b/source/source_md/md_base.cpp @@ -96,7 +96,9 @@ void MD_base::update_pos() { if (my_rank == 0) { - for (int i = 0; i < ucell.nat; ++i) + const int natom = ucell.nat; +#pragma omp parallel for schedule(static) if (natom >= 256) + for (int i = 0; i < natom; ++i) { for (int k = 0; k < 3; ++k) { @@ -127,7 +129,9 @@ void MD_base::update_vel(const ModuleBase::Vector3* force) { if (my_rank == 0) { - for (int i = 0; i < ucell.nat; ++i) + const int natom = ucell.nat; +#pragma omp parallel for schedule(static) if (natom >= 256) + for (int i = 0; i < natom; ++i) { for (int k = 0; k < 3; ++k) { diff --git a/source/source_md/md_func.cpp b/source/source_md/md_func.cpp index 6bd1b60dd59..3c7526cd9c2 100644 --- a/source/source_md/md_func.cpp +++ b/source/source_md/md_func.cpp @@ -44,6 +44,7 @@ double kinetic_energy(const int& natom, const ModuleBase::Vector3* vel, { double ke = 0; +#pragma omp parallel for reduction(+:ke) schedule(static) if (natom >= 256) for (int ion = 0; ion < natom; ++ion) { ke += 0.5 * allmass[ion] * vel[ion].norm2(); @@ -52,6 +53,43 @@ double kinetic_energy(const int& natom, const ModuleBase::Vector3* vel, return ke; } +MDKineticState calc_kinetic_state(const int& natom, + const int& frozen_freedom, + const double* allmass, + const ModuleBase::Vector3* vel) +{ + MDKineticState state; + if (3 * natom == frozen_freedom) + { + return state; + } + + state.kinetic = kinetic_energy(natom, vel, allmass); + state.temperature = 2 * state.kinetic / (3 * natom - frozen_freedom); + return state; +} + +MDStressState calc_stress_state(const int& natom, + const double& omega, + const ModuleBase::Vector3* vel, + const double* allmass, + const ModuleBase::matrix& virial) +{ + MDStressState state; + temp_vector(natom, vel, allmass, state.temperature_tensor); + state.stress.create(3, 3); + + for (int i = 0; i < 3; ++i) + { + for (int j = 0; j < 3; ++j) + { + state.stress(i, j) = virial(i, j) + state.temperature_tensor(i, j) / omega; + } + } + + return state; +} + void compute_stress(const UnitCell& unit_in, const ModuleBase::Vector3* vel, const double* allmass, @@ -61,17 +99,7 @@ void compute_stress(const UnitCell& unit_in, { if (cal_stress) { - ModuleBase::matrix t_vector; - - temp_vector(unit_in.nat, vel, allmass, t_vector); - - for (int i = 0; i < 3; ++i) - { - for (int j = 0; j < 3; ++j) - { - stress(i, j) = virial(i, j) + t_vector(i, j) / unit_in.omega; - } - } + stress = calc_stress_state(unit_in.nat, unit_in.omega, vel, allmass, virial).stress; } return; @@ -273,7 +301,9 @@ void force_virial(ModuleESolver::ESolver* p_esolver, force_temp *= 0.5; virial *= 0.5; - for (int i = 0; i < unit_in.nat; ++i) + const int natom = unit_in.nat; +#pragma omp parallel for schedule(static) if (natom >= 256) + for (int i = 0; i < natom; ++i) { for (int j = 0; j < 3; ++j) { @@ -463,8 +493,9 @@ double current_temp(double& kinetic, } else { - kinetic = kinetic_energy(natom, vel, allmass); - return 2 * kinetic / (3 * natom - frozen_freedom); + const MDKineticState state = calc_kinetic_state(natom, frozen_freedom, allmass, vel); + kinetic = state.kinetic; + return state.temperature; } } @@ -475,17 +506,45 @@ void temp_vector(const int& natom, { t_vector.create(3, 3); + double t00 = 0.0; + double t01 = 0.0; + double t02 = 0.0; + double t10 = 0.0; + double t11 = 0.0; + double t12 = 0.0; + double t20 = 0.0; + double t21 = 0.0; + double t22 = 0.0; + +#pragma omp parallel for reduction(+:t00, t01, t02, t10, t11, t12, t20, t21, t22) schedule(static) if (natom >= 256) for (int ion = 0; ion < natom; ++ion) { - for (int i = 0; i < 3; ++i) - { - for (int j = 0; j < 3; ++j) - { - t_vector(i, j) += allmass[ion] * vel[ion][i] * vel[ion][j]; - } - } + const double mass = allmass[ion]; + const double vx = vel[ion].x; + const double vy = vel[ion].y; + const double vz = vel[ion].z; + + t00 += mass * vx * vx; + t01 += mass * vx * vy; + t02 += mass * vx * vz; + t10 += mass * vy * vx; + t11 += mass * vy * vy; + t12 += mass * vy * vz; + t20 += mass * vz * vx; + t21 += mass * vz * vy; + t22 += mass * vz * vz; } + t_vector(0, 0) = t00; + t_vector(0, 1) = t01; + t_vector(0, 2) = t02; + t_vector(1, 0) = t10; + t_vector(1, 1) = t11; + t_vector(1, 2) = t12; + t_vector(2, 0) = t20; + t_vector(2, 1) = t21; + t_vector(2, 2) = t22; + return; } diff --git a/source/source_md/md_func.h b/source/source_md/md_func.h index be433ffe4ac..51c4eb47d83 100644 --- a/source/source_md/md_func.h +++ b/source/source_md/md_func.h @@ -1,6 +1,7 @@ #ifndef MD_FUNC_H #define MD_FUNC_H +#include "md_statistics.h" #include "source_esolver/esolver.h" class Parameter; @@ -117,6 +118,14 @@ void force_virial(ModuleESolver::ESolver* p_esolver, */ double kinetic_energy(const int& natom, const ModuleBase::Vector3* vel, const double* allmass); +/** + * @brief calculate kinetic energy and temperature without writing caller-owned state + */ +MDKineticState calc_kinetic_state(const int& natom, + const int& frozen_freedom, + const double* allmass, + const ModuleBase::Vector3* vel); + /** * @brief calculate the total stress tensor * @@ -134,6 +143,15 @@ void compute_stress(const UnitCell& unit_in, const ModuleBase::matrix& virial, ModuleBase::matrix& stress); +/** + * @brief calculate stress and ionic temperature tensor without writing caller-owned state + */ +MDStressState calc_stress_state(const int& natom, + const double& omega, + const ModuleBase::Vector3* vel, + const double* allmass, + const ModuleBase::matrix& virial); + /** * @brief output the stress information * diff --git a/source/source_md/md_statistics.h b/source/source_md/md_statistics.h new file mode 100644 index 00000000000..e7bef175be5 --- /dev/null +++ b/source/source_md/md_statistics.h @@ -0,0 +1,23 @@ +#ifndef MD_STATISTICS_H +#define MD_STATISTICS_H + +#include "source_base/matrix.h" + +namespace MD_func +{ + +struct MDKineticState +{ + double kinetic = 0.0; + double temperature = 0.0; +}; + +struct MDStressState +{ + ModuleBase::matrix stress; + ModuleBase::matrix temperature_tensor; +}; + +} // namespace MD_func + +#endif // MD_STATISTICS_H diff --git a/source/source_md/run_md.cpp b/source/source_md/run_md.cpp index ef28e5c8975..b7aab6511a0 100644 --- a/source/source_md/run_md.cpp +++ b/source/source_md/run_md.cpp @@ -12,41 +12,48 @@ #include "verlet.h" #include "source_cell/update_cell.h" #include "source_cell/print_cell.h" -namespace Run_MD -{ +#include -void md_line(UnitCell& unit_in, ModuleESolver::ESolver* p_esolver, const Parameter& param_in) +namespace +{ +std::unique_ptr create_md_runner(const Parameter& param_in, UnitCell& unit_in) { - ModuleBase::TITLE("Run_MD", "md_line"); - ModuleBase::timer::start("Run_MD", "md_line"); - - /// determine the md_type - MD_base* mdrun = nullptr; if (param_in.mdp.md_type == "fire") { - mdrun = new FIRE(param_in, unit_in); + return std::unique_ptr(new FIRE(param_in, unit_in)); } - else if ((param_in.mdp.md_type == "nvt" && param_in.mdp.md_thermostat == "nhc") || param_in.mdp.md_type == "npt") + if ((param_in.mdp.md_type == "nvt" && param_in.mdp.md_thermostat == "nhc") || param_in.mdp.md_type == "npt") { - mdrun = new Nose_Hoover(param_in, unit_in); + return std::unique_ptr(new Nose_Hoover(param_in, unit_in)); } - else if (param_in.mdp.md_type == "nve" || param_in.mdp.md_type == "nvt") + if (param_in.mdp.md_type == "nve" || param_in.mdp.md_type == "nvt") { - mdrun = new Verlet(param_in, unit_in); + return std::unique_ptr(new Verlet(param_in, unit_in)); } - else if (param_in.mdp.md_type == "langevin") + if (param_in.mdp.md_type == "langevin") { - mdrun = new Langevin(param_in, unit_in); + return std::unique_ptr(new Langevin(param_in, unit_in)); } - else if (param_in.mdp.md_type == "msst") + if (param_in.mdp.md_type == "msst") { - mdrun = new MSST(param_in, unit_in); - } - else - { - ModuleBase::WARNING_QUIT("md_line", "no such md_type!"); + return std::unique_ptr(new MSST(param_in, unit_in)); } + ModuleBase::WARNING_QUIT("md_line", "no such md_type!"); + return nullptr; +} +} // namespace + +namespace Run_MD +{ + +void md_line(UnitCell& unit_in, ModuleESolver::ESolver* p_esolver, const Parameter& param_in) +{ + ModuleBase::TITLE("Run_MD", "md_line"); + ModuleBase::timer::start("Run_MD", "md_line"); + + std::unique_ptr mdrun = create_md_runner(param_in, unit_in); + /// md cycle, mohan update 2026-01-04, change '<=' to '<' while ((mdrun->step_ + mdrun->step_rst_) < param_in.mdp.md_nstep && !mdrun->stop) { @@ -129,7 +136,6 @@ void md_line(UnitCell& unit_in, ModuleESolver::ESolver* p_esolver, const Paramet mdrun->step_++; } - delete mdrun; ModuleBase::timer::end("Run_MD", "md_line"); return; } diff --git a/source/source_md/test/CMakeLists.txt b/source/source_md/test/CMakeLists.txt index c4e3a1c2d2f..682bd682b49 100644 --- a/source/source_md/test/CMakeLists.txt +++ b/source/source_md/test/CMakeLists.txt @@ -37,7 +37,6 @@ list(APPEND depend_files ../../source_base/realarray.cpp ../../source_base/complexarray.cpp ../../source_base/complexmatrix.cpp - ../../source_base/global_variable.cpp ../../source_base/libm/branred.cpp ../../source_base/libm/sincos.cpp ../../source_base/math_integral.cpp diff --git a/source/source_md/test/fire_test.cpp b/source/source_md/test/fire_test.cpp index 3b294da46ac..e52ae4cbecd 100644 --- a/source/source_md/test/fire_test.cpp +++ b/source/source_md/test/fire_test.cpp @@ -5,9 +5,8 @@ #undef private #define private public #define protected public -#include "source_esolver/esolver_lj.h" #include "source_md/fire.h" -#include "setcell.h" +#include "md_test_fixture.h" #define doublethreshold 1e-12 /************************************************ @@ -35,31 +34,8 @@ * - output MD information such as energy, temperature, and pressure */ -class FIREtest : public testing::Test +class FIREtest : public MdIntegratorFixture { - protected: - MD_base* mdrun; - UnitCell ucell; - Parameter param_in; - ModuleESolver::ESolver* p_esolver; - - void SetUp() - { - Setcell::setupcell(ucell); - Setcell::parameters(param_in.input); - - p_esolver = new ModuleESolver::ESolver_LJ(); - p_esolver->before_all_runners(ucell, param_in.inp); - - mdrun = new FIRE(param_in, ucell); - mdrun->setup(p_esolver, PARAM.sys.global_readin_dir); - } - - void TearDown() - { - delete mdrun; - delete p_esolver; - } }; TEST_F(FIREtest, Setup) @@ -167,7 +143,7 @@ TEST_F(FIREtest, Restart) mdrun->restart(PARAM.sys.global_readin_dir); remove("Restart_md.txt"); - FIRE* fire = dynamic_cast(mdrun); + FIRE* fire = dynamic_cast(mdrun.get()); EXPECT_EQ(mdrun->step_rst_, 3); EXPECT_EQ(fire->alpha, 0.1); EXPECT_EQ(fire->negative_count, 0); diff --git a/source/source_md/test/langevin_test.cpp b/source/source_md/test/langevin_test.cpp index 69df605b153..65d462a86da 100644 --- a/source/source_md/test/langevin_test.cpp +++ b/source/source_md/test/langevin_test.cpp @@ -5,9 +5,8 @@ #undef private #define private public #define protected public -#include "source_esolver/esolver_lj.h" #include "source_md/langevin.h" -#include "setcell.h" +#include "md_test_fixture.h" #define doublethreshold 1e-12 /************************************************ @@ -35,31 +34,8 @@ * - output MD information such as energy, temperature, and pressure */ -class Langevin_test : public testing::Test +class Langevin_test : public MdIntegratorFixture { - protected: - MD_base* mdrun; - UnitCell ucell; - Parameter param_in; - ModuleESolver::ESolver* p_esolver; - - void SetUp() - { - Setcell::setupcell(ucell); - Setcell::parameters(param_in.input); - - p_esolver = new ModuleESolver::ESolver_LJ(); - p_esolver->before_all_runners(ucell, param_in.inp); - - mdrun = new Langevin(param_in, ucell); - mdrun->setup(p_esolver, PARAM.sys.global_readin_dir); - } - - void TearDown() - { - delete mdrun; - delete p_esolver; - } }; TEST_F(Langevin_test, setup) diff --git a/source/source_md/test/lj_pot_test.cpp b/source/source_md/test/lj_pot_test.cpp index 64c0b52fe6b..99cec432b56 100644 --- a/source/source_md/test/lj_pot_test.cpp +++ b/source/source_md/test/lj_pot_test.cpp @@ -1,9 +1,9 @@ #include "gtest/gtest.h" #define private public #include "source_io/module_parameter/parameter.h" +#include "md_test_fixture.h" #include "source_esolver/esolver_lj.h" #include "source_md/md_func.h" -#include "setcell.h" #undef private #define doublethreshold 1e-12 @@ -17,46 +17,23 @@ * - calculate energy, force, virial for lj pot */ -class LJ_pot_test : public testing::Test +class LJ_pot_test : public LjPotTestFixture { - protected: - ModuleBase::Vector3* force; - ModuleBase::matrix stress; - double potential; - int natom; - UnitCell ucell; - Input_para input; - - void SetUp() - { - Setcell::setupcell(ucell); - - natom = ucell.nat; - force = new ModuleBase::Vector3[natom]; - stress.create(3, 3); - - Setcell::parameters(input); - } - - void TearDown() - { - delete[] force; - } }; TEST_F(LJ_pot_test, potential) { - ModuleESolver::ESolver* p_esolver = new ModuleESolver::ESolver_LJ(); + std::unique_ptr p_esolver(new ModuleESolver::ESolver_LJ()); p_esolver->before_all_runners(ucell, input); - MD_func::force_virial(p_esolver, 0, ucell, potential, force, true, stress); + MD_func::force_virial(p_esolver.get(), 0, ucell, potential, force, true, stress); EXPECT_NEAR(potential, -0.011957818623534381, doublethreshold); } TEST_F(LJ_pot_test, force) { - ModuleESolver::ESolver* p_esolver = new ModuleESolver::ESolver_LJ(); + std::unique_ptr p_esolver(new ModuleESolver::ESolver_LJ()); p_esolver->before_all_runners(ucell, input); - MD_func::force_virial(p_esolver, 0, ucell, potential, force, true, stress); + MD_func::force_virial(p_esolver.get(), 0, ucell, potential, force, true, stress); EXPECT_NEAR(force[0].x, 0.00049817733089377704, doublethreshold); EXPECT_NEAR(force[0].y, 0.00082237246837022328, doublethreshold); EXPECT_NEAR(force[0].z, -3.0493186101154812e-20, doublethreshold); @@ -73,9 +50,9 @@ TEST_F(LJ_pot_test, force) TEST_F(LJ_pot_test, stress) { - ModuleESolver::ESolver* p_esolver = new ModuleESolver::ESolver_LJ(); + std::unique_ptr p_esolver(new ModuleESolver::ESolver_LJ()); p_esolver->before_all_runners(ucell, input); - MD_func::force_virial(p_esolver, 0, ucell, potential, force, true, stress); + MD_func::force_virial(p_esolver.get(), 0, ucell, potential, force, true, stress); EXPECT_NEAR(stress(0, 0), 8.0360222227631859e-07, doublethreshold); EXPECT_NEAR(stress(0, 1), 1.7207745586539077e-07, doublethreshold); EXPECT_NEAR(stress(0, 2), 0, doublethreshold); @@ -89,7 +66,7 @@ TEST_F(LJ_pot_test, stress) TEST_F(LJ_pot_test, RcutSearchRadius) { - ModuleESolver::ESolver_LJ* p_esolver = new ModuleESolver::ESolver_LJ(); + std::unique_ptr p_esolver(new ModuleESolver::ESolver_LJ()); ucell.ntype = 2; std::vector rcut = {3.0}; p_esolver->rcut_search_radius(ucell.ntype, rcut); @@ -114,7 +91,7 @@ TEST_F(LJ_pot_test, RcutSearchRadius) TEST_F(LJ_pot_test, SetC6C12) { - ModuleESolver::ESolver_LJ* p_esolver = new ModuleESolver::ESolver_LJ(); + std::unique_ptr p_esolver(new ModuleESolver::ESolver_LJ()); ucell.ntype = 2; // no rule @@ -187,7 +164,7 @@ TEST_F(LJ_pot_test, SetC6C12) TEST_F(LJ_pot_test, CalEnShift) { - ModuleESolver::ESolver_LJ* p_esolver = new ModuleESolver::ESolver_LJ(); + std::unique_ptr p_esolver(new ModuleESolver::ESolver_LJ()); ucell.ntype = 2; std::vector rcut = {3.0}; @@ -214,4 +191,4 @@ TEST_F(LJ_pot_test, CalEnShift) EXPECT_NEAR(p_esolver->en_shift(0, 1), -3.303688865319793e-07, doublethreshold); EXPECT_NEAR(p_esolver->en_shift(1, 0), -3.303688865319793e-07, doublethreshold); EXPECT_NEAR(p_esolver->en_shift(1, 1), -5.6443326024140752e-06, doublethreshold); -} \ No newline at end of file +} diff --git a/source/source_md/test/md_func_test.cpp b/source/source_md/test/md_func_test.cpp index eb9ce57a5f9..a9bae026fa8 100644 --- a/source/source_md/test/md_func_test.cpp +++ b/source/source_md/test/md_func_test.cpp @@ -5,9 +5,8 @@ #undef private #define private public #define protected public -#include "source_esolver/esolver_lj.h" +#include "md_test_fixture.h" #include "source_md/md_func.h" -#include "setcell.h" #define doublethreshold 1e-12 /************************************************ @@ -50,45 +49,8 @@ * - test the current_md_info function with an incorrect file path */ -class MD_func_test : public testing::Test +class MD_func_test : public MdFuncTestFixture { - protected: - UnitCell ucell; - double* allmass; // atom mass - ModuleBase::Vector3* pos; // atom position - ModuleBase::Vector3* vel; // atom velocity - ModuleBase::Vector3* ionmbl; // atom is frozen or not - ModuleBase::Vector3* force; // atom force - ModuleBase::matrix virial; // virial for this lattice - ModuleBase::matrix stress; // stress for this lattice - double potential; // potential energy - int natom; // atom number - double temperature; // temperature - int frozen_freedom; // frozen_freedom - Parameter param_in; - - void SetUp() - { - Setcell::setupcell(ucell); - Setcell::parameters(param_in.input); - natom = ucell.nat; - allmass = new double[natom]; - pos = new ModuleBase::Vector3[natom]; - ionmbl = new ModuleBase::Vector3[natom]; - vel = new ModuleBase::Vector3[natom]; - force = new ModuleBase::Vector3[natom]; - stress.create(3, 3); - virial.create(3, 3); - } - - void TearDown() - { - delete[] allmass; - delete[] pos; - delete[] vel; - delete[] ionmbl; - delete[] force; - } }; TEST_F(MD_func_test, gaussrand) diff --git a/source/source_md/test/md_test_fixture.h b/source/source_md/test/md_test_fixture.h new file mode 100644 index 00000000000..fccc19e96f0 --- /dev/null +++ b/source/source_md/test/md_test_fixture.h @@ -0,0 +1,111 @@ +#ifndef MD_TEST_FIXTURE_H +#define MD_TEST_FIXTURE_H + +#include "gtest/gtest.h" +#include "source_esolver/esolver_lj.h" +#include "source_io/module_parameter/parameter.h" +#include "source_md/md_base.h" +#include "setcell.h" + +#include +#include + +class MdTestBase : public testing::Test +{ + protected: + UnitCell ucell; + Parameter param_in; + std::unique_ptr p_esolver; + + void SetUp() override + { + Setcell::setupcell(ucell); + Setcell::parameters(param_in.input); + + p_esolver.reset(new ModuleESolver::ESolver_LJ()); + p_esolver->before_all_runners(ucell, param_in.inp); + } +}; + +template +class MdIntegratorFixture : public MdTestBase +{ + protected: + std::unique_ptr mdrun; + + void SetUp() override + { + MdTestBase::SetUp(); + mdrun.reset(new Integrator(param_in, ucell)); + mdrun->setup(p_esolver.get(), PARAM.sys.global_readin_dir); + } +}; + +class MdFuncTestFixture : public testing::Test +{ + protected: + UnitCell ucell; + std::vector allmass_store; + std::vector> pos_store; + std::vector> vel_store; + std::vector> ionmbl_store; + std::vector> force_store; + double* allmass = nullptr; + ModuleBase::Vector3* pos = nullptr; + ModuleBase::Vector3* vel = nullptr; + ModuleBase::Vector3* ionmbl = nullptr; + ModuleBase::Vector3* force = nullptr; + ModuleBase::matrix virial; + ModuleBase::matrix stress; + double potential = 0.0; + int natom = 0; + double temperature = 0.0; + int frozen_freedom = 0; + Parameter param_in; + + void SetUp() override + { + Setcell::setupcell(ucell); + Setcell::parameters(param_in.input); + natom = ucell.nat; + + allmass_store.resize(natom); + pos_store.resize(natom); + vel_store.resize(natom); + ionmbl_store.resize(natom); + force_store.resize(natom); + allmass = allmass_store.data(); + pos = pos_store.data(); + vel = vel_store.data(); + ionmbl = ionmbl_store.data(); + force = force_store.data(); + stress.create(3, 3); + virial.create(3, 3); + } +}; + +class LjPotTestFixture : public testing::Test +{ + protected: + std::vector> force_store; + ModuleBase::Vector3* force = nullptr; + ModuleBase::matrix stress; + double potential = 0.0; + int natom = 0; + UnitCell ucell; + Input_para input; + + void SetUp() override + { + Setcell::setupcell(ucell); + + natom = ucell.nat; + force_store.resize(natom); + force = force_store.data(); + stress.create(3, 3); + + Setcell::parameters(input); + } +}; + +#endif // MD_TEST_FIXTURE_H diff --git a/source/source_md/test/msst_test.cpp b/source/source_md/test/msst_test.cpp index 7d0fd8054d1..8b3982be73b 100644 --- a/source/source_md/test/msst_test.cpp +++ b/source/source_md/test/msst_test.cpp @@ -5,9 +5,8 @@ #undef private #define private public #define protected public -#include "source_esolver/esolver_lj.h" #include "source_md/msst.h" -#include "setcell.h" +#include "md_test_fixture.h" #define doublethreshold 1e-12 /************************************************ @@ -35,31 +34,8 @@ * - output MD information such as energy, temperature, and pressure */ -class MSST_test : public testing::Test +class MSST_test : public MdIntegratorFixture { - protected: - MD_base* mdrun; - UnitCell ucell; - Parameter param_in; - ModuleESolver::ESolver* p_esolver; - - void SetUp() - { - Setcell::setupcell(ucell); - Setcell::parameters(param_in.input); - - p_esolver = new ModuleESolver::ESolver_LJ(); - p_esolver->before_all_runners(ucell, param_in.inp); - - mdrun = new MSST(param_in, ucell); - mdrun->setup(p_esolver, PARAM.sys.global_readin_dir); - } - - void TearDown() - { - delete mdrun; - delete p_esolver; - } }; TEST_F(MSST_test, setup) @@ -208,7 +184,7 @@ TEST_F(MSST_test, restart) mdrun->restart(PARAM.sys.global_readin_dir); remove("Restart_md.txt"); - MSST* msst = dynamic_cast(mdrun); + MSST* msst = dynamic_cast(mdrun.get()); EXPECT_EQ(mdrun->step_rst_, 3); EXPECT_EQ(msst->omega[mdrun->mdp.msst_direction], -0.00977662); EXPECT_EQ(msst->e0, -0.00768262); diff --git a/source/source_md/test/nhchain_test.cpp b/source/source_md/test/nhchain_test.cpp index 647df0a730d..4f7a4244011 100644 --- a/source/source_md/test/nhchain_test.cpp +++ b/source/source_md/test/nhchain_test.cpp @@ -5,9 +5,8 @@ #undef private #define private public #define protected public -#include "source_esolver/esolver_lj.h" #include "source_md/nhchain.h" -#include "setcell.h" +#include "md_test_fixture.h" #define doublethreshold 1e-12 /************************************************ * unit test of functions in nhchain.h @@ -33,34 +32,20 @@ * - Nose_Hoover::print_md * - output MD information such as energy, temperature, and pressure */ -class NHC_test : public testing::Test +class NHC_test : public MdTestBase { protected: - MD_base* mdrun; - UnitCell ucell; - Parameter param_in; - ModuleESolver::ESolver* p_esolver; + std::unique_ptr mdrun; - void SetUp() + void SetUp() override { - Setcell::setupcell(ucell); - Setcell::parameters(param_in.input); - - p_esolver = new ModuleESolver::ESolver_LJ(); - p_esolver->before_all_runners(ucell, param_in.inp); - + MdTestBase::SetUp(); param_in.input.mdp.md_type = "npt"; param_in.input.mdp.md_pmode = "tri"; param_in.input.mdp.md_pfirst = 1; param_in.input.mdp.md_plast = 1; - mdrun = new Nose_Hoover(param_in, ucell); - mdrun->setup(p_esolver, PARAM.sys.global_readin_dir); - } - - void TearDown() - { - delete mdrun; - delete p_esolver; + mdrun.reset(new Nose_Hoover(param_in, ucell)); + mdrun->setup(p_esolver.get(), PARAM.sys.global_readin_dir); } }; @@ -179,7 +164,7 @@ TEST_F(NHC_test, restart) mdrun->restart(PARAM.sys.global_readin_dir); remove("Restart_md.txt"); - Nose_Hoover* nhc = dynamic_cast(mdrun); + Nose_Hoover* nhc = dynamic_cast(mdrun.get()); EXPECT_EQ(mdrun->step_rst_, 3); EXPECT_EQ(mdrun->mdp.md_tchain, 4); EXPECT_EQ(mdrun->mdp.md_pchain, 4); diff --git a/source/source_md/test/verlet_test.cpp b/source/source_md/test/verlet_test.cpp index 8f2c00f74f3..b16cd522756 100644 --- a/source/source_md/test/verlet_test.cpp +++ b/source/source_md/test/verlet_test.cpp @@ -5,9 +5,8 @@ #undef private #define private public #define protected public -#include "source_esolver/esolver_lj.h" #include "source_md/verlet.h" -#include "setcell.h" +#include "md_test_fixture.h" #define doublethreshold 1e-12 @@ -36,30 +35,8 @@ * - output MD information such as energy, temperature, and pressure */ -class Verlet_test : public testing::Test +class Verlet_test : public MdIntegratorFixture { - protected: - MD_base* mdrun; - UnitCell ucell; - Parameter param_in; - ModuleESolver::ESolver* p_esolver; - - void SetUp() - { - Setcell::setupcell(ucell); - Setcell::parameters(param_in.input); - - p_esolver = new ModuleESolver::ESolver_LJ(); - p_esolver->before_all_runners(ucell, param_in.inp); - - mdrun = new Verlet(param_in, ucell); - mdrun->setup(p_esolver, PARAM.sys.global_readin_dir); - } - - void TearDown() - { - delete mdrun; - } }; TEST_F(Verlet_test, setup) diff --git a/source/source_pw/module_pwdft/hamilt_pw.cpp b/source/source_pw/module_pwdft/hamilt_pw.cpp index 47fae9bcb96..d56061cbb25 100644 --- a/source/source_pw/module_pwdft/hamilt_pw.cpp +++ b/source/source_pw/module_pwdft/hamilt_pw.cpp @@ -1,16 +1,17 @@ #include "hamilt_pw.h" +#include "source_io/module_parameter/parameter.h" +#include "source_base/global_function.h" +#include "source_base/global_variable.h" +#include "source_base/parallel_reduce.h" + +#include "op_pw_veff.h" #include "op_pw_ekin.h" -#include "op_pw_exx.h" #include "op_pw_meta.h" #include "op_pw_nl.h" #include "op_pw_proj.h" -#include "op_pw_veff.h" -#include "source_base/global_function.h" -#include "source_base/global_variable.h" -#include "source_base/parallel_reduce.h" +#include "op_pw_exx.h" #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info -#include "source_io/module_parameter/parameter.h" namespace hamilt { @@ -21,8 +22,7 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, K_Vectors* pkv, pseudopot_cell_vnl* nlpp, Plus_U* p_dftu, // mohan add 2025-11-06 - const UnitCell* ucell) - : ucell(ucell) + const UnitCell* ucell): ucell(ucell) { this->classname = "HamiltPW"; this->ppcell = nlpp; @@ -39,7 +39,7 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, // Operator* ekinetic = new Ekinetic> Operator* ekinetic = new Ekinetic>(tpiba2, gk2, wfc_basis->nks, wfc_basis->npwk_max); - if (this->ops == nullptr) + if(this->ops == nullptr) { this->ops = ekinetic; } @@ -59,7 +59,7 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, { pot_register_in.push_back("hartree"); } - // no variable can choose xc, maybe it is necessary + //no variable can choose xc, maybe it is necessary pot_register_in.push_back("xc"); if (PARAM.inp.imp_sol) { @@ -78,21 +78,20 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, pot_register_in.push_back("ml_exx"); } // DFT-1/2 - if (PARAM.inp.dfthalf_type == 1) - { + if (PARAM.inp.dfthalf_type == 1) { pot_register_in.push_back("dfthalf"); } - // only Potential is not empty, Veff and Meta are available - if (pot_register_in.size() > 0) + //only Potential is not empty, Veff and Meta are available + if(pot_register_in.size()>0) { - // register Potential by gathered operator + //register Potential by gathered operator pot_in->pot_register(pot_register_in); Operator* veff = new Veff>(isk, pot_in->get_veff_smooth_data(), pot_in->get_veff_smooth().nr, pot_in->get_veff_smooth().nc, wfc_basis); - if (this->ops == nullptr) + if(this->ops == nullptr) { this->ops = veff; } @@ -111,8 +110,9 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, } if (PARAM.inp.vnl_in_h) { - Operator* nonlocal = new Nonlocal>(isk, this->ppcell, ucell, wfc_basis); - if (this->ops == nullptr) + Operator* nonlocal + = new Nonlocal>(isk, this->ppcell, ucell, wfc_basis); + if(this->ops == nullptr) { this->ops = nonlocal; } @@ -121,13 +121,11 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, this->ops->add(nonlocal); } } - if (PARAM.inp.sc_mag_switch || PARAM.inp.dft_plus_u) + if(PARAM.inp.sc_mag_switch || PARAM.inp.dft_plus_u) { - Operator* onsite_proj = new OnsiteProj>(isk, - ucell, - p_dftu, - PARAM.inp.sc_mag_switch, - (PARAM.inp.dft_plus_u > 0)); + Operator* onsite_proj + = new OnsiteProj>(isk, ucell, p_dftu, + PARAM.inp.sc_mag_switch, (PARAM.inp.dft_plus_u>0)); this->ops->add(onsite_proj); } if (GlobalC::exx_info.info_global.cal_exx) @@ -146,21 +144,97 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, return; } -template +template HamiltPW::~HamiltPW() { - if (this->ops != nullptr) + if(this->ops!= nullptr) { delete this->ops; } } -template +template void HamiltPW::updateHk(const int ik) { - ModuleBase::TITLE("HamiltPW", "updateHk"); + ModuleBase::TITLE("HamiltPW","updateHk"); this->ops->init(ik); - ModuleBase::TITLE("HamiltPW", "updateHk"); + ModuleBase::TITLE("HamiltPW","updateHk"); +} + +template +template +HamiltPW::HamiltPW(const HamiltPW *hamilt) +{ + this->classname = hamilt->classname; + this->ppcell = hamilt->ppcell; + this->qq_nt = hamilt->qq_nt; + this->qq_so = hamilt->qq_so; + this->vkb = hamilt->vkb; + OperatorPW, Device_in> * node = + reinterpret_cast, Device_in> *>(hamilt->ops); + + while(node != nullptr) { + if (node->classname == "Ekinetic") { + Operator* ekinetic = + new Ekinetic>( + reinterpret_cast>*>(node)); + if(this->ops == nullptr) { + this->ops = ekinetic; + } + else { + this->ops->add(ekinetic); + } + // this->ops = reinterpret_cast*>(node); + } + else if (node->classname == "Nonlocal") { + Operator* nonlocal = + new Nonlocal>( + reinterpret_cast>*>(node)); + if(this->ops == nullptr) { + this->ops = nonlocal; + } + else { + this->ops->add(nonlocal); + } + } + else if (node->classname == "Veff") { + Operator* veff = + new Veff>( + reinterpret_cast>*>(node)); + if(this->ops == nullptr) { + this->ops = veff; + } + else { + this->ops->add(veff); + } + } + else if (node->classname == "Meta") { + Operator* meta = + new Meta>( + reinterpret_cast>*>(node)); + if(this->ops == nullptr) { + this->ops = meta; + } + else { + this->ops->add(meta); + } + } + else if (node->classname == "OnsiteProj") { + Operator* onsite_proj = + new OnsiteProj>( + reinterpret_cast>*>(node)); + if(this->ops == nullptr) { + this->ops = onsite_proj; + } + else { + this->ops->add(onsite_proj); + } + } + else { + ModuleBase::WARNING_QUIT("HamiltPW", "Unrecognized Operator type!"); + } + node = reinterpret_cast, Device_in> *>(node->next_op); + } } // This routine applies the S matrix to m wavefunctions psi and puts @@ -316,8 +390,8 @@ void HamiltPW::sPsi(const T* psi_in, // psi } } -template -void HamiltPW::set_exx_helper(Exx_Helper& exx_helper) +template +void HamiltPW::set_exx_helper(Exx_Helper &exx_helper) { auto op = this->ops; while (op != nullptr) @@ -326,6 +400,7 @@ void HamiltPW::set_exx_helper(Exx_Helper& exx_helper) { exx_helper.op_exx = reinterpret_cast*>(op); exx_helper.set_op(); + } op = op->next_op; } diff --git a/source/source_pw/module_pwdft/hamilt_pw.h b/source/source_pw/module_pwdft/hamilt_pw.h index 958f82c8e3b..0aa878f26c1 100644 --- a/source/source_pw/module_pwdft/hamilt_pw.h +++ b/source/source_pw/module_pwdft/hamilt_pw.h @@ -1,15 +1,15 @@ #ifndef HAMILTPW_H #define HAMILTPW_H -#include "source_base/kernels/math_kernel_op.h" #include "source_base/macros.h" #include "source_cell/klist.h" -#include "source_esolver/esolver_ks_pw.h" #include "source_estate/module_pot/potential_new.h" +#include "source_esolver/esolver_ks_pw.h" #include "source_hamilt/hamilt.h" -#include "source_lcao/module_dftu/dftu.h" // mohan add 2025-11-06 -#include "source_pw/module_pwdft/exx_helper.h" #include "source_pw/module_pwdft/vnl_pw.h" +#include "source_base/kernels/math_kernel_op.h" +#include "source_pw/module_pwdft/exx_helper.h" +#include "source_lcao/module_dftu/dftu.h" // mohan add 2025-11-06 namespace hamilt { @@ -24,12 +24,16 @@ class HamiltPW : public Hamilt using Real = typename GetTypeReal::type; public: - HamiltPW(elecstate::Potential* pot_in, - ModulePW::PW_Basis_K* wfc_basis, - K_Vectors* p_kv, - pseudopot_cell_vnl* nlpp, - Plus_U* p_dftu, // mohan add 2025-11-06 - const UnitCell* ucell); + + HamiltPW(elecstate::Potential* pot_in, + ModulePW::PW_Basis_K* wfc_basis, + K_Vectors* p_kv, + pseudopot_cell_vnl* nlpp, + Plus_U *p_dftu, // mohan add 2025-11-06 + const UnitCell* ucell); + + template + explicit HamiltPW(const HamiltPW* hamilt); ~HamiltPW(); @@ -45,7 +49,7 @@ class HamiltPW : public Hamilt void set_exx_helper(Exx_Helper& exx_helper_in); - protected: +protected: // used in sPhi, which are calculated in hPsi or sPhi const pseudopot_cell_vnl* ppcell = nullptr; const UnitCell* const ucell = nullptr;