@@ -77,6 +77,7 @@ set(SRC
polys/uintpoly.cpp
polys/uratpoly.cpp
pow.cpp
+ mod.cpp
prime_sieve.cpp
printers/codegen.cpp
printers/mathml.cpp
@@ -213,6 +214,7 @@ set(HEADERS
polys/usymenginepoly.h
polys/msymenginepoly.h
pow.h
+ mod.h
prime_sieve.h
printers/codegen.h
printers/mathml.h
@@ -498,6 +498,18 @@ void DiffVisitor::bvisit(const Pow &self)
}
}
+void DiffVisitor::bvisit(const Mod &md)
+{
+ RCP<const Basic> m = md.get_divisor();
+
+ // 如果模数是常数,导数返回1
+ if (is_a_Number(*m) or is_a<Constant>(*m)) {
+ result_ = one;
+ } else {
+ result_ = Derivative::create(md.rcp_from_this(), {x});
+ }
+}
+
void DiffVisitor::bvisit(const Sin &self)
{
apply(self.get_arg());
@@ -77,6 +77,7 @@ public:
void bvisit(const Add &self);
void bvisit(const Mul &self);
void bvisit(const Pow &self);
+ void bvisit(const Mod &self);
void bvisit(const Sin &self);
void bvisit(const Cos &self);
void bvisit(const Tan &self);
@@ -1966,6 +1966,8 @@ bool Derivative::is_canonical(const RCP<const Basic> &arg,
return true;
} else if (is_a<Abs>(*arg)) {
return true;
+ } else if (is_a<Mod>(*arg)) {
+ return true;
} else if (is_a<FunctionWrapper>(*arg)) {
return true;
} else if (is_a<PolyGamma>(*arg) or is_a<Zeta>(*arg)
@@ -112,6 +112,13 @@ public:
mp_pow_ui(tmp, i, mp_get_ui(other.i));
return make_rcp<const Integer>(std::move(tmp));
}
+
+ inline RCP<const Number> modint(const Integer &other) const
+ {
+ integer_class tmp;
+ tmp = i % other.i;
+ return make_rcp<const Integer>(std::move(tmp));
+ }
//! \return negative of self.
inline RCP<const Integer> neg() const
{
@@ -179,6 +186,20 @@ public:
{
throw NotImplementedError("Not Implemented");
};
+
+ RCP<const Number> mod(const Number &other) const override
+ {
+ if (is_a<Integer>(other)) {
+ return modint(down_cast<const Integer &>(other));
+ } else {
+ return other.rmod(*this);
+ }
+ };
+
+ RCP<const Number> rmod(const Number &other) const override
+ {
+ return other.mod(*this);
+ };
};
//! less operator (<) for Integers
new file mode 100644
@@ -0,0 +1,106 @@
+#include <symengine/mod.h>
+#include <symengine/add.h>
+#include <symengine/test_visitors.h>
+
+namespace SymEngine
+{
+
+Mod::Mod(const RCP<const Basic> ÷nd, const RCP<const Basic> &divisor)
+ : dividend_{dividend}, divisor_{divisor}
+{
+ SYMENGINE_ASSERT(!dividend.is_null());
+ SYMENGINE_ASSERT(!divisor.is_null());
+ SYMENGINE_ASSIGN_TYPEID()
+ SYMENGINE_ASSERT(is_canonical(*dividend, *divisor))
+}
+
+bool Mod::is_canonical(const Basic ÷nd, const Basic &divisor) const
+{
+ // 模数不能为 0
+ if (is_a<Integer>(divisor) && down_cast<const Integer &>(divisor).is_zero())
+ return false;
+
+ // 避免 a mod 1 (应化简为 0)
+ if (is_a<Integer>(divisor) && down_cast<const Integer &>(divisor).is_one())
+ return false;
+
+ return true;
+}
+
+hash_t Mod::__hash__() const
+{
+ hash_t seed = SYMENGINE_MOD;
+ hash_combine<Basic>(seed, *dividend_);
+ hash_combine<Basic>(seed, *divisor_);
+ return seed;
+}
+
+bool Mod::__eq__(const Basic &o) const
+{
+ if (is_a<Mod>(o) and eq(*dividend_, *(down_cast<const Mod &>(o).dividend_))
+ and eq(*divisor_, *(down_cast<const Mod &>(o).divisor_)))
+ return true;
+
+ return false;
+}
+
+int Mod::compare(const Basic &o) const
+{
+ SYMENGINE_ASSERT(is_a<Mod>(o))
+ const Mod &s = down_cast<const Mod &>(o);
+ int dividend_cmp = dividend_->__cmp__(*s.dividend_);
+ if (dividend_cmp == 0)
+ return divisor_->__cmp__(*s.divisor_);
+ else
+ return dividend_cmp;
+}
+
+RCP<const Basic> mod(const RCP<const Basic> ÷nd,
+ const RCP<const Basic> &divisor)
+{
+ if (is_number_and_zero(*divisor)) {
+ throw SymEngineException("divisor must not zero.");
+ }
+
+ // 0 mod xx = 0
+ if (is_a<Integer>(*dividend)
+ && down_cast<const Integer &>(*dividend).is_zero()) {
+ return zero;
+ }
+
+ // a mod 1 = 0
+ if (is_a<Integer>(*divisor)
+ && down_cast<const Integer &>(*divisor).is_one()) {
+ return zero;
+ }
+
+ // 数值情况直接计算
+ if (is_a_Number(*dividend) && is_a_Number(*divisor)) {
+ RCP<const Number> a_num = rcp_static_cast<const Number>(dividend);
+ RCP<const Number> m_num = rcp_static_cast<const Number>(divisor);
+ return a_num->mod(*m_num);
+ }
+
+ // 处理乘法表达式 (k * x) mod m → Mod((k mod m) * (x mod m), m)
+ if (is_a<Mul>(*dividend)) {
+ RCP<const Mul> mul_a = rcp_static_cast<const Mul>(dividend);
+ RCP<const Basic> coeff = mul_a->get_coef();
+ map_basic_basic dict = mul_a->get_dict();
+ RCP<const Basic> mod_coeff = mod(coeff, divisor);
+ if (is_a_Number(*mod_coeff)
+ && down_cast<const Integer &>(*mod_coeff).is_zero()) {
+ return zero;
+ }
+ RCP<const Basic> mod_x = Mul::from_dict(one, std::move(dict));
+ return make_rcp<const Mod>(mul(mod_coeff, mod_x), divisor);
+ }
+
+ return make_rcp<const Mod>(dividend, divisor);
+}
+
+vec_basic Mod::get_args() const
+{
+ return {dividend_, divisor_};
+}
+
+} // namespace SymEngine
new file mode 100644
@@ -0,0 +1,62 @@
+/**
+ * \file mod.h
+ * Moder Class
+ *
+ **/
+#ifndef SYMENGINE_MOD_H
+#define SYMENGINE_MOD_H
+
+#include <symengine/functions.h>
+#include <symengine/mul.h>
+#include <symengine/ntheory.h>
+#include <symengine/constants.h>
+
+namespace SymEngine
+{
+
+class Mod : public Basic
+{
+private:
+ RCP<const Basic> dividend_, divisor_;
+
+public:
+ IMPLEMENT_TYPEID(SYMENGINE_MOD)
+
+ //! Mod Constructor
+ Mod(const RCP<const Basic> ÷nd, const RCP<const Basic> &divisor);
+
+ //! \return Size of the hash
+ hash_t __hash__() const override;
+
+ /*! Equality comparator
+ * \param o - Object to be compared with
+ * \return whether the 2 objects are equal
+ * */
+ bool __eq__(const Basic &o) const override;
+
+ int compare(const Basic &o) const override;
+
+ //! \return `true` if canonical
+ bool is_canonical(const Basic ÷nd, const Basic &divisor) const;
+
+ //! \return `dividend` of `dividend%divisor`
+ inline RCP<const Basic> get_dividend() const
+ {
+ return dividend_;
+ }
+
+ //! \return `divisor` of `dividend%divisor`
+ inline RCP<const Basic> get_divisor() const
+ {
+ return divisor_;
+ }
+
+ vec_basic get_args() const override;
+};
+
+//! \return Mod from `a` and `b`
+RCP<const Basic> mod(const RCP<const Basic> ÷nd,
+ const RCP<const Basic> &divisor);
+} // namespace SymEngine
+
+#endif
@@ -7,6 +7,7 @@
#include <symengine/functions.h>
#include <symengine/add.h>
#include <symengine/pow.h>
+#include <symengine/mod.h>
namespace SymEngine
{
@@ -63,6 +63,17 @@ public:
virtual RCP<const Number> pow(const Number &other) const = 0;
virtual RCP<const Number> rpow(const Number &other) const = 0;
+ virtual RCP<const Number> mod(const Number &other) const
+ {
+ throw NotImplementedError(std::string("mod not implemented for type ")
+ + type_code_name(this->get_type_code()));
+ };
+ virtual RCP<const Number> rmod(const Number &other) const
+ {
+ throw NotImplementedError(std::string("mod not implemented for type ")
+ + type_code_name(this->get_type_code()));
+ };
+
vec_basic get_args() const override
{
return {};
@@ -44,8 +44,8 @@
// Unqualified %code blocks.
#line 22 "parser.yy"
-#include "symengine/basic.h"
#include "symengine/pow.h"
+#include "symengine/mod.h"
#include "symengine/logic.h"
#include "symengine/parser/parser.h"
#include "symengine/utilities/stream_fmt.h"
@@ -56,6 +56,7 @@ using SymEngine::vec_basic;
using SymEngine::rcp_static_cast;
using SymEngine::mul;
using SymEngine::pow;
+using SymEngine::mod;
using SymEngine::add;
using SymEngine::sub;
using SymEngine::Lt;
@@ -87,7 +88,7 @@ void parser::error(const std::string &msg)
}
-#line 91 "parser.tab.cc"
+#line 92 "parser.tab.cc"
#ifndef YY_
@@ -160,7 +161,7 @@ void parser::error(const std::string &msg)
#define YYRECOVERING() (!!yyerrstatus_)
namespace yy {
-#line 164 "parser.tab.cc"
+#line 165 "parser.tab.cc"
/// Build a parser object.
parser::parser (SymEngine::Parser &p_yyarg)
@@ -822,40 +823,46 @@ namespace yy {
switch (yyn)
{
case 2: // st_expr: expr
-#line 104 "parser.yy"
+#line 105 "parser.yy"
{
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ();
p.res = yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > ();
}
-#line 831 "parser.tab.cc"
+#line 832 "parser.tab.cc"
break;
case 3: // expr: expr '+' expr
-#line 112 "parser.yy"
+#line 113 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = add(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 837 "parser.tab.cc"
+#line 838 "parser.tab.cc"
break;
case 4: // expr: expr '-' expr
-#line 115 "parser.yy"
+#line 116 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = sub(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 843 "parser.tab.cc"
+#line 844 "parser.tab.cc"
break;
case 5: // expr: expr '*' expr
-#line 118 "parser.yy"
+#line 119 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = mul(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 849 "parser.tab.cc"
+#line 850 "parser.tab.cc"
break;
- case 6: // expr: expr '/' expr
-#line 121 "parser.yy"
+ case 6: // expr: expr '%' expr
+#line 122 "parser.yy"
+ { yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = mod(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
+#line 856 "parser.tab.cc"
+ break;
+
+ case 7: // expr: expr '/' expr
+#line 125 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = div(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 855 "parser.tab.cc"
+#line 862 "parser.tab.cc"
break;
- case 7: // expr: IMPLICIT_MUL POW expr
-#line 126 "parser.yy"
+ case 8: // expr: IMPLICIT_MUL POW expr
+#line 130 "parser.yy"
{
auto tup = p.parse_implicit_mul(yystack_[2].value.as < std::string > ());
if (neq(*std::get<1>(tup), *one)) {
@@ -864,165 +871,165 @@ namespace yy {
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = pow(std::get<0>(tup), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ());
}
}
-#line 868 "parser.tab.cc"
+#line 875 "parser.tab.cc"
break;
- case 8: // expr: expr POW expr
-#line 136 "parser.yy"
+ case 9: // expr: expr POW expr
+#line 140 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = pow(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 874 "parser.tab.cc"
+#line 881 "parser.tab.cc"
break;
- case 9: // expr: expr '<' expr
-#line 139 "parser.yy"
+ case 10: // expr: expr '<' expr
+#line 143 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(Lt(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ())); }
-#line 880 "parser.tab.cc"
+#line 887 "parser.tab.cc"
break;
- case 10: // expr: expr '>' expr
-#line 142 "parser.yy"
+ case 11: // expr: expr '>' expr
+#line 146 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(Gt(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ())); }
-#line 886 "parser.tab.cc"
+#line 893 "parser.tab.cc"
break;
- case 11: // expr: expr NE expr
-#line 145 "parser.yy"
+ case 12: // expr: expr NE expr
+#line 149 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(Ne(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ())); }
-#line 892 "parser.tab.cc"
+#line 899 "parser.tab.cc"
break;
- case 12: // expr: expr LE expr
-#line 148 "parser.yy"
+ case 13: // expr: expr LE expr
+#line 152 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(Le(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ())); }
-#line 898 "parser.tab.cc"
+#line 905 "parser.tab.cc"
break;
- case 13: // expr: expr GE expr
-#line 151 "parser.yy"
+ case 14: // expr: expr GE expr
+#line 155 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(Ge(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ())); }
-#line 904 "parser.tab.cc"
+#line 911 "parser.tab.cc"
break;
- case 14: // expr: expr EQ expr
-#line 154 "parser.yy"
+ case 15: // expr: expr EQ expr
+#line 158 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(Eq(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > (), yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ())); }
-#line 910 "parser.tab.cc"
+#line 917 "parser.tab.cc"
break;
- case 15: // expr: expr '|' expr
-#line 157 "parser.yy"
+ case 16: // expr: expr '|' expr
+#line 161 "parser.yy"
{
set_boolean s;
s.insert(rcp_static_cast<const Boolean>(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > ()));
s.insert(rcp_static_cast<const Boolean>(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()));
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(logical_or(s));
}
-#line 921 "parser.tab.cc"
+#line 928 "parser.tab.cc"
break;
- case 16: // expr: expr '&' expr
-#line 165 "parser.yy"
+ case 17: // expr: expr '&' expr
+#line 169 "parser.yy"
{
set_boolean s;
s.insert(rcp_static_cast<const Boolean>(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > ()));
s.insert(rcp_static_cast<const Boolean>(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()));
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(logical_and(s));
}
-#line 932 "parser.tab.cc"
+#line 939 "parser.tab.cc"
break;
- case 17: // expr: expr '^' expr
-#line 173 "parser.yy"
+ case 18: // expr: expr '^' expr
+#line 177 "parser.yy"
{
vec_boolean s;
s.push_back(rcp_static_cast<const Boolean>(yystack_[2].value.as < SymEngine::RCP<const SymEngine::Basic> > ()));
s.push_back(rcp_static_cast<const Boolean>(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()));
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(logical_xor(s));
}
-#line 943 "parser.tab.cc"
+#line 950 "parser.tab.cc"
break;
- case 18: // expr: '(' expr ')'
-#line 181 "parser.yy"
+ case 19: // expr: '(' expr ')'
+#line 185 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = yystack_[1].value.as < SymEngine::RCP<const SymEngine::Basic> > (); }
-#line 949 "parser.tab.cc"
+#line 956 "parser.tab.cc"
break;
- case 19: // expr: '-' expr
-#line 184 "parser.yy"
+ case 20: // expr: '-' expr
+#line 188 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = neg(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 955 "parser.tab.cc"
+#line 962 "parser.tab.cc"
break;
- case 20: // expr: '+' expr
-#line 187 "parser.yy"
+ case 21: // expr: '+' expr
+#line 191 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > (); }
-#line 961 "parser.tab.cc"
+#line 968 "parser.tab.cc"
break;
- case 21: // expr: '~' expr
-#line 190 "parser.yy"
+ case 22: // expr: '~' expr
+#line 194 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(logical_not(rcp_static_cast<const Boolean>(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()))); }
-#line 967 "parser.tab.cc"
+#line 974 "parser.tab.cc"
break;
- case 22: // expr: leaf
-#line 193 "parser.yy"
+ case 23: // expr: leaf
+#line 197 "parser.yy"
{ yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = rcp_static_cast<const Basic>(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ()); }
-#line 973 "parser.tab.cc"
+#line 980 "parser.tab.cc"
break;
- case 23: // leaf: IDENTIFIER
-#line 198 "parser.yy"
+ case 24: // leaf: IDENTIFIER
+#line 202 "parser.yy"
{
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = p.parse_identifier(yystack_[0].value.as < std::string > ());
}
-#line 981 "parser.tab.cc"
+#line 988 "parser.tab.cc"
break;
- case 24: // leaf: IMPLICIT_MUL
-#line 203 "parser.yy"
+ case 25: // leaf: IMPLICIT_MUL
+#line 207 "parser.yy"
{
auto tup = p.parse_implicit_mul(yystack_[0].value.as < std::string > ());
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = mul(std::get<0>(tup), std::get<1>(tup));
}
-#line 990 "parser.tab.cc"
+#line 997 "parser.tab.cc"
break;
- case 25: // leaf: NUMERIC
-#line 209 "parser.yy"
+ case 26: // leaf: NUMERIC
+#line 213 "parser.yy"
{
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = p.parse_numeric(yystack_[0].value.as < std::string > ());
}
-#line 998 "parser.tab.cc"
+#line 1005 "parser.tab.cc"
break;
- case 26: // leaf: func
-#line 214 "parser.yy"
+ case 27: // leaf: func
+#line 218 "parser.yy"
{
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ();
}
-#line 1006 "parser.tab.cc"
+#line 1013 "parser.tab.cc"
break;
- case 27: // leaf: pwise
-#line 219 "parser.yy"
+ case 28: // leaf: pwise
+#line 223 "parser.yy"
{
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ();
}
-#line 1014 "parser.tab.cc"
+#line 1021 "parser.tab.cc"
break;
- case 28: // func: IDENTIFIER '(' expr_list ')'
-#line 226 "parser.yy"
+ case 29: // func: IDENTIFIER '(' expr_list ')'
+#line 230 "parser.yy"
{
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = p.functionify(yystack_[3].value.as < std::string > (), yystack_[1].value.as < SymEngine::vec_basic > ());
}
-#line 1022 "parser.tab.cc"
+#line 1029 "parser.tab.cc"
break;
- case 29: // epair: '(' expr ',' expr ')'
-#line 234 "parser.yy"
+ case 30: // epair: '(' expr ',' expr ')'
+#line 238 "parser.yy"
{
auto logical_expr = yystack_[1].value.as < SymEngine::RCP<const SymEngine::Basic> > ();
if (!SymEngine::is_a_sub<Boolean>(*logical_expr)) {
@@ -1031,54 +1038,54 @@ namespace yy {
}
yylhs.value.as < std::pair<SymEngine::RCP<const SymEngine::Basic>, SymEngine::RCP<const SymEngine::Boolean>> > () = std::make_pair(yystack_[3].value.as < SymEngine::RCP<const SymEngine::Basic> > (), rcp_static_cast<const Boolean>(logical_expr));
}
-#line 1035 "parser.tab.cc"
+#line 1042 "parser.tab.cc"
break;
- case 30: // piecewise_list: piecewise_list ',' epair
-#line 246 "parser.yy"
+ case 31: // piecewise_list: piecewise_list ',' epair
+#line 250 "parser.yy"
{
yylhs.value.as < SymEngine::PiecewiseVec > () = yystack_[2].value.as < SymEngine::PiecewiseVec > ();
yylhs.value.as < SymEngine::PiecewiseVec > () .push_back(yystack_[0].value.as < std::pair<SymEngine::RCP<const SymEngine::Basic>, SymEngine::RCP<const SymEngine::Boolean>> > ());
}
-#line 1044 "parser.tab.cc"
+#line 1051 "parser.tab.cc"
break;
- case 31: // piecewise_list: epair
-#line 252 "parser.yy"
+ case 32: // piecewise_list: epair
+#line 256 "parser.yy"
{
yylhs.value.as < SymEngine::PiecewiseVec > () = SymEngine::PiecewiseVec(1, yystack_[0].value.as < std::pair<SymEngine::RCP<const SymEngine::Basic>, SymEngine::RCP<const SymEngine::Boolean>> > ());
}
-#line 1052 "parser.tab.cc"
+#line 1059 "parser.tab.cc"
break;
- case 32: // pwise: PIECEWISE '(' piecewise_list ')'
-#line 259 "parser.yy"
+ case 33: // pwise: PIECEWISE '(' piecewise_list ')'
+#line 263 "parser.yy"
{
assert(yystack_[3].value.as < std::string > () == "Piecewise");
yylhs.value.as < SymEngine::RCP<const SymEngine::Basic> > () = piecewise(std::move(yystack_[1].value.as < SymEngine::PiecewiseVec > ()));
}
-#line 1061 "parser.tab.cc"
+#line 1068 "parser.tab.cc"
break;
- case 33: // expr_list: expr_list ',' expr
-#line 268 "parser.yy"
+ case 34: // expr_list: expr_list ',' expr
+#line 272 "parser.yy"
{
yylhs.value.as < SymEngine::vec_basic > () = yystack_[2].value.as < SymEngine::vec_basic > (); // TODO : should make copy?
yylhs.value.as < SymEngine::vec_basic > () .push_back(yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ());
}
-#line 1070 "parser.tab.cc"
+#line 1077 "parser.tab.cc"
break;
- case 34: // expr_list: expr
-#line 274 "parser.yy"
+ case 35: // expr_list: expr
+#line 278 "parser.yy"
{
yylhs.value.as < SymEngine::vec_basic > () = vec_basic(1, yystack_[0].value.as < SymEngine::RCP<const SymEngine::Basic> > ());
}
-#line 1078 "parser.tab.cc"
+#line 1085 "parser.tab.cc"
break;
-#line 1082 "parser.tab.cc"
+#line 1089 "parser.tab.cc"
default:
break;
@@ -1267,126 +1274,130 @@ namespace yy {
- const signed char parser::yypact_ninf_ = -13;
+ const signed char parser::yypact_ninf_ = -14;
const signed char parser::yytable_ninf_ = -1;
const short
parser::yypact_[] =
{
- 29, 28, 50, -13, 54, 29, 29, 29, 29, 77,
- 110, -13, -13, -13, 68, 29, 29, 71, 71, 72,
- -13, -13, 29, 29, 29, 29, 29, 29, 29, 29,
- 29, 29, 29, 29, 29, 29, 29, -13, -12, 110,
- -11, 71, -13, 125, 139, 152, 25, 163, 173, -10,
- 181, 32, 53, 53, 71, 71, 71, 51, -13, 68,
- -13, 29, 29, -13, 110, 91, -13
+ 30, 29, 49, -14, 52, 30, 30, 30, 30, 77,
+ 115, -14, -14, -14, 54, 30, 30, 55, 55, 75,
+ -14, -14, 30, 30, 30, 30, 30, 30, 30, 30,
+ 30, 30, 30, 30, 30, 30, 30, 30, -14, -13,
+ 115, -12, 55, -14, 131, 146, 160, 173, 25, 184,
+ 194, -11, 33, 200, 200, 55, 55, 55, 55, 53,
+ -14, 54, -14, 30, 30, -14, 115, 95, -14
};
const signed char
parser::yydefact_[] =
{
- 0, 0, 23, 25, 24, 0, 0, 0, 0, 0,
- 2, 22, 26, 27, 0, 0, 0, 19, 20, 0,
- 21, 1, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 31, 0, 34,
- 0, 7, 18, 15, 17, 16, 14, 10, 9, 11,
- 12, 13, 4, 3, 5, 6, 8, 0, 32, 0,
- 28, 0, 0, 30, 33, 0, 29
+ 0, 0, 24, 26, 25, 0, 0, 0, 0, 0,
+ 2, 23, 27, 28, 0, 0, 0, 20, 21, 0,
+ 22, 1, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 32, 0,
+ 35, 0, 8, 19, 16, 18, 17, 15, 11, 10,
+ 12, 13, 14, 4, 3, 5, 7, 6, 9, 0,
+ 33, 0, 29, 0, 0, 31, 34, 0, 30
};
const signed char
parser::yypgoto_[] =
{
- -13, -13, -5, -13, -13, 36, -13, -13, -13
+ -14, -14, -5, -14, -14, 19, -14, -14, -14
};
const signed char
parser::yydefgoto_[] =
{
- 0, 9, 10, 11, 12, 37, 38, 13, 40
+ 0, 9, 10, 11, 12, 38, 39, 13, 41
};
const signed char
parser::yytable_[] =
{
- 17, 18, 19, 20, 29, 30, 31, 32, 33, 34,
- 39, 41, 35, 58, 60, 59, 61, 43, 44, 45,
- 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
- 56, 57, 1, 2, 3, 4, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 5, 6, 35, 31, 32,
- 33, 34, 14, 7, 35, 8, 64, 65, 22, 23,
+ 17, 18, 19, 20, 30, 31, 32, 33, 34, 35,
+ 40, 42, 36, 60, 62, 61, 63, 44, 45, 46,
+ 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
+ 57, 58, 59, 1, 2, 3, 4, 27, 28, 29,
+ 30, 31, 32, 33, 34, 35, 5, 6, 36, 31,
+ 32, 33, 34, 35, 14, 7, 36, 8, 66, 67,
+ 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 15, 16, 36, 21, 36, 37,
+ 65, 64, 22, 23, 24, 25, 26, 27, 28, 29,
+ 30, 31, 32, 33, 34, 35, 0, 0, 36, 0,
+ 0, 43, 22, 23, 24, 25, 26, 27, 28, 29,
+ 30, 31, 32, 33, 34, 35, 0, 0, 36, 0,
+ 0, 68, 22, 23, 24, 25, 26, 27, 28, 29,
+ 30, 31, 32, 33, 34, 35, 0, 0, 36, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
- 34, 33, 34, 35, 15, 35, 16, 21, 62, 22,
- 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 36, 35, 35, 63, 0, 42, 22, 23,
- 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
- 34, 0, 0, 35, 0, 0, 66, 22, 23, 24,
+ 34, 35, 0, 0, 36, 24, 25, 26, 27, 28,
+ 29, 30, 31, 32, 33, 34, 35, 0, 0, 36,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
- 0, 0, 35, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 0, 0, 35, 24, 25,
- 26, 27, 28, 29, 30, 31, 32, 33, 34, 0,
- 0, 35, 25, 26, 27, 28, 29, 30, 31, 32,
- 33, 34, 0, 0, 35, 27, 28, 29, 30, 31,
- 32, 33, 34, 0, 0, 35, 28, 29, 30, 31,
- 32, 33, 34, 0, 0, 35, 30, 31, 32, 33,
- 34, 0, 0, 35
+ 35, 0, 0, 36, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 0, 0, 36, 28, 29, 30,
+ 31, 32, 33, 34, 35, 0, 0, 36, 29, 30,
+ 31, 32, 33, 34, 35, 0, 0, 36, 33, 34,
+ 35, 0, 0, 36
};
const signed char
parser::yycheck_[] =
{
- 5, 6, 7, 8, 14, 15, 16, 17, 18, 19,
- 15, 16, 22, 25, 25, 27, 27, 22, 23, 24,
+ 5, 6, 7, 8, 15, 16, 17, 18, 19, 20,
+ 15, 16, 23, 26, 26, 28, 28, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
- 35, 36, 3, 4, 5, 6, 11, 12, 13, 14,
- 15, 16, 17, 18, 19, 16, 17, 22, 16, 17,
- 18, 19, 24, 24, 22, 26, 61, 62, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
- 19, 18, 19, 22, 24, 22, 22, 0, 27, 7,
- 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
- 18, 19, 24, 22, 22, 59, -1, 25, 7, 8,
+ 35, 36, 37, 3, 4, 5, 6, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 16, 17, 23, 16,
+ 17, 18, 19, 20, 25, 25, 23, 27, 63, 64,
+ 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 25, 23, 23, 0, 23, 25,
+ 61, 28, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, -1, -1, 23, -1,
+ -1, 26, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, -1, -1, 23, -1,
+ -1, 26, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, -1, -1, 23, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
- 19, -1, -1, 22, -1, -1, 25, 7, 8, 9,
+ 19, 20, -1, -1, 23, 9, 10, 11, 12, 13,
+ 14, 15, 16, 17, 18, 19, 20, -1, -1, 23,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
- -1, -1, 22, 8, 9, 10, 11, 12, 13, 14,
- 15, 16, 17, 18, 19, -1, -1, 22, 9, 10,
- 11, 12, 13, 14, 15, 16, 17, 18, 19, -1,
- -1, 22, 10, 11, 12, 13, 14, 15, 16, 17,
- 18, 19, -1, -1, 22, 12, 13, 14, 15, 16,
- 17, 18, 19, -1, -1, 22, 13, 14, 15, 16,
- 17, 18, 19, -1, -1, 22, 15, 16, 17, 18,
- 19, -1, -1, 22
+ 20, -1, -1, 23, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, -1, -1, 23, 13, 14, 15,
+ 16, 17, 18, 19, 20, -1, -1, 23, 14, 15,
+ 16, 17, 18, 19, 20, -1, -1, 23, 18, 19,
+ 20, -1, -1, 23
};
const signed char
parser::yystos_[] =
{
- 0, 3, 4, 5, 6, 16, 17, 24, 26, 29,
- 30, 31, 32, 35, 24, 24, 22, 30, 30, 30,
- 30, 0, 7, 8, 9, 10, 11, 12, 13, 14,
- 15, 16, 17, 18, 19, 22, 24, 33, 34, 30,
- 36, 30, 25, 30, 30, 30, 30, 30, 30, 30,
- 30, 30, 30, 30, 30, 30, 30, 30, 25, 27,
- 25, 27, 27, 33, 30, 30, 25
+ 0, 3, 4, 5, 6, 16, 17, 25, 27, 30,
+ 31, 32, 33, 36, 25, 25, 23, 31, 31, 31,
+ 31, 0, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 23, 25, 34, 35,
+ 31, 37, 31, 26, 31, 31, 31, 31, 31, 31,
+ 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,
+ 26, 28, 26, 28, 28, 34, 31, 31, 26
};
const signed char
parser::yyr1_[] =
{
- 0, 28, 29, 30, 30, 30, 30, 30, 30, 30,
- 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
- 30, 30, 30, 31, 31, 31, 31, 31, 32, 33,
- 34, 34, 35, 36, 36
+ 0, 29, 30, 31, 31, 31, 31, 31, 31, 31,
+ 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,
+ 31, 31, 31, 31, 32, 32, 32, 32, 32, 33,
+ 34, 35, 35, 36, 37, 37
};
const signed char
parser::yyr2_[] =
{
0, 2, 1, 3, 3, 3, 3, 3, 3, 3,
- 3, 3, 3, 3, 3, 3, 3, 3, 3, 2,
- 2, 2, 1, 1, 1, 1, 1, 1, 4, 5,
- 3, 1, 4, 3, 1
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 2, 2, 2, 1, 1, 1, 1, 1, 1, 4,
+ 5, 3, 1, 4, 3, 1
};
@@ -1398,9 +1409,9 @@ namespace yy {
{
"END_OF_FILE", "error", "\"invalid token\"", "PIECEWISE", "IDENTIFIER",
"NUMERIC", "IMPLICIT_MUL", "'|'", "'^'", "'&'", "EQ", "'>'", "'<'", "NE",
- "LE", "GE", "'-'", "'+'", "'*'", "'/'", "UMINUS", "UPLUS", "POW", "NOT",
- "'('", "')'", "'~'", "','", "$accept", "st_expr", "expr", "leaf", "func",
- "epair", "piecewise_list", "pwise", "expr_list", YY_NULLPTR
+ "LE", "GE", "'-'", "'+'", "'*'", "'/'", "'%'", "UMINUS", "UPLUS", "POW",
+ "NOT", "'('", "')'", "'~'", "','", "$accept", "st_expr", "expr", "leaf",
+ "func", "epair", "piecewise_list", "pwise", "expr_list", YY_NULLPTR
};
#endif
@@ -1409,10 +1420,10 @@ namespace yy {
const short
parser::yyrline_[] =
{
- 0, 103, 103, 111, 114, 117, 120, 125, 135, 138,
- 141, 144, 147, 150, 153, 156, 164, 172, 180, 183,
- 186, 189, 192, 197, 202, 208, 213, 218, 225, 233,
- 245, 251, 258, 267, 273
+ 0, 104, 104, 112, 115, 118, 121, 124, 129, 139,
+ 142, 145, 148, 151, 154, 157, 160, 168, 176, 184,
+ 187, 190, 193, 196, 201, 206, 212, 217, 222, 229,
+ 237, 249, 255, 262, 271, 277
};
void
@@ -1454,8 +1465,8 @@ namespace yy {
0, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2, 9, 2,
- 24, 25, 18, 17, 27, 16, 2, 19, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 20, 9, 2,
+ 25, 26, 18, 17, 28, 16, 2, 19, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
12, 2, 11, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
@@ -1463,7 +1474,7 @@ namespace yy {
2, 2, 2, 2, 8, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 7, 2, 26, 2, 2, 2,
+ 2, 2, 2, 2, 7, 2, 27, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
@@ -1477,7 +1488,7 @@ namespace yy {
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 1, 2, 3, 4,
- 5, 6, 10, 13, 14, 15, 20, 21, 22, 23
+ 5, 6, 10, 13, 14, 15, 21, 22, 23, 24
};
// Last valid token kind.
const int code_max = 269;
@@ -1491,5 +1502,5 @@ namespace yy {
}
} // yy
-#line 1495 "parser.tab.cc"
+#line 1506 "parser.tab.cc"
@@ -474,7 +474,7 @@ namespace yy {
{
enum symbol_kind_type
{
- YYNTOKENS = 28, ///< Number of tokens.
+ YYNTOKENS = 29, ///< Number of tokens.
S_YYEMPTY = -2,
S_YYEOF = 0, // END_OF_FILE
S_YYerror = 1, // error
@@ -496,23 +496,24 @@ namespace yy {
S_17_ = 17, // '+'
S_18_ = 18, // '*'
S_19_ = 19, // '/'
- S_UMINUS = 20, // UMINUS
- S_UPLUS = 21, // UPLUS
- S_POW = 22, // POW
- S_NOT = 23, // NOT
- S_24_ = 24, // '('
- S_25_ = 25, // ')'
- S_26_ = 26, // '~'
- S_27_ = 27, // ','
- S_YYACCEPT = 28, // $accept
- S_st_expr = 29, // st_expr
- S_expr = 30, // expr
- S_leaf = 31, // leaf
- S_func = 32, // func
- S_epair = 33, // epair
- S_piecewise_list = 34, // piecewise_list
- S_pwise = 35, // pwise
- S_expr_list = 36 // expr_list
+ S_20_ = 20, // '%'
+ S_UMINUS = 21, // UMINUS
+ S_UPLUS = 22, // UPLUS
+ S_POW = 23, // POW
+ S_NOT = 24, // NOT
+ S_25_ = 25, // '('
+ S_26_ = 26, // ')'
+ S_27_ = 27, // '~'
+ S_28_ = 28, // ','
+ S_YYACCEPT = 29, // $accept
+ S_st_expr = 30, // st_expr
+ S_expr = 31, // expr
+ S_leaf = 32, // leaf
+ S_func = 33, // func
+ S_epair = 34, // epair
+ S_piecewise_list = 35, // piecewise_list
+ S_pwise = 36, // pwise
+ S_expr_list = 37 // expr_list
};
};
@@ -1389,7 +1390,7 @@ switch (yykind)
/// Constants.
enum
{
- yylast_ = 203, ///< Last index in yytable_.
+ yylast_ = 223, ///< Last index in yytable_.
yynnts_ = 9, ///< Number of nonterminal symbols.
yyfinal_ = 21 ///< Termination state number.
};
@@ -1402,7 +1403,7 @@ switch (yykind)
} // yy
-#line 1406 "parser.tab.hh"
+#line 1407 "parser.tab.hh"
@@ -21,6 +21,7 @@
%code // *.cpp
{
#include "symengine/pow.h"
+#include "symengine/mod.h"
#include "symengine/logic.h"
#include "symengine/parser/parser.h"
#include "symengine/utilities/stream_fmt.h"
@@ -31,6 +32,7 @@ using SymEngine::vec_basic;
using SymEngine::rcp_static_cast;
using SymEngine::mul;
using SymEngine::pow;
+using SymEngine::mod;
using SymEngine::add;
using SymEngine::sub;
using SymEngine::Lt;
@@ -79,7 +81,7 @@ void parser::error(const std::string &msg)
%left LE
%left GE
%left '-' '+'
-%left '*' '/'
+%left '*' '/' '%'
%right UMINUS
%right UPLUS
%right POW
@@ -115,6 +117,9 @@ expr:
|
expr '*' expr
{ $$ = mul($1, $3); }
+|
+ expr '%' expr
+ { $$ = mod($1, $3); }
|
expr '/' expr
{ $$ = div($1, $3); }
@@ -52,7 +52,7 @@ int Tokenizer::lex(yy::parser::semantic_type *yylval)
}
if (yych <= '>') {
if (yych <= '*') {
- if (yych <= '%') {
+ if (yych <= '$') {
if (yych <= 0x00)
goto yy1;
if (yych <= 0x1F)
@@ -28,9 +28,10 @@ int Tokenizer::lex(yy::parser::semantic_type* yylval)
whitespace = [ \t\v\n\r]+;
dig = [0-9];
char = [\x80-\xff] | [a-zA-Z_];
- operators = "-"|"+"|"/"|"("|")"|"*"|","|"^"|"~"|"<"|">"|"&"|"|";
+ operators = "-"|"+"|"/"|"("|")"|"*"|","|"^"|"~"|"<"|">"|"&"|"|"|"%";
pows = "**"|"@";
+ mods = "%";
le = "<=";
ge = ">=";
ne = "!=";
@@ -47,6 +48,7 @@ int Tokenizer::lex(yy::parser::semantic_type* yylval)
// FIXME:
operators { return tok[0]; }
pows { return yy::parser::token::yytokentype::POW; }
+ mods { return yy::parser::token::yytokentype::MOD; }
le { return yy::parser::token::yytokentype::LE; }
ge { return yy::parser::token::yytokentype::GE; }
ne { return yy::parser::token::yytokentype::NE; }
@@ -45,6 +45,11 @@ void Precedence::bvisit(const Pow &x)
precedence = PrecedenceEnum::Pow;
}
+void Precedence::bvisit(const Mod &x)
+{
+ precedence = PrecedenceEnum::Mod;
+}
+
void Precedence::bvisit(const GaloisField &x)
{
// iterators need to be implemented
@@ -564,6 +569,14 @@ void StrPrinter::_print_pow(std::ostringstream &o, const RCP<const Basic> &a,
}
}
+void StrPrinter::_print_mod(std::ostringstream &o, const RCP<const Basic> &a,
+ const RCP<const Basic> &b)
+{
+ o << parenthesizeLE(a, PrecedenceEnum::Mod);
+ o << "%";
+ o << parenthesizeLE(b, PrecedenceEnum::Mod);
+}
+
void StrPrinter::bvisit(const Mul &x)
{
std::ostringstream o, o2;
@@ -655,6 +668,13 @@ void StrPrinter::bvisit(const Pow &x)
str_ = o.str();
}
+void StrPrinter::bvisit(const Mod &x)
+{
+ std::ostringstream o;
+ _print_mod(o, x.get_dividend(), x.get_divisor());
+ str_ = o.str();
+}
+
template <typename T>
char _print_sign(const T &i)
{
@@ -10,7 +10,7 @@ namespace SymEngine
std::string print_double(double d);
std::vector<std::string> init_str_printer_names();
-enum class PrecedenceEnum { Relational, Add, Mul, Pow, Atom };
+enum class PrecedenceEnum { Relational, Add, Mul, Pow, Mod, Atom};
class Precedence : public BaseVisitor<Precedence>
{
@@ -22,6 +22,7 @@ public:
void bvisit(const Mul &x);
void bvisit(const Relational &x);
void bvisit(const Pow &x);
+ void bvisit(const Mod &x);
template <typename Poly>
void bvisit_upoly(const Poly &x)
{
@@ -114,6 +115,8 @@ protected:
virtual bool split_mul_coef();
virtual void _print_pow(std::ostringstream &o, const RCP<const Basic> &a,
const RCP<const Basic> &b);
+ virtual void _print_mod(std::ostringstream &o, const RCP<const Basic> &a,
+ const RCP<const Basic> &b);
virtual std::string print_div(const std::string &num,
const std::string &den, bool paren);
virtual std::string get_imag_symbol();
@@ -154,6 +157,7 @@ public:
void bvisit(const Add &x);
void bvisit(const Mul &x);
void bvisit(const Pow &x);
+ void bvisit(const Mod &x);
void bvisit(const UIntPoly &x);
void bvisit(const MIntPoly &x);
void bvisit(const URatPoly &x);
@@ -483,6 +483,55 @@ public:
throw NotImplementedError("Not Implemented");
}
}
+
+ /*! Raise `other` to mod RealDouble
+ * \param other of type Integer
+ * */
+ RCP<const Number> modreal(const Integer &other) const
+ {
+ double a = i;
+ double b = mp_get_d(other.as_integer_class());
+ return number(a - b * (std::floor(a / b)));
+ }
+
+ /*! Raise `other` to mod RealDouble
+ * \param other of type Rational
+ * */
+ RCP<const Number> modreal(const Rational &other) const
+ {
+ double a = i;
+ double b = mp_get_d(other.as_rational_class());
+ return number(a - b * (std::floor(a / b)));
+ }
+
+ /*! Raise `other` to mod RealDouble
+ * \param other of type Complex
+ * */
+ RCP<const Number> modreal(const Complex &other) const
+ {
+ auto a = i;
+ auto b = std::complex<double>(mp_get_d(other.real_),
+ mp_get_d(other.imaginary_)).real();
+ return number(a - b * (std::floor(a / b)));
+ }
+
+ RCP<const Number> mod(const Number &other) const override
+ {
+ if (is_a<Rational>(other)) {
+ return modreal(down_cast<const Rational &>(other));
+ } else if (is_a<Integer>(other)) {
+ return modreal(down_cast<const Integer &>(other));
+ } else if (is_a<Complex>(other)) {
+ return modreal(down_cast<const Complex &>(other));
+ } else {
+ throw NotImplementedError("Not Implemented");
+ }
+ }
+
+ RCP<const Number> rmod(const Number &other) const override
+ {
+ return other.mod(*this);
+ }
};
RCP<const RealDouble> real_double(double x);
@@ -51,6 +51,60 @@ void SimplifyVisitor::bvisit(const Mul &x)
result_ = Mul::from_dict(x.get_coef(), std::move(map));
}
+void SimplifyVisitor::bvisit(const Mod &x)
+{
+ RCP<const Basic> a = x.get_dividend();
+ RCP<const Basic> b = x.get_divisor();
+ apply(a); // 递归化简 a
+ apply(b); // 递归化简 b
+
+ // --- 规则 1:乘法分配律 (k * x) mod m → Mod((k mod m) * (x mod m), m) ---
+ if (is_a<Mul>(*a)) {
+ RCP<const Mul> mul_a = rcp_static_cast<const Mul>(a);
+ RCP<const Basic> new_coef = mod(mul_a->get_coef(), b);
+ map_basic_basic new_dict;
+ for (const auto &term : mul_a->get_dict()) {
+ new_dict[term.first] = mod(term.second, b);
+ }
+ result_ = mod(mul(new_coef, Mul::from_dict(one, std::move(new_dict))), b);
+ return;
+ }
+
+ // --- 规则 2:加法分配律 (a + b) mod m → Mod((a mod m) + (b mod m), m) ---
+ if (is_a<Add>(*a)) {
+ RCP<const Add> add_a = rcp_static_cast<const Add>(a);
+ RCP<const Basic> r = zero; // 剩余部分 r
+ RCP<const Number> k = zero; // 系数 k
+
+ // 遍历加法项,提取 k 和 r
+ for (const auto &term : add_a->get_dict()) {
+ if (is_a<Mul>(*term.first)) {
+ RCP<const Mul> mul_term = rcp_static_cast<const Mul>(term.first);
+ if (eq(*mul_term->get_coef(), *b)) { // 项是 k*b
+ k = addnum(k, rcp_static_cast<const Number>(term.second));
+ } else {
+ r = add(r, mul(term.first, term.second));
+ }
+ } else if (eq(*term.first, *b)) { // 项是 b
+ k = addnum(k, rcp_static_cast<const Number>(term.second));
+ } else {
+ r = add(r, mul(term.first, term.second));
+ }
+ }
+ // 处理常数项
+ r = add(r, add_a->get_coef());
+
+ // 化简为 mod(r, b)
+ if (!k->is_zero()) {
+ result_ = mod(r, b);
+ return;
+ }
+ }
+
+ // 默认情况:无法优化,返回原式
+ result_ = x.rcp_from_this();
+}
+
RCP<const Basic> simplify(const RCP<const Basic> &x,
const Assumptions *assumptions)
{
@@ -21,6 +21,7 @@ public:
void bvisit(const Mul &x);
void bvisit(const Pow &x);
+ void bvisit(const Mod &x);
void bvisit(const OneArgFunction &x);
};
@@ -136,6 +136,16 @@ public:
}
}
+ void bvisit(const Mod &x) {
+ RCP<const Basic> dividend_new = apply(x.get_dividend());
+ RCP<const Basic> divisor_new = apply(x.get_divisor());
+ if (dividend_new == x.get_dividend() and divisor_new == x.get_divisor()) {
+ result_ = x.rcp_from_this();
+ } else {
+ result_ = mod(dividend_new, divisor_new);
+ }
+ }
+
void bvisit(const OneArgFunction &x)
{
apply(x.get_arg());
@@ -386,6 +396,28 @@ public:
}
}
+ void bvisit(const Mod &x) {
+ // 1. 替换被模数 a
+ RCP<const Basic> a = apply(x.get_dividend());
+
+ // 2. 替换模数 m
+ RCP<const Basic> m = apply(x.get_divisor());
+
+ // 3. 如果替换后 a 和 m 都是数值,直接计算
+ if (is_a_Number(*a) && is_a_Number(*m)) {
+ RCP<const Number> a_num = rcp_static_cast<const Number>(a);
+ RCP<const Number> m_num = rcp_static_cast<const Number>(m);
+ if (m_num->is_zero()) {
+ throw std::runtime_error("Modulo by zero");
+ }
+ result_ = a_num->mod(*m_num);
+ return;
+ }
+
+ // 4. 默认情况:重建 Mod 对象
+ result_ = mod(a, m);
+ }
+
void bvisit(const Derivative &x)
{
RCP<const Symbol> s;
@@ -543,6 +543,19 @@ void RealVisitor::bvisit(const Pow &x)
this->check_power(x.get_base(), x.get_exp());
}
+void RealVisitor::bvisit(const Mod &x)
+{
+ tribool b = tribool::tritrue;
+ for (const auto &arg : x.get_args()) {
+ arg->accept(*this);
+ b = andwk_tribool(b, is_real_);
+ if (is_indeterminate(b)) {
+ break;
+ }
+ }
+ is_real_ = b;
+}
+
tribool RealVisitor::apply(const Basic &b)
{
b.accept(*this);
@@ -786,6 +799,25 @@ void PolynomialVisitor::bvisit(const Pow &x)
check_power(*x.get_base(), *x.get_exp());
}
+void PolynomialVisitor::bvisit(const Mod &x)
+{
+ RCP<const Basic> p = x.get_dividend(); // 被模数 p(x)
+ RCP<const Basic> q = x.get_divisor(); // 模数 q(x)
+
+ bool q_is_poly = is_polynomial(*q, variables_);
+
+ // 情况1:模数 q 是常数,且 p 是多项式
+ if (q_is_poly && is_a_Number(*q)) {
+ RCP<const Number> q_num = rcp_static_cast<const Number>(q);
+ if (q_num->is_zero()) {
+ throw std::runtime_error("Modulo by zero");
+ }
+ is_polynomial_ = true;
+ return;
+ }
+ is_polynomial_ = false;
+}
+
void PolynomialVisitor::bvisit(const Symbol &x)
{
if (variables_allowed_)
@@ -201,6 +201,7 @@ public:
void bvisit(const Constant &x);
void bvisit(const Add &x);
void bvisit(const Mul &x);
+ void bvisit(const Mod &x);
void bvisit(const Pow &x);
tribool apply(const Basic &b);
@@ -335,6 +336,7 @@ public:
void bvisit(const Add &x);
void bvisit(const Mul &x);
void bvisit(const Pow &x);
+ void bvisit(const Mod &x);
void bvisit(const Set &x)
{
is_polynomial_ = false;
@@ -4,6 +4,11 @@ add_executable(${PROJECT_NAME} test_basic.cpp)
target_link_libraries(${PROJECT_NAME} symengine catch)
add_test(${PROJECT_NAME} ${PROJECT_BINARY_DIR}/${PROJECT_NAME})
+add_executable(test_mod test_mod.cpp)
+target_link_libraries(test_mod symengine catch)
+add_test(test_mod ${PROJECT_BINARY_DIR}/test_mod)
+add_compile_options(-O0 -ggdb3 -Wall -Wextra -Wpedantic)
+
add_executable(test_arit test_arit.cpp)
target_link_libraries(test_arit symengine catch)
add_test(test_arit ${PROJECT_BINARY_DIR}/test_arit)
new file mode 100644
@@ -0,0 +1,269 @@
+#include "catch.hpp"
+#include <symengine/basic.h>
+#include <symengine/visitor.h>
+#include <symengine/simplify.h>
+#include <symengine/test_visitors.h>
+#include <symengine/parser.h>
+#include <symengine/parser/parser.h>
+
+using SymEngine::Assumptions;
+using SymEngine::Basic;
+using SymEngine::Complex;
+using SymEngine::down_cast;
+using SymEngine::E;
+using SymEngine::integer;
+using SymEngine::Integer;
+using SymEngine::is_a;
+using SymEngine::map_basic_basic;
+using SymEngine::Mod;
+using SymEngine::Number;
+using SymEngine::one;
+using SymEngine::parse;
+using SymEngine::pi;
+using SymEngine::Rational;
+using SymEngine::RCP;
+using SymEngine::reals;
+using SymEngine::Symbol;
+using SymEngine::symbol;
+using SymEngine::zero;
+
+//测试用例分类
+// 1.基础测试用例
+// 2.parse测试用例
+// 3.Simplify测试用例
+// 4.subs测试用例
+// 5.visitors测试用例
+TEST_CASE("mod: Basic", "[basic]")
+{
+ RCP<const Basic> x = symbol("x");
+ RCP<const Basic> y = symbol("y");
+ RCP<const Basic> two = integer(2);
+ RCP<const Basic> three = integer(3);
+ REQUIRE(mod(integer(5), integer(3))->compare(*two) == 0);
+ REQUIRE(mod(mul(integer(6), x), two) == zero);
+ REQUIRE(mod(mul(integer(2), x), three)->__str__() == "(2*x)%3");
+ REQUIRE(mod(add(x, two), three)->__str__() == "(2 + x)%3");
+ REQUIRE(mod(pow(x, two), three)->__str__() == "(x**2)%3");
+ REQUIRE(mod(x, integer(2))->__str__() == "x%2");
+ REQUIRE(mod(mul(x, integer(6)), integer(2)) == zero);
+ REQUIRE(mod(add(x, integer(1)), integer(1)) == zero);
+ REQUIRE_THROWS_AS(mod(x, zero), SymEngine::SymEngineException);
+ auto exp1 = mod(mul(symbol("R"),symbol("nio")), integer(16));
+ REQUIRE(exp1->__str__() == "(R*nio)%16");
+ auto exp2 = mod(mul(mul(integer(3), symbol("R")),symbol("nio")), integer(16));
+ REQUIRE(exp2->__str__() == "(3*R*nio)%16");
+ auto exp3 = mod(mul(integer(3), symbol("R")), integer(16));
+ REQUIRE(exp3->__str__() == "(3*R)%16");
+ auto exp4 = mod(mul(integer(48), symbol("R")), integer(16));
+ REQUIRE(exp4->__str__() == "0");
+ auto exp6 = mod(mul(integer(33), symbol("R")), integer(16));
+ REQUIRE(exp6->__str__() == "R%16");
+ auto exp7 = mod(mul(integer(35), symbol("R")), integer(16));
+ REQUIRE(exp7->__str__() == "(3*R)%16");
+}
+
+TEST_CASE("mod: parse", "[parse]")
+{
+ // apt-get install bison re2c # bison最好是3.8.2,re2c最好是3.0版本
+ // bison -Wcounterexamples -d parser.yy -o parser.tab.cc
+ // re2c -W --no-generation-date -b tokenizer.re -o tokenizer.cpp
+ std::string s;
+ RCP<const Basic> res;
+ RCP<const Basic> x = symbol("x");
+ RCP<const Basic> y = symbol("y");
+ RCP<const Basic> w = symbol("w1");
+ RCP<const Basic> l = symbol("l0ngn4me");
+
+ s = "x%(3+w1)-2/y";
+ res = parse(s);
+ REQUIRE(eq(*res, *add(mod(x, add(integer(3), w)), div(integer(-2), y))));
+ REQUIRE(eq(*res, *parse(res->__str__())));
+
+ s = "l0ngn4me - w1*y + 2%(x)";
+ res = parse(s);
+ REQUIRE(eq(*res, *add(add(l, neg(mul(w, y))), mod(integer(2), x))));
+ REQUIRE(eq(*res, *parse(res->__str__())));
+
+ s = "x % --y";
+ res = parse(s);
+ REQUIRE(eq(*res, *mod(x, y)));
+ REQUIRE(eq(*res, *parse(res->__str__())));
+}
+
+TEST_CASE("Simplify Mod", "[simplify]")
+{
+ RCP<const Basic> x = symbol("x");
+ RCP<const Basic> y = symbol("y");
+
+ // 常量折叠
+ REQUIRE(eq(*simplify(mod(integer(7), integer(3))), *integer(1)));
+
+ // 特殊规则
+ REQUIRE(eq(*simplify(mod(x, integer(1))), *zero));
+ REQUIRE(eq(*simplify(mod(zero, x)), *zero));
+
+ // 线性化简
+ RCP<const Basic> expr1 = mod(add(x, mul(integer(2), y)), y);
+ REQUIRE(eq(*simplify(expr1), *mod(x, y)));
+
+ // 无法化简的情况
+ RCP<const Basic> expr2 = mod(x, y);
+ REQUIRE(eq(*simplify(expr2), *expr2));
+}
+
+TEST_CASE("Mod: subs", "[subs]")
+{
+ RCP<const Symbol> x = symbol("x");
+ RCP<const Symbol> y = symbol("y");
+ RCP<const Basic> z = symbol("z");
+ RCP<const Basic> w = symbol("w");
+ RCP<const Basic> i2 = integer(2);
+ RCP<const Basic> i3 = integer(3);
+ RCP<const Basic> i4 = integer(4);
+
+ RCP<const Basic> r1 = mod(x, y);
+ RCP<const Basic> r2 = mod(y, y);
+ map_basic_basic d;
+ d[x] = y;
+ REQUIRE(eq(*r1->subs(d), *r2));
+
+ d[x] = z;
+ d[y] = w;
+ r1 = mod(x, y);
+ r2 = mod(z, w);
+ REQUIRE(eq(*r1->subs(d), *r2));
+
+ r1 = mod(x, i2);
+ r2 = mod(z, i2);
+ REQUIRE(eq(*r1->subs(d), *r2));
+
+ d.clear();
+ d[mod(x, y)] = z;
+ r1 = mod(x, y);
+ r2 = z;
+ REQUIRE(eq(*r1->subs(d), *r2));
+
+ d.clear();
+ d[mod(E, x)] = z;
+ r1 = mod(E, mul(x, x));
+ r2 = r1->subs(d);
+ REQUIRE(is_a<Mod>(*r2));
+ REQUIRE(eq(*down_cast<const Mod &>(*r2).get_dividend(), *E));
+ REQUIRE(eq(*down_cast<const Mod &>(*r2).get_divisor(), *mul(x, x)));
+
+ r2 = r1->xreplace(d);
+ REQUIRE(eq(*r1, *r2));
+
+ r1 = mod(E, mul(i2, x));
+
+ r2 = r1->xreplace(d);
+ REQUIRE(eq(*r1, *r2));
+
+ r1 = mod(E, add(i2, x));
+ r2 = r1->subs(d);
+ REQUIRE(eq(*r1, *r2));
+}
+
+TEST_CASE("Mod: visitors -> is_real", "[visitors]")
+{
+ // todo: Complex mod not support
+ // is_real
+ RCP<const Number> i1 = integer(0);
+ RCP<const Number> i2 = integer(3);
+ RCP<const Number> c1 = Complex::from_two_nums(*i1, *i2);
+ RCP<const Basic> x = symbol("x");
+ RCP<const Basic> y = symbol("y");
+ RCP<const Basic> e5 = mod(integer(2), x);
+ RCP<const Basic> e6 = mod(integer(-1), x);
+ RCP<const Basic> e7 = mul(x, y);
+ // RCP<const Basic> e8 = mul(c1, x);
+ // RCP<const Basic> e9 = mod(i2, c1);
+ REQUIRE(is_indeterminate(is_real(*e5)));
+ REQUIRE(is_indeterminate(is_real(*e6)));
+ // REQUIRE(is_indeterminate(is_real(*e9)));
+
+ const auto a4 = Assumptions({reals()->contains(x)});
+ REQUIRE(is_true(is_real(*e5, &a4)));
+}
+
+TEST_CASE("Functions", "[ccode]")
+{
+ auto x = symbol("x");
+ auto p = function_symbol("f", mod(integer(2), x));
+ REQUIRE(ccode(*p) == "f(2%x)");
+ p = function_symbol("f", mod(integer(2), x));
+ REQUIRE(jscode(*p) == "f(2%x)");
+}
+
+TEST_CASE("Mod: visitors -> diff", "[visitors]")
+{
+ RCP<const Symbol> x = symbol("x");
+ RCP<const Symbol> y = symbol("y");
+ RCP<const Integer> m = integer(3);
+
+ // 情况1:模数是常数
+ RCP<const Basic> expr1 = mod(x, m);
+ REQUIRE(diff(expr1, x, true) == one);
+
+ // 情况2:模数是变量
+ RCP<const Basic> expr2 = mod(x, y);
+ REQUIRE(diff(expr2, x, true)->__str__() == "Derivative(x%y, x)");
+}
+
+TEST_CASE("Mod: visitors -> is_polynomial", "[visitors]")
+{
+ RCP<const Basic> x = symbol("x");
+ RCP<const Basic> y = symbol("y");
+ RCP<const Basic> z = symbol("z");
+ RCP<const Number> i1 = integer(1);
+ RCP<const Number> i2 = integer(3);
+ RCP<const Number> i3 = integer(-2);
+ RCP<const Basic> rat1 = Rational::from_two_ints(*integer(5), *integer(6));
+ RCP<const Basic> e7 = mod(x, y);
+ RCP<const Basic> e8 = mod(x, i2);
+ RCP<const Basic> e9 = mod(x, rat1);
+ RCP<const Basic> e10 = mod(integer(2), x);
+ RCP<const Basic> e11 = div(i1, x);
+ RCP<const Basic> e12 = mod(x, i3);
+ RCP<const Basic> e13 = mod(x, x);
+ RCP<const Basic> e14 = mod(x, mul(i3, y));
+ RCP<const Basic> e15 = mod(mul(i2, x), y);
+ RCP<const Basic> e16 = add(add(mul(x, x), mul(i2, x)), i3);
+ RCP<const Basic> e18 = mod(sqrt(x), i2);
+ RCP<const Basic> e21
+ = add(add(mul(mul(x, x), mul(y, y)), mul(x, mod(y, i2))), exp(i3));
+ RCP<const Basic> e22
+ = add(add(mul(mul(x, x), mul(y, y)), mul(x, mod(y, i2))), exp(x));
+
+ REQUIRE(is_polynomial(*x));
+ REQUIRE(is_polynomial(*i1));
+ REQUIRE(is_polynomial(*rat1));
+ REQUIRE(is_polynomial(*pi));
+ REQUIRE(!is_polynomial(*e7));
+ REQUIRE(is_polynomial(*e8));
+ REQUIRE(is_polynomial(*e9));
+ REQUIRE(!is_polynomial(*e10));
+ REQUIRE(!is_polynomial(*e11));
+ REQUIRE(is_polynomial(*rat1, {x, y, z}));
+ REQUIRE(is_polynomial(*pi, {x, y, z}));
+ REQUIRE(is_polynomial(*x, {x}));
+ REQUIRE(is_polynomial(*x, {y}));
+ REQUIRE(is_polynomial(*e8, {x}));
+ REQUIRE(is_polynomial(*e12, {x}));
+ REQUIRE(is_polynomial(*e12, {y}));
+ REQUIRE(!is_polynomial(*e10, {x}));
+ REQUIRE(!is_polynomial(*e10, {y}));
+ REQUIRE(!is_polynomial(*e7, {x}));
+ REQUIRE(!is_polynomial(*e7, {y}));
+ REQUIRE(!is_polynomial(*e13, {x}));
+ REQUIRE(!is_polynomial(*e14, {x}));
+ REQUIRE(!is_polynomial(*e15, {x}));
+ REQUIRE(is_polynomial(*e16, {x}));
+ REQUIRE(is_polynomial(*e16, {y}));
+ REQUIRE(is_polynomial(*e16));
+ REQUIRE(is_polynomial(*e18, {x}));
+ REQUIRE(is_polynomial(*e21));
+ REQUIRE(!is_polynomial(*e22));
+ REQUIRE(is_polynomial(*e21, {x, y}));
+ REQUIRE(!is_polynomial(*e22, {x, y}));
+}
\ No newline at end of file
@@ -137,3 +137,4 @@ SYMENGINE_ENUM(SYMENGINE_CONJUGATEMATRIX, ConjugateMatrix)
SYMENGINE_ENUM(SYMENGINE_TRANSPOSE, Transpose)
// Transpose must be the last MatrixExpr
SYMENGINE_ENUM(SYMENGINE_UNEVALUATED_EXPR, UnevaluatedExpr)
+SYMENGINE_ENUM(SYMENGINE_MOD, Mod)
\ No newline at end of file
@@ -179,6 +179,11 @@ void TransformVisitor::bvisit(const Pow &x)
}
}
+void TransformVisitor::bvisit(const Mod &x)
+{
+ result_ = mod(apply(x.get_dividend()), apply(x.get_divisor()));
+}
+
void TransformVisitor::bvisit(const OneArgFunction &x)
{
auto farg = x.get_arg();
@@ -286,6 +291,13 @@ void CountOpsVisitor::bvisit(const Pow &x)
apply(*x.get_base());
}
+void CountOpsVisitor::bvisit(const Mod &x)
+{
+ count++;
+ apply(*x.get_dividend());
+ apply(*x.get_divisor());
+}
+
void CountOpsVisitor::bvisit(const Number &x) {}
void CountOpsVisitor::bvisit(const ComplexBase &x)
@@ -182,6 +182,17 @@ public:
}
}
+ void bvisit(const Mod &x)
+ {
+ if (eq(*x.get_dividend(), *x_) and eq(*x.get_divisor(), *n_)) {
+ coeff_ = one;
+ } else if (neq(*x.get_dividend(), *x_) and eq(*zero, *n_)) {
+ coeff_ = x.rcp_from_this();
+ } else {
+ coeff_ = zero;
+ }
+ }
+
void bvisit(const Symbol &x)
{
if (eq(x, *x_) and eq(*one, *n_)) {
@@ -247,6 +258,7 @@ public:
void bvisit(const Add &x);
void bvisit(const Mul &x);
void bvisit(const Pow &x);
+ void bvisit(const Mod &x);
void bvisit(const OneArgFunction &x);
template <class T>
@@ -328,6 +340,7 @@ public:
void bvisit(const Mul &x);
void bvisit(const Add &x);
void bvisit(const Pow &x);
+ void bvisit(const Mod &x);
void bvisit(const Number &x);
void bvisit(const ComplexBase &x);
void bvisit(const Symbol &x);
@@ -75,7 +75,7 @@ RCP<const Basic> mod(const RCP<const Basic> ÷nd,
}
// 数值情况直接计算
- if (is_a_Number(*dividend) && is_a_Number(*divisor)) {
+ if (is_a<Integer>(*dividend) && is_a<Integer>(*divisor)) {
RCP<const Number> a_num = rcp_static_cast<const Number>(dividend);
RCP<const Number> m_num = rcp_static_cast<const Number>(divisor);
return a_num->mod(*m_num);
@@ -5,6 +5,7 @@
#include <symengine/test_visitors.h>
#include <symengine/parser.h>
#include <symengine/parser/parser.h>
+#include <symengine/rational.h>
using SymEngine::Assumptions;
using SymEngine::Basic;
@@ -26,6 +27,7 @@ using SymEngine::reals;
using SymEngine::Symbol;
using SymEngine::symbol;
using SymEngine::zero;
+using SymEngine::rational;
//测试用例分类
// 1.基础测试用例
@@ -60,6 +62,11 @@ TEST_CASE("mod: Basic", "[basic]")
REQUIRE(exp6->__str__() == "R%16");
auto exp7 = mod(mul(integer(35), symbol("R")), integer(16));
REQUIRE(exp7->__str__() == "(3*R)%16");
+
+ RCP<const Basic> ra = rational(2, 10);
+ RCP<const Basic> rb = rational(3, 10);
+ auto rmod = mod(ra, rb);
+ REQUIRE(rmod->__str__() == "(1/5)%(3/10)");
}
TEST_CASE("mod: parse", "[parse]")
@@ -75,7 +75,8 @@ RCP<const Basic> mod(const RCP<const Basic> ÷nd,
}
// 数值情况直接计算
- if (is_a<Integer>(*dividend) && is_a<Integer>(*divisor)) {
+ if ((is_a<Integer>(*dividend) || is_a<Rational>(*dividend))
+ && is_a<Integer>(*divisor)) {
RCP<const Number> a_num = rcp_static_cast<const Number>(dividend);
RCP<const Number> m_num = rcp_static_cast<const Number>(divisor);
return a_num->mod(*m_num);
@@ -318,6 +318,31 @@ public:
{
return integer(SymEngine::get_den(i));
}
+
+ RCP<const Number> modrational(const Integer &other) const {
+ const integer_class &num = SymEngine::get_num(i);
+ const integer_class &den = SymEngine::get_den(i);
+
+ const integer_class &m = other.as_integer_class();
+
+ integer_class new_mod = den * m;
+ integer_class r = num % new_mod;
+ return Rational::from_mpq(rational_class(r, den));
+ }
+
+ RCP<const Number> mod(const Number &other) const override
+ {
+ if (is_a<Integer>(other)) {
+ return modrational(down_cast<const Integer &>(other));
+ } else {
+ return other.rmod(*this);
+ }
+ };
+
+ RCP<const Number> rmod(const Number &other) const override
+ {
+ return other.mod(*this);
+ };
};
//! returns the `num` and `den` of rational `rat` as `RCP<const Integer>`
@@ -49,6 +49,7 @@ TEST_CASE("mod: Basic", "[basic]")
REQUIRE(mod(x, integer(2))->__str__() == "x%2");
REQUIRE(mod(mul(x, integer(6)), integer(2)) == zero);
REQUIRE(mod(add(x, integer(1)), integer(1)) == zero);
+ REQUIRE(mod(rational(1, 8), integer(8))->__str__() == "1/8");
REQUIRE_THROWS_AS(mod(x, zero), SymEngine::SymEngineException);
auto exp1 = mod(mul(symbol("R"),symbol("nio")), integer(16));
REQUIRE(exp1->__str__() == "(R*nio)%16");