diff --git a/include/vectorwrapper/vectorwrapper.hpp b/include/vectorwrapper/vectorwrapper.hpp index 2feb29d..1654b9e 100644 --- a/include/vectorwrapper/vectorwrapper.hpp +++ b/include/vectorwrapper/vectorwrapper.hpp @@ -169,6 +169,8 @@ namespace vwr { private: template void assign_values (const bt::number_seq&, Args... parArgs); + template + void assign_values_op (Op parOp, const bt::number_seq& parSeq, const VecBase& parOther); vector_type m_wrapped; }; diff --git a/include/vectorwrapper/vectorwrapper.inl b/include/vectorwrapper/vectorwrapper.inl index ccca6e1..7fe12e2 100644 --- a/include/vectorwrapper/vectorwrapper.inl +++ b/include/vectorwrapper/vectorwrapper.inl @@ -55,6 +55,12 @@ namespace vwr { static_cast(t); } + template + template + void VecBase::assign_values_op (Op parOp, const bt::number_seq& parSeq, const VecBase& parOther) { + this->assign_values(parSeq, parOp((*this)[I], parOther[I])...); + } + template auto VecBase::operator[] (size_type parIndex) -> scalar_type& { return VecGetter::get_at(m_wrapped, parIndex); @@ -95,36 +101,28 @@ namespace vwr { template VecBase& VecBase::operator+= (const VecBase& parOther) { static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); - for (size_type z = 0; z < VectorWrapperInfo::dimensions; ++z) { - (*this)[z] += parOther[z]; - } + this->assign_values_op(std::plus(), bt::number_range::dimensions>(), parOther); return *this; } template template VecBase& VecBase::operator-= (const VecBase& parOther) { static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); - for (size_type z = 0; z < VectorWrapperInfo::dimensions; ++z) { - (*this)[z] -= parOther[z]; - } + this->assign_values_op(std::minus(), bt::number_range::dimensions>(), parOther); return *this; } template template VecBase& VecBase::operator*= (const VecBase& parOther) { static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); - for (size_type z = 0; z < VectorWrapperInfo::dimensions; ++z) { - (*this)[z] *= parOther[z]; - } + this->assign_values_op(std::multiplies(), bt::number_range::dimensions>(), parOther); return *this; } template template VecBase& VecBase::operator/= (const VecBase& parOther) { static_assert(static_cast(VectorWrapperInfo::dimensions) == static_cast(VectorWrapperInfo::dimensions), "Dimensions mismatch"); - for (int z = 0; z < VectorWrapperInfo::dimensions; ++z) { - (*this)[z] /= parOther[z]; - } + this->assign_values_op(std::divides(), bt::number_range::dimensions>(), parOther); return *this; } diff --git a/test/unit/test_operators.cpp b/test/unit/test_operators.cpp index 4c1cad4..3e839ea 100644 --- a/test/unit/test_operators.cpp +++ b/test/unit/test_operators.cpp @@ -137,3 +137,32 @@ TEST(vwr, bin_operators_scalar) { EXPECT_EQ(res, 1000 % a); } } + +TEST(vwr, bin_assign_op) { + using namespace vwr; + + { + ivec3 a(2, 4, 8); + ivec3 res(2 + 20, 4 + 20, 8 + 20); + a += ivec3(20); + EXPECT_EQ(res, a); + } + { + ivec3 a(2, 4, 8); + ivec3 res(2 - 20, 4 - 20, 8 - 20); + a -= ivec3(20); + EXPECT_EQ(res, a); + } + { + ivec3 a(2, 4, 8); + ivec3 res(2 * 20, 4 * 20, 8 * 20); + a *= ivec3(20); + EXPECT_EQ(res, a); + } + { + ivec3 a(2, 4, 8); + ivec3 res(2 / 2, 4 / 2, 8 / 2); + a /= ivec3(2); + EXPECT_EQ(res, a); + } +}