Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clu/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def save(self, state) -> str:
assert self.current_checkpoint == next_checkpoint, (
"Expected next_checkpoint to match .current_checkpoint: "
f"{next_checkpoint} != {self.current_checkpoint}")
return self.current_checkpoint
return self.current_checkpoint # pyrefly: ignore[bad-return]

@utils.logged_with("Checkpoint.restore_or_initialize()")
def restore_or_initialize(self, state: T) -> T:
Expand Down Expand Up @@ -531,7 +531,7 @@ def get_latest_checkpoint_to_restore_from(self) -> Optional[str]:
logging.info(
"Checked checkpoint base_directories: %s - common_numbers=%s "
"- exclusive_numbers=%s", base_directories, common_numbers,
all_numbers.difference(common_numbers))
all_numbers.difference(common_numbers)) # pyrefly: ignore[bad-argument-type]
if not common_numbers:
return None
highest_number = sorted(common_numbers)[-1]
Expand Down
8 changes: 4 additions & 4 deletions clu/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _make_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
features = dict(x=inputs, y=labels)
return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2)
return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2) # pyrefly: ignore[bad-argument-type]


@flax.struct.dataclass
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_ignores_incomplete_checkpoint(self):
self.assertEqual(state.step, 1)
state = TrainState(step=2)
# Failed save : step=2 is stored, but TensorFlow checkpoint fails.
ckpt.tf_checkpoint_manager.save = None
ckpt.tf_checkpoint_manager.save = None # pyrefly: ignore[bad-assignment]
with self.assertRaisesRegex(TypeError,
r"'NoneType' object is not callable"):
ckpt.save(state)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_overwrite(self):
self.assertEqual(state.step, 1)
self.assertEqual(tf_step.numpy(), 1)
checkpoint_info = checkpoint.CheckpointInfo.from_path(
ckpt.current_checkpoint)
ckpt.current_checkpoint) # pyrefly: ignore[bad-argument-type]
# Stores steps 2, 3, 4, 5
for _ in range(4):
tf_step.assign_add(1)
Expand Down Expand Up @@ -360,7 +360,7 @@ def test_synchronize_multiple_hosts(self, process_index_mock):
def test_preemption(self):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
state0 = state.replace(step=0)
state0 = state.replace(step=0) # pyrefly: ignore[missing-attribute]
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
# Initialize both at step=1.
Expand Down
8 changes: 4 additions & 4 deletions clu/data/dataset_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ def __init__(self, dataset, *, checkpoint: bool):
"depend on //third_party/py/tensorflow.") from e
self._tf = tf

if not isinstance(dataset, tf.data.Dataset):
if not isinstance(dataset, tf.data.Dataset): # pyrefly: ignore[missing-attribute]
raise ValueError("`dataset` must be an instance of `tf.data.Dataset` "
f"but got {type(dataset)}.")
self._dataset = dataset
self._checkpoint = checkpoint
assert self.element_spec # Verify element spec.
self.iterator = iter(dataset)
self._ckpt = tf.train.Checkpoint(ds=self.iterator)
self._ckpt = tf.train.Checkpoint(ds=self.iterator) # pyrefly: ignore[missing-attribute]

def get_next(self) -> Element:
return next(self)
Expand All @@ -179,7 +179,7 @@ def __next__(self) -> Element:

def reset(self):
self.iterator = iter(self._dataset)
self._ckpt = self._tf.train.Checkpoint(ds=self.iterator)
self._ckpt = self._tf.train.Checkpoint(ds=self.iterator) # pyrefly: ignore[missing-attribute]

