Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions stan/math/fwd/functor/operands_and_partials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ops_partials_edge<Dx, fvar<Dx>> {
: partial_(0), partials_(partial_), operand_(op) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
const Op& operand_;

Expand Down Expand Up @@ -62,19 +63,25 @@ class ops_partials_edge<Dx, fvar<Dx>> {
* @tparam Op3 type of the third operand
* @tparam Op4 type of the fourth operand
* @tparam Op5 type of the fifth operand
* @tparam Op6 type of the sixth operand
* @tparam Op7 type of the seventh operand
* @tparam Op8 type of the eighth operand
* @tparam T_return_type return type of the expression. This defaults
* to a template metaprogram that calculates the scalar promotion of
* Op1 -- Op5
* Op1 -- Op8
*/
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
typename Dx>
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
typename Op6, typename Op7, typename Op8, typename Dx>
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, fvar<Dx>> {
public:
internal::ops_partials_edge<Dx, std::decay_t<Op1>> edge1_;
internal::ops_partials_edge<Dx, std::decay_t<Op2>> edge2_;
internal::ops_partials_edge<Dx, std::decay_t<Op3>> edge3_;
internal::ops_partials_edge<Dx, std::decay_t<Op4>> edge4_;
internal::ops_partials_edge<Dx, std::decay_t<Op5>> edge5_;
internal::ops_partials_edge<Dx, std::decay_t<Op6>> edge6_;
internal::ops_partials_edge<Dx, std::decay_t<Op7>> edge7_;
internal::ops_partials_edge<Dx, std::decay_t<Op8>> edge8_;
using T_return_type = fvar<Dx>;
explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
operands_and_partials(const Op1& o1, const Op2& o2)
Expand All @@ -87,6 +94,35 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5)
: edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5, const Op6& o6)
: edge1_(o1),
edge2_(o2),
edge3_(o3),
edge4_(o4),
edge5_(o5),
edge6_(o6) {}
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5, const Op6& o6,
const Op7& o7)
: edge1_(o1),
edge2_(o2),
edge3_(o3),
edge4_(o4),
edge5_(o5),
edge6_(o6),
edge7_(o7) {}
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5, const Op6& o6,
const Op7& o7, const Op8& o8)
: edge1_(o1),
edge2_(o2),
edge3_(o3),
edge4_(o4),
edge5_(o5),
edge6_(o6),
edge7_(o7),
edge8_(o8) {}

