diff --git a/include/vectorwrapper/vectorwrapper.hpp b/include/vectorwrapper/vectorwrapper.hpp index 4892f96..6e157a8 100644 --- a/include/vectorwrapper/vectorwrapper.hpp +++ b/include/vectorwrapper/vectorwrapper.hpp @@ -291,6 +291,11 @@ namespace vwr { scalar_type& x ( void ); const scalar_type& x ( void ) const; }; + + template + constexpr bool compare ( const Vec& parLeft, const Vec& parRight, Op parComposeOp, Op parOp, bt::number_seq ); + template + bool compare ( const Vec& parLeft, const Vec& parRight, ComposeOp parComposeOp, Op parOp, bt::number_seq ); } //namespace implem template @@ -392,7 +397,15 @@ namespace vwr { template bool operator== ( const Vec& parLeft, const Vec& parRight ); template + bool operator!= ( const Vec& parLeft, const Vec& parRight ); + template bool operator< ( const Vec& parLeft, const Vec& parRight ); + template + bool operator> ( const Vec& parLeft, const Vec& parRight ); + template + bool operator<= ( const Vec& parLeft, const Vec& parRight ); + template + bool operator>= ( const Vec& parLeft, const Vec& parRight ); template bool operator== ( const Vec& parLeft, const typename VectorWrapperInfo::scalar_type& parRight ); diff --git a/include/vectorwrapper/vectorwrapper.inl b/include/vectorwrapper/vectorwrapper.inl index e68883f..b3f0f66 100644 --- a/include/vectorwrapper/vectorwrapper.inl +++ b/include/vectorwrapper/vectorwrapper.inl @@ -279,6 +279,18 @@ namespace vwr { { static_assert(sizeof...(I) == S, "Bug?"); } + + template + inline constexpr + bool compare (const Vec&, const Vec&, ComposeOp, Op, bt::number_seq) { + return LastVal; + } + template + inline + bool compare (const Vec& parLeft, const Vec& parRight, ComposeOp parComposeOp, Op parOp, bt::number_seq) { + static_assert(I1 < VectorWrapperInfo::dimensions, "Index out of range"); + return parComposeOp(parOp(parLeft[I1], parRight[I1]), compare(parLeft, parRight, parComposeOp, parOp, bt::number_seq())); + } } //namespace implem template const Vec Vec::unit_x(scalar_type(1)); @@ -296,20 +308,74 @@ namespace vwr { template inline bool operator== (const Vec& parLeft, const Vec& parRight) { static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); - bool retval = true; - for (size_type z = 0; z < VectorWrapperInfo::dimensions; ++z) { - retval &= (parLeft[z] == parRight[z]); - } - return retval; + typedef typename std::common_type::scalar_type, typename VectorWrapperInfo::scalar_type>::type scalar_type; + return implem::compare( + parLeft, + parRight, + std::logical_and(), + std::equal_to(), + bt::number_range::dimensions>() + ); + } + template + inline bool operator!= (const Vec& parLeft, const Vec& parRight) { + static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); + typedef typename std::common_type::scalar_type, typename VectorWrapperInfo::scalar_type>::type scalar_type; + return implem::compare( + parLeft, + parRight, + std::logical_or(), + std::not_equal_to(), + bt::number_range::dimensions>() + ); } template inline bool operator< (const Vec& parLeft, const Vec& parRight) { static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); - bool retval = true; - for (size_type z = 0; z < VectorWrapperInfo::dimensions; ++z) { - retval &= (parLeft[z] < parRight[z]); - } - return retval; + typedef typename std::common_type::scalar_type, typename VectorWrapperInfo::scalar_type>::type scalar_type; + return implem::compare( + parLeft, + parRight, + std::logical_and(), + std::less(), + bt::number_range::dimensions>() + ); + } + template + inline bool operator> (const Vec& parLeft, const Vec& parRight) { + static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); + typedef typename std::common_type::scalar_type, typename VectorWrapperInfo::scalar_type>::type scalar_type; + return implem::compare( + parLeft, + parRight, + std::logical_and(), + std::greater(), + bt::number_range::dimensions>() + ); + } + template + inline bool operator<= (const Vec& parLeft, const Vec& parRight) { + static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); + typedef typename std::common_type::scalar_type, typename VectorWrapperInfo::scalar_type>::type scalar_type; + return implem::compare( + parLeft, + parRight, + std::logical_and(), + std::less_equal(), + bt::number_range::dimensions>() + ); + } + template + inline bool operator>= (const Vec& parLeft, const Vec& parRight) { + static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); + typedef typename std::common_type::scalar_type, typename VectorWrapperInfo::scalar_type>::type scalar_type; + return implem::compare( + parLeft, + parRight, + std::logical_and(), + std::greater_equal(), + bt::number_range::dimensions>() + ); } template diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index cbbf5c0..0b65f23 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -6,6 +6,7 @@ add_executable(${PROJECT_NAME} test_ops.cpp example.cpp test_get_at.cpp + test_operators.cpp ) target_link_libraries(${PROJECT_NAME} diff --git a/test/unit/test_operators.cpp b/test/unit/test_operators.cpp new file mode 100644 index 0000000..f773616 --- /dev/null +++ b/test/unit/test_operators.cpp @@ -0,0 +1,44 @@ +#include "sample_vectors.hpp" +#include + +TEST(vwr, operators) { + using namespace vwr; + + { + ivec3 a(5, 6, 7); + ivec3 b(6, 7, 8); + + EXPECT_LT(a, b); + EXPECT_LE(a, b); + EXPECT_NE(a, b); + EXPECT_FALSE(a == b); + EXPECT_FALSE(a > b); + EXPECT_FALSE(a >= b); + EXPECT_GT(b, a); + EXPECT_GE(b, a); + } + { + ivec3 a(6, 6, 7); + ivec3 b(6, 7, 8); + + EXPECT_FALSE(a < b); + EXPECT_LE(a, b); + EXPECT_NE(a, b); + EXPECT_FALSE(a == b); + EXPECT_FALSE(a > b); + EXPECT_FALSE(a >= b); + EXPECT_GE(b, a); + } + { + ivec3 a(0xAABB, 0xAABB, 0xAABB); + ivec3 b(0xAABB, 0xAABB, 0xAABB); + + EXPECT_FALSE(a < b); + EXPECT_LE(a, b); + EXPECT_FALSE(a != b); + EXPECT_EQ(a, b); + EXPECT_FALSE(a > b); + EXPECT_GE(a, b); + EXPECT_GE(b, a); + } +}