Skip to content
41 changes: 41 additions & 0 deletions include/xtensor/generators/xbuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,32 @@ namespace xt

namespace detail
{
template <class S>
struct vstack_fixed_shape_impl;

template <std::size_t N>
struct vstack_fixed_shape_impl<fixed_shape<N>>
{
using type = fixed_shape<1, N>;
};

template <std::size_t I, std::size_t... J>
struct vstack_fixed_shape_impl<fixed_shape<I, J...>>
{
using type = fixed_shape<I, J...>;
};

template <class... CT>
struct vstack_fixed_shape
{
using type = concat_fixed_shape_t<
0,
typename vstack_fixed_shape_impl<typename std::decay_t<CT>::shape_type>::type...>;
};

template <class... CT>
using vstack_fixed_shape_t = typename vstack_fixed_shape<CT...>::type;

template <class S, class... CT>
inline auto vstack_shape(std::tuple<CT...>& t, const S& shape)
{
Expand Down Expand Up @@ -948,6 +974,21 @@ namespace xt
return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), new_shape);
}

/**
* @brief Stack fixed-shape xexpressions in sequence vertically (row wise).
* This overload preserves the result shape at compile time by treating
* 1-D fixed shapes as ``(1, N)`` row vectors before concatenation.
*
* @param t \ref xtuple of fixed-shape xexpressions to stack
* @return xgenerator evaluating to stacked elements with a fixed compile-time shape
*/
template <fixed_shape_container_concept... CT>
inline auto vstack(std::tuple<CT...>&& t)
{
using shape_type = detail::vstack_fixed_shape_t<CT...>;
return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), shape_type{});
}

namespace detail
{

Expand Down
17 changes: 12 additions & 5 deletions include/xtensor/views/index_mapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ namespace xt
* @throws Assertion failure if `i != 0` for integral slices.
* @throws Assertion failure if `i >= slice.size()` for non-integral slices.
*/
template <size_t I, std::integral Index>
template <access_t ACCESS, size_t I, std::integral Index>
size_t map_ith_index(const view_type& view, const Index i) const;

/**
Expand Down Expand Up @@ -490,16 +490,16 @@ namespace xt
{
if constexpr (ACCESS == access_t::SAFE)
{
return container.at(map_ith_index<Is>(view, indices[Is])...);
return container.at(map_ith_index<ACCESS, Is>(view, indices[Is])...);
}
else
{
return container(map_ith_index<Is>(view, indices[Is])...);
return container(map_ith_index<ACCESS, Is>(view, indices[Is])...);
}
}

template <class UnderlyingContainer, class... Slices>
template <size_t I, std::integral Index>
template <access_t ACCESS, size_t I, std::integral Index>
auto
index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_ith_index(const view_type& view, const Index i) const
-> size_t
Expand All @@ -518,10 +518,17 @@ namespace xt
assert(i == 0);
return size_t(slice);
}
else if constexpr (xt::detail::is_xall_slice<std::decay_t<current_slice>>::value)
{
return size_t(i);
}
else
{
using slice_size_type = typename current_slice::size_type;
assert(i < slice.size());
if constexpr (ACCESS == access_t::UNSAFE)
{
assert(static_cast<slice_size_type>(i) < slice.size());
}
return size_t(slice(static_cast<slice_size_type>(i)));
}
}
Expand Down
15 changes: 15 additions & 0 deletions test/test_xbuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,21 @@ namespace xt
ASSERT_TRUE(arange(8) == w1);
ASSERT_TRUE(w1 == w2);
}

TEST(xbuilder, vstack_fixed)
{
xtensor_fixed<float, fixed_shape<1, 2>> a = {{1.f, 2.f}};
xtensor_fixed<float, fixed_shape<2, 2>> b = {{3.f, 4.f}, {5.f, 6.f}};

auto c = vstack(xtuple(a, b));

using expected_shape_t = fixed_shape<3, 2>;
ASSERT_EQ(expected_shape_t{}, c.shape());
EXPECT_EQ(1.f, c(0, 0));
EXPECT_EQ(2.f, c(0, 1));
EXPECT_EQ(3.f, c(1, 0));
EXPECT_EQ(6.f, c(2, 1));
}
#endif

TEST(xbuilder, access)
Expand Down
Loading