Rope kernel bug:
import mlx.core as mx
import mlx.nn as nn
batch_size = 32
seq_len = 8192
heads = 32
head_dim = 128
x = mx.random.normal(shape=(batch_size, seq_len, heads * head_dim), dtype=mx.bfloat16)
rope = nn.RoPE(head_dim)
x = x.reshape(batch_size, seq_len, heads, head_dim).transpose(0, 2, 1, 3)
rotated_x = rope(x)
mx.eval(rotated_x)
returns:
Traceback (most recent call last):
RuntimeError: cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms) failed: invalid argument
Rope kernel bug:
returns:
Traceback (most recent call last):