Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/base/rotary_embedding_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#ifndef INFINI_OPS_BASE_ROTARY_EMBEDDING_INFINILM_H_
#define INFINI_OPS_BASE_ROTARY_EMBEDDING_INFINILM_H_

#include <cassert>
#include <cstddef>

#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

class RotaryEmbeddingInfinilm : public Operator<RotaryEmbeddingInfinilm> {
public:
RotaryEmbeddingInfinilm(const Tensor input, const Tensor pos_ids,
const Tensor sin_table, const Tensor cos_table,
bool is_neox, Tensor out)
: ndim_{out.ndim()},
batch_size_{ndim_ == 4 ? out.size(-4) : 1},
seq_len_{out.size(-3)},
nhead_{out.size(-2)},
table_dim_{sin_table.size(1)},
has_batch_dim_{ndim_ == 4},
pos_has_batch_dim_{pos_ids.ndim() == 2},
input_strides_{input.strides()},
out_strides_{out.strides()},
pos_strides_{pos_ids.strides()} {
const auto head_dim = out.size(-1);
const auto table_len = sin_table.size(0);
const auto angle_dtype = sin_table.dtype();
const auto pos_dtype = pos_ids.dtype();

assert(input.shape() == out.shape() &&
"`RotaryEmbeddingInfinilm` requires `input` and `out` same shape");
assert(input.dtype() == out.dtype() &&
"`RotaryEmbeddingInfinilm` requires `input` and `out` same dtype");
assert((ndim_ == 3 || ndim_ == 4) &&
"`RotaryEmbeddingInfinilm` requires 3D or 4D tensor");
assert(head_dim % 2 == 0 &&
"`RotaryEmbeddingInfinilm` requires head dimension to be even");
assert(
head_dim == table_dim_ * 2 &&
"`RotaryEmbeddingInfinilm` requires table dim to be half of head dim");
assert(pos_ids.ndim() == 1 || pos_ids.ndim() == 2);
assert((pos_dtype == DataType::kInt32 || pos_dtype == DataType::kInt64) &&
"`RotaryEmbeddingInfinilm` requires int32 or int64 position ids");
assert(sin_table.shape() == cos_table.shape() &&
"`RotaryEmbeddingInfinilm` requires sin_table and cos_table same "
"shape");
assert(sin_table.dtype() == cos_table.dtype() &&
"`RotaryEmbeddingInfinilm` requires sin_table and cos_table same "
"dtype");
assert((angle_dtype == DataType::kFloat16 ||
angle_dtype == DataType::kBFloat16 ||
angle_dtype == DataType::kFloat32) &&
"`RotaryEmbeddingInfinilm` requires float sin/cos tables");
assert(sin_table.ndim() == 2 && cos_table.ndim() == 2 &&
"`RotaryEmbeddingInfinilm` requires 2D sin/cos tables");
assert(
table_len >= seq_len_ &&
"`RotaryEmbeddingInfinilm` requires table length >= sequence length");
assert((pos_has_batch_dim_ ? (pos_ids.size(0) == batch_size_ &&
pos_ids.size(1) == seq_len_)
: (pos_ids.size(0) == seq_len_)) &&
"`RotaryEmbeddingInfinilm` requires pos_ids shape [seq] or [batch, "
"seq]");
assert(out_strides_[ndim_ - 1] == 1 && input_strides_[ndim_ - 1] == 1 &&
"`RotaryEmbeddingInfinilm` requires contiguous head dimension");
assert(sin_table.strides()[1] == 1 && cos_table.strides()[1] == 1 &&
"`RotaryEmbeddingInfinilm` requires contiguous table dimension");
}

virtual void operator()(const Tensor input, const Tensor pos_ids,
const Tensor sin_table, const Tensor cos_table,
bool is_neox, Tensor out) const = 0;

protected:
Tensor::Size ndim_{0};

Tensor::Size batch_size_{0};

Tensor::Size seq_len_{0};

Tensor::Size nhead_{0};

Tensor::Size table_dim_{0};

bool has_batch_dim_{false};

bool pos_has_batch_dim_{false};

Tensor::Strides input_strides_;

Tensor::Strides out_strides_;

Tensor::Strides pos_strides_;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/rotary_embedding_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_ROTARY_EMBEDDING_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_ROTARY_EMBEDDING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/rotary_embedding_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbeddingInfinilm, Device::Type::kIluvatar>
: public CudaRotaryEmbeddingInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaRotaryEmbeddingInfinilm<
Runtime<Device::Type::kIluvatar>>::CudaRotaryEmbeddingInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/metax/ops/rotary_embedding_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_METAX_ROTARY_EMBEDDING_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_ROTARY_EMBEDDING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/rotary_embedding_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbeddingInfinilm, Device::Type::kMetax>
: public CudaRotaryEmbeddingInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaRotaryEmbeddingInfinilm<
Runtime<Device::Type::kMetax>>::CudaRotaryEmbeddingInfinilm;
};

} // namespace infini::ops

#endif
26 changes: 26 additions & 0 deletions src/native/cuda/moore/ops/rotary_embedding_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef INFINI_OPS_MOORE_ROTARY_EMBEDDING_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_ROTARY_EMBEDDING_INFINILM_KERNEL_H_

#include <utility>