/** \ingroup type_trait
* Build the node to be stored on the autodiff graph.
Expand All @@ -102,8 +138,8 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
* @return the value with its derivative
*/
T_return_type build(Dx value) {
Dx deriv
= edge1_.dx() + edge2_.dx() + edge3_.dx() + edge4_.dx() + edge5_.dx();
Dx deriv = edge1_.dx() + edge2_.dx() + edge3_.dx() + edge4_.dx()
+ edge5_.dx() + edge6_.dx() + edge7_.dx() + edge8_.dx();
return T_return_type(value, deriv);
}
};
Expand All @@ -124,7 +160,8 @@ class ops_partials_edge<Dx, std::vector<fvar<Dx>>> {
operands_(ops) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
const Op& operands_;

Expand All @@ -150,7 +187,8 @@ class ops_partials_edge<Dx, Eigen::Matrix<fvar<Dx>, R, C>> {
operands_(ops) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
const Op& operands_;

Expand Down Expand Up @@ -178,7 +216,8 @@ class ops_partials_edge<Dx, std::vector<Eigen::Matrix<fvar<Dx>, R, C>>> {
}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
const Op& operands_;

Expand Down Expand Up @@ -207,7 +246,8 @@ class ops_partials_edge<Dx, std::vector<std::vector<fvar<Dx>>>> {
}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
const Op& operands_;

Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/rev/operands_and_partials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class ops_partials_edge<double, var_value<Op>,
operands_(ops) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
var_value<Op> operands_;
static constexpr int size() noexcept { return 0; }
Expand Down
32 changes: 26 additions & 6 deletions stan/math/prim/functor/operands_and_partials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
namespace stan {
namespace math {
template <typename Op1 = double, typename Op2 = double, typename Op3 = double,
typename Op4 = double, typename Op5 = double,
typename T_return_type = return_type_t<Op1, Op2, Op3, Op4, Op5>>
typename Op4 = double, typename Op5 = double, typename Op6 = double,
typename Op7 = double, typename Op8 = double,
typename T_return_type
= return_type_t<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8>>
class operands_and_partials; // Forward declaration

namespace internal {
Expand Down Expand Up @@ -70,7 +72,8 @@ class ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>> {
static constexpr int size() noexcept { return 0; } // reverse mode

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
};
template <typename ViewElt, typename Op>
Expand Down Expand Up @@ -100,7 +103,7 @@ constexpr double
*
* This base template is instantiated when all operands are
* primitives and we don't want to calculate derivatives at
* all. So all Op1 - Op5 must be arithmetic primitives
* all. So all Op1 - Op8 must be arithmetic primitives
* like int or double. This is controlled with the
* T_return_type type parameter.
*
Expand All @@ -109,12 +112,15 @@ constexpr double
* @tparam Op3 type of the third operand
* @tparam Op4 type of the fourth operand
* @tparam Op5 type of the fifth operand
* @tparam Op6 type of the sixth operand
* @tparam Op7 type of the seventh operand
* @tparam Op8 type of the eighth operand
* @tparam T_return_type return type of the expression. This defaults
* to calling a template metaprogram that calculates the scalar
* promotion of Op1..Op4
* promotion of Op1..Op8
*/
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
typename T_return_type>
typename Op6, typename Op7, typename Op8, typename T_return_type>
class operands_and_partials {
public:
explicit operands_and_partials(const Op1& /* op1 */) noexcept {}
Expand All @@ -126,6 +132,17 @@ class operands_and_partials {
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
const Op3& /* op3 */, const Op4& /* op4 */,
const Op5& /* op5 */) noexcept {}
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
const Op3& /* op3 */, const Op4& /* op4 */,
const Op5& /* op5 */, const Op6& /* op6 */) noexcept {}
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
const Op3& /* op3 */, const Op4& /* op4 */,
const Op5& /* op5 */, const Op6& /* op6 */,
const Op7& /* op7 */) noexcept {}
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
const Op3& /* op3 */, const Op4& /* op4 */,
const Op5& /* op5 */, const Op6& /* op6 */,
const Op7& /* op7 */, const Op8& /* op8 */) noexcept {}

/** \ingroup type_trait
* Build the node to be stored on the autodiff graph.
Expand All @@ -148,6 +165,9 @@ class operands_and_partials {
internal::ops_partials_edge<double, std::decay_t<Op3>> edge3_;
internal::ops_partials_edge<double, std::decay_t<Op4>> edge4_;
internal::ops_partials_edge<double, std::decay_t<Op5>> edge5_;
internal::ops_partials_edge<double, std::decay_t<Op6>> edge6_;
internal::ops_partials_edge<double, std::decay_t<Op7>> edge7_;
internal::ops_partials_edge<double, std::decay_t<Op8>> edge8_;
};

} // namespace math
Expand Down
77 changes: 66 additions & 11 deletions stan/math/rev/functor/operands_and_partials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class ops_partials_edge<double, var> {
: partial_(0), partials_(partial_), operand_(op) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
var operand_;
static constexpr int size() noexcept { return 1; }
Expand Down Expand Up @@ -109,15 +110,22 @@ inline void update_adjoints(StdVec1& x, const Vec2& y, const vari& z) {
* @tparam Op3 type of the third operand
* @tparam Op4 type of the fourth operand
* @tparam Op5 type of the fifth operand
* @tparam Op6 type of the sixth operand
* @tparam Op7 type of the seventh operand
* @tparam Op8 type of the eighth operand
*/
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5>
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
typename Op6, typename Op7, typename Op8>
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, var> {
public:
internal::ops_partials_edge<double, std::decay_t<Op1>> edge1_;
internal::ops_partials_edge<double, std::decay_t<Op2>> edge2_;
internal::ops_partials_edge<double, std::decay_t<Op3>> edge3_;
internal::ops_partials_edge<double, std::decay_t<Op4>> edge4_;
internal::ops_partials_edge<double, std::decay_t<Op5>> edge5_;
internal::ops_partials_edge<double, std::decay_t<Op6>> edge6_;
internal::ops_partials_edge<double, std::decay_t<Op7>> edge7_;
internal::ops_partials_edge<double, std::decay_t<Op8>> edge8_;

explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
operands_and_partials(const Op1& o1, const Op2& o2)
Expand All @@ -130,6 +138,35 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5)
: edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5, const Op6& o6)
: edge1_(o1),
edge2_(o2),
edge3_(o3),
edge4_(o4),
edge5_(o5),
edge6_(o6) {}
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5, const Op6& o6,
const Op7& o7)
: edge1_(o1),
edge2_(o2),
edge3_(o3),
edge4_(o4),
edge5_(o5),
edge6_(o6),
edge7_(o7) {}
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
const Op4& o4, const Op5& o5, const Op6& o6,
const Op7& o7, const Op8& o8)
: edge1_(o1),
edge2_(o2),
edge3_(o3),
edge4_(o4),
edge5_(o5),
edge6_(o6),
edge7_(o7),
edge8_(o8) {}