@property
def element_spec(self) -> ElementSpec:
Expand All @@ -189,7 +189,7 @@ def element_spec(self) -> ElementSpec:
f"{element_spec}.")
invalid_features = [
k for k, v in element_spec.items()
if not isinstance(v, self._tf.TensorSpec)
if not isinstance(v, self._tf.TensorSpec) # pyrefly: ignore[missing-attribute]
]
if invalid_features:
raise ValueError(f"Features {invalid_features} are not tensors. Dataset "
Expand Down
4 changes: 2 additions & 2 deletions clu/deterministic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def create_dataset(dataset_builder: DatasetBuilder,
if isinstance(rng, tf.Tensor):
rngs = [x.numpy() for x in tf.random.experimental.stateless_split(rng, 3)]
else:
rngs = list(jax.random.key_data(jax.random.split(rng, 3)))
rngs = list(jax.random.key_data(jax.random.split(rng, 3))) # pyrefly: ignore[bad-argument-type]
else:
rngs = 3 * [[None, None]]

Expand Down Expand Up @@ -458,7 +458,7 @@ def create_dataset(dataset_builder: DatasetBuilder,

if preprocess_fn is not None:
if rng_available:
ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop())
ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop()) # pyrefly: ignore[bad-argument-type]
else:
ds = ds.map(preprocess_fn, num_parallel_calls=AUTOTUNE)

Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/logging_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def write_histograms(self,
if histo is not None:
logging.info("[%d]%s Histogram for %r = {%s}", step,
self._collection_str, key,
_get_histogram_as_string(histo, bins))
_get_histogram_as_string(histo, bins)) # pyrefly: ignore[bad-argument-type]

def write_pointcloud(
self,
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def create_default_writer(
logdir = epath.Path(logdir)
if collection is not None:
logdir /= collection
writers.append(SummaryWriter(os.fspath(logdir)))
writers.append(SummaryWriter(os.fspath(logdir))) # pyrefly: ignore[bad-argument-type]
if asynchronous:
return AsyncMultiWriter(writers)
return MultiWriter(writers)
18 changes: 9 additions & 9 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class MultiHeadMetrics(metrics.Collection):
"""

@flax.struct.dataclass
class FromFun(cls):
class FromFun(cls): # pyrefly: ignore[invalid-inheritance]
"""Wrapper Metric class that collects output after applying `fun`."""

@classmethod
Expand Down Expand Up @@ -336,7 +336,7 @@ class Metrics(Collection):
"""

@flax.struct.dataclass
class FromOutput(cls):
class FromOutput(cls): # pyrefly: ignore[invalid-inheritance]
"""Wrapper Metric class that collects output named `name`."""

@classmethod
Expand Down Expand Up @@ -445,7 +445,7 @@ def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]:
"""Returns a metric class that collects all model outputs named `names`."""

@flax.struct.dataclass
class FromOutputs(cls): # pylint:disable=missing-class-docstring
class FromOutputs(cls): # pylint:disable=missing-class-docstring # pyrefly: ignore[invalid-inheritance]

@classmethod
def from_model_output(cls: type[M], **model_output) -> M:
Expand All @@ -456,7 +456,7 @@ def make_array(value):
# Can't jnp.concatenate() scalars, promote to shape=(1,) in that case.
return value[None] if value.ndim == 0 else value

return cls({name: (make_array(model_output[name]),) for name in names})
return cls({name: (make_array(model_output[name]),) for name in names}) # pyrefly: ignore[bad-argument-count]

return FromOutputs

Expand Down Expand Up @@ -531,7 +531,7 @@ class MyMetrics(metrics.Collection):
Returns:
A subclass of Collection with fields defined by provided `metrics`.
"""
return flax.struct.dataclass(
return flax.struct.dataclass( # pyrefly: ignore[bad-return]
type("_InlineCollection", (Collection,), {"__annotations__": metrics}))

@classmethod
Expand Down Expand Up @@ -706,9 +706,9 @@ class LastValue(Metric):

def __init__( # pytype: disable=missing-parameter # jnp-array
self,
total: jnp.ndarray | _default = _default,
count: jnp.ndarray | _default = _default,
value: jnp.ndarray | _default = _default,
total: jnp.ndarray | _default = _default, # pyrefly: ignore[not-a-type]
count: jnp.ndarray | _default = _default, # pyrefly: ignore[not-a-type]
value: jnp.ndarray | _default = _default, # pyrefly: ignore[not-a-type]
):
"""Backward compatibility constructor.

Expand Down Expand Up @@ -913,7 +913,7 @@ class Accuracy(Average):
"""

@classmethod
def from_model_output(
def from_model_output( # pyrefly: ignore[bad-override]
cls, *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs
) -> Accuracy:
if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32:
Expand Down
16 changes: 8 additions & 8 deletions clu/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@flax.struct.dataclass
class CollectingMetricAccuracy(
metrics.CollectingMetric.from_outputs(("logits", "labels"))):
metrics.CollectingMetric.from_outputs(("logits", "labels"))): # pyrefly: ignore[invalid-inheritance]

def compute(self):
values = super().compute()
Expand All @@ -44,7 +44,7 @@ def compute(self):
@flax.struct.dataclass
class Collection(metrics.Collection):
train_accuracy: metrics.Accuracy
learning_rate: metrics.LastValue.from_output("learning_rate")
learning_rate: metrics.LastValue.from_output("learning_rate") # pyrefly: ignore[invalid-annotation]


@flax.struct.dataclass
Expand Down Expand Up @@ -217,12 +217,12 @@ def accuracy(*, logits, labels, **_):

chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Average.from_fun(accuracy),
metrics.Average.from_fun(accuracy), # pyrefly: ignore[bad-argument-type]
reduce=False)(self.model_outputs), self.results["train_accuracy"])

chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Average.from_fun(accuracy),
metrics.Average.from_fun(accuracy), # pyrefly: ignore[bad-argument-type]
reduce=False)(self.model_outputs_masked),
self.results_masked["train_accuracy"])

Expand All @@ -237,12 +237,12 @@ def make_accuracy_args_map(*, logits, labels, **_):

chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Accuracy.from_fun(make_accuracy_args_map),
metrics.Accuracy.from_fun(make_accuracy_args_map), # pyrefly: ignore[bad-argument-type]
reduce=False)(self.model_outputs), self.results["train_accuracy"])

chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Accuracy.from_fun(make_accuracy_args_map),
metrics.Accuracy.from_fun(make_accuracy_args_map), # pyrefly: ignore[bad-argument-type]
reduce=False)(self.model_outputs_masked),
self.results_masked["train_accuracy"])

Expand Down Expand Up @@ -356,8 +356,8 @@ def with_head2(logits, labels, mask, head2_mask, **_):
return dict(logits=logits, labels=labels, mask=head2_mask & mask)

collection = metrics.Collection.create(
head1_accuracy=metrics.Accuracy.from_fun(with_head1),
head2_accuracy=metrics.Accuracy.from_fun(with_head2)
head1_accuracy=metrics.Accuracy.from_fun(with_head1), # pyrefly: ignore[bad-argument-type]
head2_accuracy=metrics.Accuracy.from_fun(with_head2) # pyrefly: ignore[bad-argument-type]
)

chex.assert_trees_all_close(
Expand Down
8 changes: 4 additions & 4 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def flatten_dict(
nested_key = f"{prefix}{delimiter}{key}" if prefix else key
if isinstance(value, (dict, flax.core.FrozenDict)):
output_dict.update(
flatten_dict(value, prefix=nested_key, delimiter=delimiter)
flatten_dict(value, prefix=nested_key, delimiter=delimiter) # pyrefly: ignore[bad-argument-type]
)
else:
output_dict[nested_key] = value
Expand All @@ -81,13 +81,13 @@ def flatten_dict(

def _count_parameters(params: _ParamsContainer) -> int:
"""Returns the count of variables for the module or parameter dictionary."""
params = flatten_dict(params)
params = flatten_dict(params) # pyrefly: ignore[bad-argument-type]
return sum(np.prod(v.shape) for v in params.values() if v is not None)


def _parameters_size(params: _ParamsContainer) -> int:
"""Returns total size (bytes) for the module or parameter dictionary."""
params = flatten_dict(params)
params = flatten_dict(params) # pyrefly: ignore[bad-argument-type]
return sum(
np.prod(v.shape) * v.dtype.itemsize
for v in params.values()
Expand Down Expand Up @@ -177,7 +177,7 @@ def _get_parameter_rows(
f"Expected `params` to be a dictionary but got {type(params)}"
)

params = flatten_dict(params)
params = flatten_dict(params) # pyrefly: ignore[bad-argument-type]
if params:
names, values = map(list, tuple(zip(*sorted(params.items()))))
else:
Expand Down
14 changes: 7 additions & 7 deletions clu/periodic_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ def __init__(self,
self._every_secs = every_secs
self._on_steps = set(on_steps or [])
# Step and timestamp for the last time the action triggered.
self._previous_step: int = None
self._previous_time: float = None
self._previous_step: int = None # pyrefly: ignore[bad-assignment]
self._previous_time: float = None # pyrefly: ignore[bad-assignment]
# Just for checking that __call__() was called every step.
self._last_step: int = None
self._last_step: int = None # pyrefly: ignore[bad-assignment]

def _init_and_check(self, step: int, t: float):
"""Initializes and checks it was called at every step."""
if self._previous_step is None:
self._previous_step = step
self._previous_time = t
self._last_step = step
self._previous_step = step # pyrefly: ignore[bad-assignment]
self._previous_time = t # pyrefly: ignore[bad-assignment]
self._last_step = step # pyrefly: ignore[bad-assignment]
elif self._every_steps is not None and step - self._last_step != 1:
raise ValueError(f"PeriodicAction must be called after every step once "
f"(every_steps={self._every_steps}, "
Expand Down Expand Up @@ -350,7 +350,7 @@ def __init__(
def _should_trigger(self, step: int, t: float) -> bool:
if self._session_running:
# If a session is running we only check if we should stop it.
dt = t - self._session_started
dt = t - self._session_started # pyrefly: ignore[unsupported-operation]
cond = (not self._profile_duration_ms or
dt * 1e3 >= self._profile_duration_ms)
cond &= (not self._num_profile_steps or
Expand Down
14 changes: 7 additions & 7 deletions clu/preprocess_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __call__(self, features: D) -> D:
"switch to grain.tensorflow.MapTransform.")
if isinstance(features, tf.data.Dataset):
return features.map(self._transform, num_parallel_calls=tf.data.AUTOTUNE)
return self._transform(features)
return self._transform(features) # pyrefly: ignore[bad-return]

@abc.abstractmethod
def _transform(self, features: FlatFeatures) -> FlatFeatures:
Expand Down Expand Up @@ -154,7 +154,7 @@ def __call__(self, features: D) -> D:

next_seed, seed = tf.unstack(
tf.random.experimental.stateless_split(features.pop(SEED_KEY)))
features = self._transform(features, seed)
features = self._transform(features, seed) # pyrefly: ignore[bad-assignment]
features[SEED_KEY] = next_seed
return features

Expand Down Expand Up @@ -341,10 +341,10 @@ def _parse_single_preprocess_op(
args = [ast.literal_eval(arg) for arg in expr.args]
kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords}
if not args:
return op_class(**kwargs)
return op_class(**kwargs) # pyrefly: ignore[bad-unpacking]

# Translate positional arguments into keyword arguments.
available_arg_names = [f.name for f in dataclasses.fields(op_class)]
available_arg_names = [f.name for f in dataclasses.fields(op_class)] # pyrefly: ignore[bad-argument-type]
for i, arg in enumerate(args):
name = available_arg_names[i]
if name in kwargs:
Expand All @@ -353,20 +353,20 @@ def _parse_single_preprocess_op(
f"(value: {arg}) and keyword argument (value: {kwargs[name]}).")
kwargs[name] = arg

return op_class(**kwargs)
return op_class(**kwargs) # pyrefly: ignore[bad-unpacking]


def parse(spec: str,
available_ops: List[Tuple[str, Type[PreprocessOp]]],
*,
only_jax_types: bool = True) -> PreprocessFn:
"""Parses a preprocess spec; a '|' separated list of preprocess ops."""
available_ops = dict(available_ops)
available_ops = dict(available_ops) # pyrefly: ignore[bad-assignment]
if not spec.strip():
ops = []
else:
ops = [
_parse_single_preprocess_op(s, available_ops) for s in spec.split("|")
_parse_single_preprocess_op(s, available_ops) for s in spec.split("|") # pyrefly: ignore[bad-argument-type]
]
return PreprocessFn(ops, only_jax_types=only_jax_types)

Expand Down
Loading