// clang-format off
#include <musa_runtime.h>
// clang-format on

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/rotary_embedding_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbeddingInfinilm, Device::Type::kMoore>
: public CudaRotaryEmbeddingInfinilm<Runtime<Device::Type::kMoore>> {
public:
using CudaRotaryEmbeddingInfinilm<
Runtime<Device::Type::kMoore>>::CudaRotaryEmbeddingInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/rotary_embedding_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_ROTARY_EMBEDDING_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_ROTARY_EMBEDDING_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/rotary_embedding_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbeddingInfinilm, Device::Type::kNvidia>
: public CudaRotaryEmbeddingInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaRotaryEmbeddingInfinilm<
Runtime<Device::Type::kNvidia>>::CudaRotaryEmbeddingInfinilm;
};

} // namespace infini::ops

#endif
127 changes: 127 additions & 0 deletions src/native/cuda/ops/rotary_embedding_infinilm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#ifndef INFINI_OPS_CUDA_ROPE_KERNEL_CUH_
#define INFINI_OPS_CUDA_ROPE_KERNEL_CUH_

#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "native/cuda/caster.cuh"

namespace infini::ops {

template <typename T>
struct VecTypeHelper {};

template <>
struct VecTypeHelper<half> {
using type2 = half2;
static __device__ __forceinline__ half2 make_half2(float x, float y) {
return __floats2half2_rn(x, y);
}
static __device__ __forceinline__ float low(half2 v) {
return __low2float(v);
}
static __device__ __forceinline__ float high(half2 v) {
return __high2float(v);
}
};

template <>
struct VecTypeHelper<cuda_bfloat16> {
using type2 = cuda_bfloat162;
static __device__ __forceinline__ cuda_bfloat162 make_half2(float x,
float y) {
return __floats2bfloat162_rn(x, y);
}
static __device__ __forceinline__ float low(cuda_bfloat162 v) {
return __low2float(v);
}
static __device__ __forceinline__ float high(cuda_bfloat162 v) {
return __high2float(v);
}
};

template <bool IsNeox, Device::Type kDev, typename TData, typename TIndex,
typename TAngle>
__global__ void RoPEKernel(
TData* __restrict__ out_ptr, const TData* __restrict__ input_ptr,
const TIndex* __restrict__ pos_ids_ptr, const TAngle* __restrict__ sin_ptr,
const TAngle* __restrict__ cos_ptr, size_t table_dim,
ptrdiff_t out_stride_batch, ptrdiff_t out_stride_seqlen,
ptrdiff_t out_stride_nhead, ptrdiff_t input_stride_batch,
ptrdiff_t input_stride_seqlen, ptrdiff_t input_stride_nhead,
ptrdiff_t pos_stride_batch, bool pos_has_batch_dim, bool has_batch_dim) {
const size_t batch_idx = has_batch_dim ? blockIdx.z : 0;
const size_t seq_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;

auto out_offset = (has_batch_dim ? batch_idx * out_stride_batch : 0) +
seq_idx * out_stride_seqlen + head_idx * out_stride_nhead;
auto input_offset = (has_batch_dim ? batch_idx * input_stride_batch : 0) +
seq_idx * input_stride_seqlen +
head_idx * input_stride_nhead;

size_t pos_offset;
if (pos_has_batch_dim) {
pos_offset = batch_idx * pos_stride_batch + seq_idx;
} else {
pos_offset = seq_idx;
}

size_t pos_id = static_cast<size_t>(pos_ids_ptr[pos_offset]);
size_t table_offset = pos_id * table_dim;

using VecHelper = VecTypeHelper<TData>;

for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
float sin_val =
Caster<kDev>::template Cast<float>(sin_ptr[table_offset + i]);
float cos_val =
Caster<kDev>::template Cast<float>(cos_ptr[table_offset + i]);

if constexpr (IsNeox) {
if constexpr (std::is_same<TData, half>::value ||
std::is_same<TData, cuda_bfloat16>::value) {
auto& y = reinterpret_cast<typename VecHelper::type2&>(
out_ptr[out_offset + 2 * i]);
auto& x = reinterpret_cast<const typename VecHelper::type2&>(
input_ptr[input_offset + 2 * i]);

float x0 = VecHelper::low(x);
float x1 = VecHelper::high(x);

float y0 = x0 * cos_val - x1 * sin_val;
float y1 = x0 * sin_val + x1 * cos_val;

y = VecHelper::make_half2(y0, y1);
} else {
float x0 =
Caster<kDev>::template Cast<float>(input_ptr[input_offset + 2 * i]);
float x1 = Caster<kDev>::template Cast<float>(
input_ptr[input_offset + 2 * i + 1]);
out_ptr[out_offset + 2 * i] =
Caster<kDev>::template Cast<TData>(x0 * cos_val - x1 * sin_val);
out_ptr[out_offset + 2 * i + 1] =
Caster<kDev>::template Cast<TData>(x0 * sin_val + x1 * cos_val);
}
} else {
size_t pos0 = i;
size_t pos1 = i + table_dim;

float x0 =
Caster<kDev>::template Cast<float>(input_ptr[input_offset + pos0]);
float x1 =
Caster<kDev>::template Cast<float>(input_ptr[input_offset + pos1]);

float y0 = x0 * cos_val - x1 * sin_val;
float y1 = x0 * sin_val + x1 * cos_val;

out_ptr[out_offset + pos0] = Caster<kDev>::template Cast<TData>(y0);
out_ptr[out_offset + pos1] = Caster<kDev>::template Cast<TData>(y1);
}
}
}

} // namespace infini::ops

#endif
Loading
Loading