Skip to content

[BUG] [CUDA] rope cudaGraphAddKernelNode #3659

@nastya236

Description

@nastya236

Rope kernel bug:

import mlx.core as mx
import mlx.nn as nn

batch_size = 32
seq_len = 8192
heads = 32
head_dim = 128
x = mx.random.normal(shape=(batch_size, seq_len, heads * head_dim), dtype=mx.bfloat16)
rope = nn.RoPE(head_dim)  
x = x.reshape(batch_size, seq_len, heads, head_dim).transpose(0, 2, 1, 3)
rotated_x = rope(x)
mx.eval(rotated_x)

returns:
Traceback (most recent call last):

RuntimeError: cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params) failed: invalid argument

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions