diff --git a/Modules/Cluster.py b/Modules/Cluster.py index 0664914b..fe2a7f99 100644 --- a/Modules/Cluster.py +++ b/Modules/Cluster.py @@ -333,6 +333,30 @@ def __setattr__(self, name, value): super(Cluster, self).__setattr__(name, value) + def __getstate__(self): + """ + Return the picklable state of the cluster. + + The thread lock created by compute_ensemble_batch cannot be pickled, + so it is dropped here. This allows sscha.Utilities.save_binary to + store objects holding a cluster after a calculation has run. + """ + state = self.__dict__.copy() + state["lock"] = None + return state + + + def __setstate__(self, state): + """ + Restore the cluster from a pickled state. + + The thread lock is transient runtime state and is reset to None, + as after __init__; compute_ensemble_batch recreates it when needed. + """ + state["lock"] = None + self.__dict__.update(state) + + def copy_file(self, source, destination, server_source = False, server_dest = True, raise_error=False, **kwargs): """ diff --git a/tests/test_save_binary/test_save_binary.py b/tests/test_save_binary/test_save_binary.py new file mode 100644 index 00000000..cf777bc7 --- /dev/null +++ b/tests/test_save_binary/test_save_binary.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function +from __future__ import division + +import os +import tempfile +import threading + +import cellconstructor as CC +import cellconstructor.Phonons + +import sscha +import sscha.Cluster +import sscha.Ensemble +import sscha.Relax +import sscha.SchaMinimizer +import sscha.Utilities + +""" +Regression test for issue #114: save_binary failed with +TypeError: cannot pickle '_thread.lock' object +when the relax object holds a cluster that already ran a calculation +(Cluster.compute_ensemble_batch stores a threading.Lock on the cluster). +""" + + +def test_save_binary_relax_with_cluster(verbose=False): + total_path = os.path.dirname(os.path.abspath(__file__)) + os.chdir(total_path) + + DATA_PATH = "../../Examples/ensemble_data_test/" + + dyn = CC.Phonons.Phonons(os.path.join(DATA_PATH, "dyn")) + + ens = sscha.Ensemble.Ensemble(dyn, 0, dyn.GetSupercell()) + ens.load(DATA_PATH, 2, 10) + + minim = sscha.SchaMinimizer.SSCHA_Minimizer(ens) + + cluster = sscha.Cluster.Cluster(hostname="localhost") + relax = sscha.Relax.SSCHA(minim, N_configs=10, max_pop=2, + cluster=cluster) + + # Cluster.compute_ensemble_batch leaves a threading.Lock on the + # cluster after the ensemble calculation; reproduce that state. + relax.cluster.lock = threading.Lock() + + with tempfile.TemporaryDirectory() as tmpdir: + filename = os.path.join(tmpdir, "relax.bin") + sscha.Utilities.save_binary(relax, filename) + + loaded = sscha.Utilities.load_binary(filename) + + # The lock is transient runtime state and must come back unset. + assert loaded.cluster.lock is None + assert loaded.cluster.hostname == "localhost" + assert loaded.N_configs == relax.N_configs + assert loaded.minim.ensemble.N == ens.N + + if verbose: + print("save_binary/load_binary round trip succeeded") + + +if __name__ == "__main__": + test_save_binary_relax_with_cluster(True)