をモンテカルロ積分で求める際、重点サンプリングのサンプルの違いによる収束の様子を確認してみた。
凡例にも示した通りだが以下の3種類のPDF(確率密度関数)について、収束の状況を比較した。
確率密度の係数の求め方
確率密度関数の積分は、を満たすので、 係数をとして方程式を解く。 より、が求まる。
被積分関数 と同じ次数のPDFを用いれば、サンプリング回数は1度で収束する?
混合分布
混合分布の場合はどうかなと思ったので、 について0.5ずつ混ぜた混合分布の収束についても確認してみる。まず、 混合分布のPDFは以下となる。*1
ちなみに、 の確率密度をもつ乱数 を求めると、 となる。 なお、この式は以下に対する逆関数によって求まる。
結果は以下。よりは収束が速くなっているが、よりは収束が遅いことが確認でき、かつ同じ値に収束することが確認できたと思う。
コード
確認に使ったpythonのコードも貼っておく。
import math import random import matplotlib.pyplot as plt def random_double(): return random.random() def random_double(min_v, max_v): return min_v + (max_v - min_v) * random.random() def montecarlo1(n): s = 0.0 seq = [] for i in range(n): x = random_double(0, 2) s += x * x seq.append(2 * s / (i + 1)) #print("{} = {}".format(2 * s / n, 8/3)) return seq def importance_sampling1(n): s = 0.0 # 一様 def pdf(x): return 0.5 seq = [] for i in range(1, n+1): x = random_double(0, 2) s += x * x / pdf(x) seq.append(s / i) return seq def importance_sampling2(n): s = 0.0 def pdf(x): return 0.5 * x seq = [] for i in range(1, n+1): x = math.sqrt(random_double(0, 4)) while x == 0.0: x = math.sqrt(random_double(0, 4)) s += x * x / pdf(x) seq.append(s / i) return seq def importance_sampling3(n): s = 0.0 def pdf(x): return 3/8 * x*x seq = [] for i in range(1, n+1): x = math.pow(random_double(0, 8), 1/3) while x == 0.0: x = math.pow(random_double(0, 8), 1/3) s += x * x / pdf(x) seq.append(s / i) return seq def importance_sampling4(n): s = 0.0 def pdf(x): return 1/4 * (x + 1) seq = [] for i in range(1, n+1): x = -1 + math.sqrt(1+random_double(0, 8)) while x == 0.0: x = -1 + math.sqrt(1+random_double(0, 8)) s += x * x / pdf(x) seq.append(s / i) return seq def main(): n = 10000 sampling=100 for i in range(sampling): if i == 0: label='1/2' else: label=None plt.scatter(range(1, n+1), importance_sampling1(n), c='#1f77b4', s=1, marker='.', label=label) for i in range(sampling): if i == 0: label='1/2x' else: label=None plt.scatter(range(1, n+1), importance_sampling2(n), c='#ff7f0e', s=1, marker='.', label=label) for i in range(sampling): if i == 0: label='3/8x^2' else: label=None plt.scatter(range(1, n+1), importance_sampling3(n), c='#2ca02c', s=1, marker='.', label=label) for i in range(sampling): if i == 0: label='1/4x+1/4' else: label=None plt.scatter(range(1, n+1), importance_sampling4(n), c='#2ca02c', s=1, marker='.', label=label) plt.xscale('log') plt.grid(which='major',color='black',linestyle='-') plt.grid(which='minor',color='gray',linestyle='-') plt.legend() plt.savefig('chart.png') main()
続きも書いた。