繰り返し二乗法について
Published on 2023-08-05Last Modified 2023-12-01
Table Of Contents
はじめに
この度私のライブラリに冪乗の余りを求める、いわゆるmodPow
関数を追加しました。
丁度いい機会ということで、本稿ではmodpowの仕組みを説明し、詳細な実装方法と実装例を示します。
数学をガバっていたらすみません。
modpowとは
以下、$\times{}$は整数の積を表します。
まず、剰余を定める重要な定理を紹介します。
$\forall{} n, m \in{} \mathbb{Z} ~ (m \neq{} 0)$に対して、 $q, r \in{} \mathbb{Z}$が存在して、次を満たす。 $$ n = qm+r, ~ 0 \leq{} r < |m| $$ また、このような$q, r$は$m, n$に対してただ一つに定まる。 :::
(証明略)
ここで、$q, r$をそれぞれ、「$n$を$m$で割った商」、「$n$を$m$で割った非負最小剰余」と定めます。
次に、$\forall{}x \in{} \mathbb{Z}$と非負整数$n$に対して、冪乗$x^n$を次で定めます。 $$ x^n = \begin{cases} x \times{} x^{n-1} & \text{if $n \geq{} 1$} \\ 1 & \text{if $n = 0$} \end{cases} $$
本稿で紹介するmodPow
関数で求める値は、$a^x$をMOD
で割った非負最小剰余です。
(数式苦手な方へ)これは何を求めているのかをできるだけ形式的に定めているだけなので、あまり気にしなくても良いです。
「$a$ を $x$ 回掛け算して、MOD
で割った余りを求める。」でOKです。
(注意): 便宜上、$0^0=1$ としています。
計算原理
以下、割る数は$0$を含まない正整数とします。
割る数とはプログラムにおけるMOD
のことです。
(定義を考えれば負の整数に対しても非負最小剰余はただ一つに定まりますが、実用上正整数を用いることが多いのかな?)
通常、冪乗は比較的簡単に非常に大きな値になります。 例えば、$2^{10}=1024$ ですが、$2^{64}=18446744073709551616$ であり、 上の定義をそのまま計算するだけでは非常に厳しいことがわかります。 そこで、まず次の定理が大事になります。
$(a ~ \text{mod} ~ m)$と書けば、$a$を$m$で割った非負最小剰余を表すとする。
このとき、整数$x, y$に対して次が成立する。 $$ ((x \times{} y) ~ \text{mod} ~ m) = ( ((x ~ \text{mod} ~ m) \times{} (y ~ \text{mod} ~ m)) ~ \text{mod} ~ m) $$
証明:
上で示した定理により、次が成立する。
ある整数 $q_{1}, q_{2}, r_{1}, r_{2}$ が存在して、$x = q_{1}m + r_{1}, ~ y = q_{2}m + r_{2}$
また、 $$ \begin{split} x \times{} y &= (q_{1}m + r_{1}) \times{} (q_{2}m + r_{2}) \\ &= q_{1}q_{2}m^2 + (q_{1}r_{2} + q_{2}r_{1})m + r_{1}r_{2} \\ &= (q_{1}q_{2}m + (q_{1}r_{2} + q_{2}r_{1}))m + r_{1}r_{2} \\ \end{split} $$ であるから、 $$ \begin{split} (左辺) &= ( ((q_{1}q_{2}m + (q_{1}r_{2} + q_{2}r_{1}))m + r_{1}r_{2}) ~ \text{mod} ~ m ) \\ &= ((r_{1}r_{2}) ~ \text{mod} ~ m) \end{split} $$ 一方、 $$ (x ~ \text{mod} ~ m) \times{} (y ~ \text{mod} ~ m) = r_{1}r_{2} $$ であるから、 $$ (右辺) = ((r_{1}r_{2}) ~ \text{mod} ~ m) $$ 証明終わり。
この定理は、積に関してなら どのタイミングで非負最小剰余に変換しても最後は同じ値になるということを示しています。 また、定理では右辺は両方非負最小剰余に変換していますが、もちろん片方だけで行っても同様の結果が得られます。
これで $a^x$ の大きさに関する問題は解決します。 なぜなら、一回 $a$
を掛け算するたびに非負最小剰余に逐次変換していけば良いので、 $(割る数)
\times{} (割る数)$さえ正しく計算できれば $a^x$
の非負最小剰余もまた計算可能になるからです。
プログラム的には、用いる整数の型が $k$ bitであるとき、MOD
の大きさが
$k/2$ bit程度であれば計算できます。
したがって、次のようなプログラムは正しく動作します。
long long modPow (long long a, long long x, const int MOD) {
// 不正な入力を弾く
assert(0 <= x);
assert(1 <= MOD);
// aを正規化
a %= MOD; a += MOD; a %= MOD;
long long res = 1;
for (long long i = 0; i < x; i++) {
res *= a;
res %= MOD;
}
return res % MOD;
}
$a$
の正規化というのは、底を最初の段階で非負最小剰余に変換しているということです。
上のコードは積や剰余演算を$O(1)$でできると仮定すれば、全体で$O(x)$で抑えられます。
しかし、(競技プログラミングなど)実用上はx
は$10^9$程度であったりすることが多く、このままでは使えません。
そこで、通常は次に説明するような高速化を施します。
まず、以下の定理を導入します。
2以上の整数 $p$ を固定する。任意の非負整数 $x$ に対して、 長さ$1$以上の整数列 ${c_{i}} ~ (c_{i} \in{} {0, 1, \cdots{}, p-1})$ が存在し、 $$ x = \sum_{i=0} c_{i}p^{i} $$ と表すことができる。 また、このような表し方(すなわち、数列 ${c_i}$)は $x=0$ を特別扱いするとき、先頭の余計な $0$ の項を除いて一意に定まる。
証明:
まず、任意の非負整数が上の形で表されることを示す。
(1) $x=0$ のとき、すべての項が $0$ であるような整数列 ${c_{i}}$ を用いれば良い。
(2) $x=1$ のとき、整数列 ${c_{i}}$ であって、 $$ c_{i} = \begin{cases} 1 & \text{if $i=0$} \\ 0 & \text{otherwise} \end{cases} $$ であるようなものを用いれば良い。
(3) $x>1$ に対して $x$ 未満の正整数が上の形で表されると仮定する。 このとき、$x = qp + r$ と表示すると、
(i) $q=0$ のとき、$0 \leq{} r < p$ であるから、整数列 ${c_{i}}$ であって、 $$ c_{i} = \begin{cases} r & \text{if $i=0$} \\ 0 & \text{otherwise} \end{cases} $$ であるようなものを用いれば良い。
(ii) $q \neq{} 0$ のとき、$q < x$ であるから、仮定より長さ1以上の整数列 ${d_{i}}$ が存在して $$ q = \sum_{i=0} d_{i} p^{i} $$ よって、$x=qp+r$としていることを思い出せば $$ pq = \sum_{i=0} d_{i} p^{i+1} $$ であるから、ここで新しく整数列 ${c_{i}}$ を次のように定める。 $$ c_{i} = \begin{cases} r & \text{if $i=0$} \\ d_{i-1} & \text{otherwise} \end{cases} $$ すると、 $$ x = pq+r = \sum_{i=0}^{m} c_{i} p^{i} $$ と表すことができる。
数学的帰納法により、任意の整数を表すことができる。
次に、このような表し方が非負整数$x$に対して一意に定まることを示す。
まず、$x=0$ のとき、すべての $c_i$ が $0$ である。 これは「先頭の余分な $0$」のみで構成されている唯一の数字なので、これを特別扱いして、一意に定まっているとみなす。
(1) $0 < x < p$ のとき
$$ x = \sum_{i=0} c_i p^i $$ とすると、$1\leq{}i$ なる $i$ に対して $c_i \neq{} 0$ であるとき、明らかに $p \leq{} x$ となる。これは仮定に反する。 したがって、 $$ c_i = \begin{cases} x & \text{if $i=0$} \\ 0 & \text{if $i \neq{} 0$} \\ \end{cases} $$ 以外の表し方は存在しない。
(2) $p \leq{} x$ のとき
$x$ 未満の数は、すべて一意的に表されると仮定する。 このとき、 $$ \begin{split} x &= \sum_{i=0} c_i p^i \\ &= \left( \sum_{i=1} c_i p^{i-1} \right) p + c_0 \end{split} $$ とできる。定理1により、$\sum_{i=1} c_i p^{i-1}$ と $c_0$ はただ一つに定まる。 ここで、明らかに $\sum_{i=1} c_i p^{i-1} \leq{} x$ である。 仮定より、$c_i ~ (1\leq{}i)$ は一意に定まる。
以上より、数学的帰納法により任意の非負整数 $x$ は一意に表される。
証明終わり。
少々天下り的ですが、$p=2$ を選んで、$x = \sum_{i=0} 2^i c_i ~ (c_i \in{} {0, 1})$ とします。 定理より、非負整数 $x$ はこのような表示を先頭の $0$ の自由を除いて一意に持ちます。 このとき、 $$ \begin{split} a^x &= a^{\sum_{i=0} 2^i c_i} \\ &= a^{c_0 + 2 c_1 + 2^2 c_2 + \dots{}} \\ &= a^{c_0} \times{} a^{2 c_1} \times{} a^{2^2 c_2} \times{} \dots{} \\ &= \prod_{i=0} a^{ 2^i c_i} \\ &= \prod_{i=0} \left( a^{2^i} \right) ^{c_i} \end{split} $$ となります。 (式変形がわからない人はこちらを参考にしてください。)
実は、この形まで変形すると非常に高速に計算できるようになっています! 先程の定義通りに計算する方法では、$x$ が大きいときに時間がかかるということが問題でした。 上の変形は、視覚的に説明するなら、 $$ a^{100} = a^{64} \times{} a^{32} \times{} a^{4} $$ のように分解しているということです。 定理3は、このような分解をしたときに、右辺の指数が必ず $2^k$ と表せるということを保証しているとも言えます。 この定理のおかげで $a^{2^k}$ の形の数さえ高速に列挙できれば良くなり、結果的に冪乗を求めるのも高速になるということです。
また、$a^{2^k} = a^{2^{k-1}} \times{} a^{2^{k-1}}$ という関係が成立するため、$a^{2^k}$ の形の数は簡単に求められます。 以上より、次のアルゴリズムを得ます。
- $x$ を2進法展開する。つまり、$x = \sum_{i=0} c_i 2^i$ と表す。
- $base \leftarrow{} (a ~ \text{mod} ~ m), ~ ans \leftarrow{} (1 ~ \text{mod} ~ m), ~ i \leftarrow{} 1$ とする。
- すべての $k ~ (i \leq{} k)$ に対して $c_k = 0$ であれば、$ans$ を出力し終了。
- $c_i = 1$ であれば $ans \leftarrow{} ((ans \times{} base) ~ \text{mod} ~ m)$ とする。
- $base \leftarrow{} ((base \times{} base) ~ \text{mod} ~ m), ~ i \leftarrow{} i+1$ とする。
- 3に戻る。
このアルゴリズムの時間計算量は、手順3から手順6のループ1回あたり $O(1)$ とみなせば、 $x < 2^k$ となった時点で停止するため($k$ 以降の数 $l$ において $c_l = 0$ となるから)、全体で $O(log(x))$ になります。
元々の $O(x)$ のアルゴリズムからかなり改善されました!
実装
実は、$x$ を2進法展開するのはほとんどのプログラミング言語で必要ないです。 というのは、コンピュータは内部的に整数を2進法で表現しているからです。 ビット演算と言われるような機能を持つプログラミング言語ならこの過程を飛ばすことができます。
具体的には、$x$ を2進法展開したときの $c_0$
は、プログラム上ではx&1
で得ることができます。
更に、x >>= 1
などで「ビットシフト」をすると、次にx&1
をしたときには
$c_1$ が得られます。
(ただし、このあたりはプログラミング言語によります。)
負の数が関わると2の補数表現など少しややこしくなりますが、$0\leq{} x$ を仮定しているので問題ありません。
以下にC言語、C++、D言語、python3での実装例を示します。
C言語/C++
#include <assert.h>
// C言語なら<assert.h>
// C++なら<cassert.h>
long long modPow (long long a, long long x, const int MOD) {
// assertion
assert(0 <= x);
assert(1 <= MOD);
// normalize
a %= MOD; a += MOD; a %= MOD;
// calculate
long long ans = 1L % MOD;
long long base = a;
while (x != 0) {
if ((x&1) != 0) {
ans *= base; ans %= MOD;
}
base = base*base; base %= MOD;
x >>= 1;
}
return ans;
}
D言語
long modPow (long a, long x, const int MOD) {
// assertion
assert(0 <= x);
assert(1 <= MOD);
// normalize
a %= MOD; a += MOD; a %= MOD;
// calculate
long ans = 1L % MOD;
long base = a % MOD;
while (x != 0) {
if ((x&1) != 0) {
ans *= base; ans %= MOD;
}
base = base*base; base %= MOD;
x >>= 1;
}
return ans;
}
python3
def modPow (a, x, MOD):
# assertion
assert 0 <= x, "x must be an integer greater than or equal to 0"
assert 1 <= MOD, "MOD must be an integer greater or equal to 1"
# normalize
a %= MOD
# calculate
ans = 1 % MOD
base = a
while x != 0:
if (x&1) != 0:
ans *= base
ans %= MOD
base = base*base
base %= MOD
x >>= 1
return ans
終わりに
まとめるのすごく大変だった...
内容に不備があれば著者のtwitterに連絡していただければ助かります。 どれだけ些細な内容でも大歓迎です。
参考文献
本稿における主張、紹介した定理及びその証明はほとんど以下に依ります。
- 尾関和彦、 情報技術のための離散型数学入門、 共立出版(2023)
主に第5章、第6章を参考にしました。