/** \ingroup type_trait
* Build the node to be stored on the autodiff graph.
Expand All @@ -150,8 +187,11 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
operand2 = edge2_.operand(), partial2 = edge2_.partial(),
operand3 = edge3_.operand(), partial3 = edge3_.partial(),
operand4 = edge4_.operand(), partial4 = edge4_.partial(),
operand5 = edge5_.operand(),
partial5 = edge5_.partial()](const auto& vi) mutable {
operand5 = edge5_.operand(), partial5 = edge5_.partial(),
operand6 = edge6_.operand(), partial6 = edge6_.partial(),
operand7 = edge7_.operand(), partial7 = edge7_.partial(),
operand8 = edge8_.operand(),
partial8 = edge8_.partial()](const auto& vi) mutable {
if (!is_constant<Op1>::value) {
internal::update_adjoints(operand1, partial1, vi);
}
Expand All @@ -167,6 +207,15 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
if (!is_constant<Op5>::value) {
internal::update_adjoints(operand5, partial5, vi);
}
if (!is_constant<Op6>::value) {
internal::update_adjoints(operand6, partial6, vi);
}
if (!is_constant<Op7>::value) {
internal::update_adjoints(operand7, partial7, vi);
}
if (!is_constant<Op8>::value) {
internal::update_adjoints(operand8, partial8, vi);
}
});
}
};
Expand All @@ -186,7 +235,8 @@ class ops_partials_edge<double, std::vector<var>> {
operands_(op.begin(), op.end()) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
Op operands_;

Expand All @@ -207,7 +257,8 @@ class ops_partials_edge<double, Op, require_eigen_st<is_var, Op>> {
operands_(ops) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
arena_t<Op> operands_;
inline int size() const noexcept { return this->operands_.size(); }
Expand All @@ -228,7 +279,8 @@ class ops_partials_edge<double, var_value<Op>, require_eigen_t<Op>> {
operands_(ops) {}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
var_value<Op> operands_;

Expand Down Expand Up @@ -256,7 +308,8 @@ class ops_partials_edge<double, std::vector<Eigen::Matrix<var, R, C>>> {
}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
Op operands_;

Expand Down Expand Up @@ -286,7 +339,8 @@ class ops_partials_edge<double, std::vector<std::vector<var>>> {
}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
Op operands_;
inline int size() const noexcept {
Expand All @@ -311,7 +365,8 @@ class ops_partials_edge<double, std::vector<var_value<Op>>,
}

private:
template <typename, typename, typename, typename, typename, typename>
template <typename, typename, typename, typename, typename, typename,
typename, typename, typename>
friend class stan::math::operands_and_partials;
std::vector<var_value<Op>, arena_allocator<var_value<Op>>> operands_;

Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/prim/functor/operands_and_partials_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ TEST(MathMetaPrim, OperandsAndPartials) {
operands_and_partials<double> o1(1.0);
operands_and_partials<double, double, double, double> o2(2.0, 3.0, 4.0, 5.0);

// This is size 10 because of the two empty broadcast arrays in each edge
EXPECT_EQ(10, sizeof(o2));
// This is size 16 because of the two empty broadcast arrays in each edge
EXPECT_EQ(16, sizeof(o2));

EXPECT_FLOAT_EQ(27.1, o1.build(27.1));
}
Loading