diff --git a/clu/checkpoint.py b/clu/checkpoint.py index 0439df1..48ccf9c 100644 --- a/clu/checkpoint.py +++ b/clu/checkpoint.py @@ -324,7 +324,7 @@ def save(self, state) -> str: assert self.current_checkpoint == next_checkpoint, ( "Expected next_checkpoint to match .current_checkpoint: " f"{next_checkpoint} != {self.current_checkpoint}") - return self.current_checkpoint + return self.current_checkpoint # pyrefly: ignore[bad-return] @utils.logged_with("Checkpoint.restore_or_initialize()") def restore_or_initialize(self, state: T) -> T: @@ -531,7 +531,7 @@ def get_latest_checkpoint_to_restore_from(self) -> Optional[str]: logging.info( "Checked checkpoint base_directories: %s - common_numbers=%s " "- exclusive_numbers=%s", base_directories, common_numbers, - all_numbers.difference(common_numbers)) + all_numbers.difference(common_numbers)) # pyrefly: ignore[bad-argument-type] if not common_numbers: return None highest_number = sorted(common_numbers)[-1] diff --git a/clu/checkpoint_test.py b/clu/checkpoint_test.py index 47b120c..9d04c8f 100644 --- a/clu/checkpoint_test.py +++ b/clu/checkpoint_test.py @@ -27,7 +27,7 @@ def _make_dataset(): inputs = tf.range(10.)[:, None] labels = inputs * 5. + tf.range(5.)[None, :] features = dict(x=inputs, y=labels) - return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2) + return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2) # pyrefly: ignore[bad-argument-type] @flax.struct.dataclass @@ -222,7 +222,7 @@ def test_ignores_incomplete_checkpoint(self): self.assertEqual(state.step, 1) state = TrainState(step=2) # Failed save : step=2 is stored, but TensorFlow checkpoint fails. - ckpt.tf_checkpoint_manager.save = None + ckpt.tf_checkpoint_manager.save = None # pyrefly: ignore[bad-assignment] with self.assertRaisesRegex(TypeError, r"'NoneType' object is not callable"): ckpt.save(state) @@ -283,7 +283,7 @@ def test_overwrite(self): self.assertEqual(state.step, 1) self.assertEqual(tf_step.numpy(), 1) checkpoint_info = checkpoint.CheckpointInfo.from_path( - ckpt.current_checkpoint) + ckpt.current_checkpoint) # pyrefly: ignore[bad-argument-type] # Stores steps 2, 3, 4, 5 for _ in range(4): tf_step.assign_add(1) @@ -360,7 +360,7 @@ def test_synchronize_multiple_hosts(self, process_index_mock): def test_preemption(self): multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test") state = TrainState(step=1) - state0 = state.replace(step=0) + state0 = state.replace(step=0) # pyrefly: ignore[missing-attribute] ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0) ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1) # Initialize both at step=1. diff --git a/clu/data/dataset_iterator.py b/clu/data/dataset_iterator.py index 23f552e..660bd3d 100644 --- a/clu/data/dataset_iterator.py +++ b/clu/data/dataset_iterator.py @@ -162,14 +162,14 @@ def __init__(self, dataset, *, checkpoint: bool): "depend on //third_party/py/tensorflow.") from e self._tf = tf - if not isinstance(dataset, tf.data.Dataset): + if not isinstance(dataset, tf.data.Dataset): # pyrefly: ignore[missing-attribute] raise ValueError("`dataset` must be an instance of `tf.data.Dataset` " f"but got {type(dataset)}.") self._dataset = dataset self._checkpoint = checkpoint assert self.element_spec # Verify element spec. self.iterator = iter(dataset) - self._ckpt = tf.train.Checkpoint(ds=self.iterator) + self._ckpt = tf.train.Checkpoint(ds=self.iterator) # pyrefly: ignore[missing-attribute] def get_next(self) -> Element: return next(self) @@ -179,7 +179,7 @@ def __next__(self) -> Element: def reset(self): self.iterator = iter(self._dataset) - self._ckpt = self._tf.train.Checkpoint(ds=self.iterator) + self._ckpt = self._tf.train.Checkpoint(ds=self.iterator) # pyrefly: ignore[missing-attribute] @property def element_spec(self) -> ElementSpec: @@ -189,7 +189,7 @@ def element_spec(self) -> ElementSpec: f"{element_spec}.") invalid_features = [ k for k, v in element_spec.items() - if not isinstance(v, self._tf.TensorSpec) + if not isinstance(v, self._tf.TensorSpec) # pyrefly: ignore[missing-attribute] ] if invalid_features: raise ValueError(f"Features {invalid_features} are not tensors. Dataset " diff --git a/clu/deterministic_data.py b/clu/deterministic_data.py index a4da5cd..ecf9465 100644 --- a/clu/deterministic_data.py +++ b/clu/deterministic_data.py @@ -429,7 +429,7 @@ def create_dataset(dataset_builder: DatasetBuilder, if isinstance(rng, tf.Tensor): rngs = [x.numpy() for x in tf.random.experimental.stateless_split(rng, 3)] else: - rngs = list(jax.random.key_data(jax.random.split(rng, 3))) + rngs = list(jax.random.key_data(jax.random.split(rng, 3))) # pyrefly: ignore[bad-argument-type] else: rngs = 3 * [[None, None]] @@ -458,7 +458,7 @@ def create_dataset(dataset_builder: DatasetBuilder, if preprocess_fn is not None: if rng_available: - ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop()) + ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop()) # pyrefly: ignore[bad-argument-type] else: ds = ds.map(preprocess_fn, num_parallel_calls=AUTOTUNE) diff --git a/clu/metric_writers/logging_writer.py b/clu/metric_writers/logging_writer.py index 50fdcc3..c3ac2f1 100644 --- a/clu/metric_writers/logging_writer.py +++ b/clu/metric_writers/logging_writer.py @@ -75,7 +75,7 @@ def write_histograms(self, if histo is not None: logging.info("[%d]%s Histogram for %r = {%s}", step, self._collection_str, key, - _get_histogram_as_string(histo, bins)) + _get_histogram_as_string(histo, bins)) # pyrefly: ignore[bad-argument-type] def write_pointcloud( self, diff --git a/clu/metric_writers/utils.py b/clu/metric_writers/utils.py index d675808..760946b 100644 --- a/clu/metric_writers/utils.py +++ b/clu/metric_writers/utils.py @@ -148,7 +148,7 @@ def create_default_writer( logdir = epath.Path(logdir) if collection is not None: logdir /= collection - writers.append(SummaryWriter(os.fspath(logdir))) + writers.append(SummaryWriter(os.fspath(logdir))) # pyrefly: ignore[bad-argument-type] if asynchronous: return AsyncMultiWriter(writers) return MultiWriter(writers) diff --git a/clu/metrics.py b/clu/metrics.py index 1f2ad97..8057349 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -278,7 +278,7 @@ class MultiHeadMetrics(metrics.Collection): """ @flax.struct.dataclass - class FromFun(cls): + class FromFun(cls): # pyrefly: ignore[invalid-inheritance] """Wrapper Metric class that collects output after applying `fun`.""" @classmethod @@ -336,7 +336,7 @@ class Metrics(Collection): """ @flax.struct.dataclass - class FromOutput(cls): + class FromOutput(cls): # pyrefly: ignore[invalid-inheritance] """Wrapper Metric class that collects output named `name`.""" @classmethod @@ -445,7 +445,7 @@ def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]: """Returns a metric class that collects all model outputs named `names`.""" @flax.struct.dataclass - class FromOutputs(cls): # pylint:disable=missing-class-docstring + class FromOutputs(cls): # pylint:disable=missing-class-docstring # pyrefly: ignore[invalid-inheritance] @classmethod def from_model_output(cls: type[M], **model_output) -> M: @@ -456,7 +456,7 @@ def make_array(value): # Can't jnp.concatenate() scalars, promote to shape=(1,) in that case. return value[None] if value.ndim == 0 else value - return cls({name: (make_array(model_output[name]),) for name in names}) + return cls({name: (make_array(model_output[name]),) for name in names}) # pyrefly: ignore[bad-argument-count] return FromOutputs @@ -531,7 +531,7 @@ class MyMetrics(metrics.Collection): Returns: A subclass of Collection with fields defined by provided `metrics`. """ - return flax.struct.dataclass( + return flax.struct.dataclass( # pyrefly: ignore[bad-return] type("_InlineCollection", (Collection,), {"__annotations__": metrics})) @classmethod @@ -706,9 +706,9 @@ class LastValue(Metric): def __init__( # pytype: disable=missing-parameter # jnp-array self, - total: jnp.ndarray | _default = _default, - count: jnp.ndarray | _default = _default, - value: jnp.ndarray | _default = _default, + total: jnp.ndarray | _default = _default, # pyrefly: ignore[not-a-type] + count: jnp.ndarray | _default = _default, # pyrefly: ignore[not-a-type] + value: jnp.ndarray | _default = _default, # pyrefly: ignore[not-a-type] ): """Backward compatibility constructor. @@ -913,7 +913,7 @@ class Accuracy(Average): """ @classmethod - def from_model_output( + def from_model_output( # pyrefly: ignore[bad-override] cls, *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs ) -> Accuracy: if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: diff --git a/clu/metrics_test.py b/clu/metrics_test.py index 91aadfe..d138266 100644 --- a/clu/metrics_test.py +++ b/clu/metrics_test.py @@ -30,7 +30,7 @@ @flax.struct.dataclass class CollectingMetricAccuracy( - metrics.CollectingMetric.from_outputs(("logits", "labels"))): + metrics.CollectingMetric.from_outputs(("logits", "labels"))): # pyrefly: ignore[invalid-inheritance] def compute(self): values = super().compute() @@ -44,7 +44,7 @@ def compute(self): @flax.struct.dataclass class Collection(metrics.Collection): train_accuracy: metrics.Accuracy - learning_rate: metrics.LastValue.from_output("learning_rate") + learning_rate: metrics.LastValue.from_output("learning_rate") # pyrefly: ignore[invalid-annotation] @flax.struct.dataclass @@ -217,12 +217,12 @@ def accuracy(*, logits, labels, **_): chex.assert_trees_all_close( self.make_compute_metric( - metrics.Average.from_fun(accuracy), + metrics.Average.from_fun(accuracy), # pyrefly: ignore[bad-argument-type] reduce=False)(self.model_outputs), self.results["train_accuracy"]) chex.assert_trees_all_close( self.make_compute_metric( - metrics.Average.from_fun(accuracy), + metrics.Average.from_fun(accuracy), # pyrefly: ignore[bad-argument-type] reduce=False)(self.model_outputs_masked), self.results_masked["train_accuracy"]) @@ -237,12 +237,12 @@ def make_accuracy_args_map(*, logits, labels, **_): chex.assert_trees_all_close( self.make_compute_metric( - metrics.Accuracy.from_fun(make_accuracy_args_map), + metrics.Accuracy.from_fun(make_accuracy_args_map), # pyrefly: ignore[bad-argument-type] reduce=False)(self.model_outputs), self.results["train_accuracy"]) chex.assert_trees_all_close( self.make_compute_metric( - metrics.Accuracy.from_fun(make_accuracy_args_map), + metrics.Accuracy.from_fun(make_accuracy_args_map), # pyrefly: ignore[bad-argument-type] reduce=False)(self.model_outputs_masked), self.results_masked["train_accuracy"]) @@ -356,8 +356,8 @@ def with_head2(logits, labels, mask, head2_mask, **_): return dict(logits=logits, labels=labels, mask=head2_mask & mask) collection = metrics.Collection.create( - head1_accuracy=metrics.Accuracy.from_fun(with_head1), - head2_accuracy=metrics.Accuracy.from_fun(with_head2) + head1_accuracy=metrics.Accuracy.from_fun(with_head1), # pyrefly: ignore[bad-argument-type] + head2_accuracy=metrics.Accuracy.from_fun(with_head2) # pyrefly: ignore[bad-argument-type] ) chex.assert_trees_all_close( diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 13add0b..a3de0ab 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -72,7 +72,7 @@ def flatten_dict( nested_key = f"{prefix}{delimiter}{key}" if prefix else key if isinstance(value, (dict, flax.core.FrozenDict)): output_dict.update( - flatten_dict(value, prefix=nested_key, delimiter=delimiter) + flatten_dict(value, prefix=nested_key, delimiter=delimiter) # pyrefly: ignore[bad-argument-type] ) else: output_dict[nested_key] = value @@ -81,13 +81,13 @@ def flatten_dict( def _count_parameters(params: _ParamsContainer) -> int: """Returns the count of variables for the module or parameter dictionary.""" - params = flatten_dict(params) + params = flatten_dict(params) # pyrefly: ignore[bad-argument-type] return sum(np.prod(v.shape) for v in params.values() if v is not None) def _parameters_size(params: _ParamsContainer) -> int: """Returns total size (bytes) for the module or parameter dictionary.""" - params = flatten_dict(params) + params = flatten_dict(params) # pyrefly: ignore[bad-argument-type] return sum( np.prod(v.shape) * v.dtype.itemsize for v in params.values() @@ -177,7 +177,7 @@ def _get_parameter_rows( f"Expected `params` to be a dictionary but got {type(params)}" ) - params = flatten_dict(params) + params = flatten_dict(params) # pyrefly: ignore[bad-argument-type] if params: names, values = map(list, tuple(zip(*sorted(params.items())))) else: diff --git a/clu/periodic_actions.py b/clu/periodic_actions.py index 7d1e4a8..21090aa 100644 --- a/clu/periodic_actions.py +++ b/clu/periodic_actions.py @@ -91,17 +91,17 @@ def __init__(self, self._every_secs = every_secs self._on_steps = set(on_steps or []) # Step and timestamp for the last time the action triggered. - self._previous_step: int = None - self._previous_time: float = None + self._previous_step: int = None # pyrefly: ignore[bad-assignment] + self._previous_time: float = None # pyrefly: ignore[bad-assignment] # Just for checking that __call__() was called every step. - self._last_step: int = None + self._last_step: int = None # pyrefly: ignore[bad-assignment] def _init_and_check(self, step: int, t: float): """Initializes and checks it was called at every step.""" if self._previous_step is None: - self._previous_step = step - self._previous_time = t - self._last_step = step + self._previous_step = step # pyrefly: ignore[bad-assignment] + self._previous_time = t # pyrefly: ignore[bad-assignment] + self._last_step = step # pyrefly: ignore[bad-assignment] elif self._every_steps is not None and step - self._last_step != 1: raise ValueError(f"PeriodicAction must be called after every step once " f"(every_steps={self._every_steps}, " @@ -350,7 +350,7 @@ def __init__( def _should_trigger(self, step: int, t: float) -> bool: if self._session_running: # If a session is running we only check if we should stop it. - dt = t - self._session_started + dt = t - self._session_started # pyrefly: ignore[unsupported-operation] cond = (not self._profile_duration_ms or dt * 1e3 >= self._profile_duration_ms) cond &= (not self._num_profile_steps or diff --git a/clu/preprocess_spec.py b/clu/preprocess_spec.py index 1649a45..55c040c 100644 --- a/clu/preprocess_spec.py +++ b/clu/preprocess_spec.py @@ -124,7 +124,7 @@ def __call__(self, features: D) -> D: "switch to grain.tensorflow.MapTransform.") if isinstance(features, tf.data.Dataset): return features.map(self._transform, num_parallel_calls=tf.data.AUTOTUNE) - return self._transform(features) + return self._transform(features) # pyrefly: ignore[bad-return] @abc.abstractmethod def _transform(self, features: FlatFeatures) -> FlatFeatures: @@ -154,7 +154,7 @@ def __call__(self, features: D) -> D: next_seed, seed = tf.unstack( tf.random.experimental.stateless_split(features.pop(SEED_KEY))) - features = self._transform(features, seed) + features = self._transform(features, seed) # pyrefly: ignore[bad-assignment] features[SEED_KEY] = next_seed return features @@ -341,10 +341,10 @@ def _parse_single_preprocess_op( args = [ast.literal_eval(arg) for arg in expr.args] kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords} if not args: - return op_class(**kwargs) + return op_class(**kwargs) # pyrefly: ignore[bad-unpacking] # Translate positional arguments into keyword arguments. - available_arg_names = [f.name for f in dataclasses.fields(op_class)] + available_arg_names = [f.name for f in dataclasses.fields(op_class)] # pyrefly: ignore[bad-argument-type] for i, arg in enumerate(args): name = available_arg_names[i] if name in kwargs: @@ -353,7 +353,7 @@ def _parse_single_preprocess_op( f"(value: {arg}) and keyword argument (value: {kwargs[name]}).") kwargs[name] = arg - return op_class(**kwargs) + return op_class(**kwargs) # pyrefly: ignore[bad-unpacking] def parse(spec: str, @@ -361,12 +361,12 @@ def parse(spec: str, *, only_jax_types: bool = True) -> PreprocessFn: """Parses a preprocess spec; a '|' separated list of preprocess ops.""" - available_ops = dict(available_ops) + available_ops = dict(available_ops) # pyrefly: ignore[bad-assignment] if not spec.strip(): ops = [] else: ops = [ - _parse_single_preprocess_op(s, available_ops) for s in spec.split("|") + _parse_single_preprocess_op(s, available_ops) for s in spec.split("|") # pyrefly: ignore[bad-argument-type] ] return PreprocessFn(ops, only_jax_types=only_jax_types)