qiushi 刚好最近有学Rcpp,照你的metropolis_hastings()代码写了一下Rcpp版本,练一练。如有错误之处,还请指正!
先定义一个头文件:metropolis_hastings.h,内容如下:
#ifndef MY_H_
#define MY_H_
#include <Rcpp.h>
using namespace Rcpp;
double Rcpp_propose(double x_t){
double x_star=R::rexp(1/x_t); //R::rexp(x,scale) where scale=1/rate
return x_star;
}
double Rcpp_p(double x_star,double x_t){
double num=x_star*(1-x_star)*R::dexp(x_t,1/x_star,false);
double den=x_t*(1-x_t)*R::dexp(x_star,1/x_t,false);
return num/den;
}
#endif
然后再写一个metropolis_hastings.cpp文件,内容如下:
#include "metropolis_hastings.h"
#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
NumericVector Rcpp_metropolis_hastings(double x_0,int num_iteration){
NumericVector samp={x_0};
double x_t=x_0;
for(int i=0;i<num_iteration;i++){
double x_star=Rcpp_propose(x_t);
NumericVector vec={1,Rcpp_p(x_star,x_t)};
double accept_prob=min(vec);
double u=R::runif(0,1);
if(accept_prob>u){
samp.push_back(x_star);
x_t=x_star;
}else{
x_t=x_t;
}
}
return samp;
}
由于我还不知道如何在Rcpp里面设定随机数种子来检验两个版本函数结果的一致性,因此只能取个别值大致比较一下。
library(Rcpp)
#Rcpp_metropolis_hastings()
sourceCpp("C:/Users/Administrator/Documents/metropolis_hastings.cpp")
#metropolis_hastings()
propose <- function(x_t) {
x_star <- rexp(1, rate = x_t)
x_star
}
p <- function(x_star, x_t) {
num <- x_star * (1 - x_star) * dexp(x_t, rate = x_star)
den <- x_t * (1 - x_t) * dexp(x_star, rate = x_t)
num / den
}
metropolis_hastings <- function(x_0 = sample(0:10, size = 1), num_iteration = 1000) {
samp <- x_0
x_t <- x_0
for (i in 1:num_iteration) {
x_star <- propose(x_t)
accept_prob <- min(1, p(x_star = x_star, x_t = x_t))
u <- runif(1)
if (accept_prob > u) {
samp <- append(samp, x_star)
x_t <- x_star
}
else {
x_t <- x_t
}
}
samp
}
# result comparsion
metropolis_hastings(2,3)
#> [1] 2
Rcpp_metropolis_hastings(2,3)
#> [1] 2
# result is the same
# speed comparsion
library(rbenchmark)
benchmark(metropolis_hastings(5,1e3),
Rcpp_metropolis_hastings(5,1e3))
#> test replications elapsed relative user.self
#> 1 metropolis_hastings(5, 1000) 100 1.08 36 1.06
#> 2 Rcpp_metropolis_hastings(5, 1000) 100 0.03 1 0.02
#> sys.self user.child sys.child
#> 1 0.01 NA NA
#> 2 0.00 NA NA
<sup>Created on 2020-08-20 by the reprex package (v0.3.0)</sup>
可以看到,R版本耗时是Rcpp版本的36倍。