24 Improving performance
Introduction
We should forget about small efficiencies, say about 97% of the time: premature optimization is the root of all evil. Yet we should not pass up our opportunities in that critical 3%. A good programmer will not be lulled into complacency by such reasoning, he will be wise to look carefully at the critical code; but only after that code has been identified.
— Donald Knuth
本节介绍四种通用的优化工具和一个通用的性能优化策略,帮助你确保优化后的代码仍然结果正确。但要注意:优化需建立在对实际瓶颈的识别之上,避免在非关键部分浪费精力,同时抓住真正影响性能的核心环节。如果你想更多地了解R语言的性能特点,推荐Evaluating the Design of the R Language这本书,该书通过将一个经过修改的R解释器与大量实际应用中的代码相结合,得出了相关结论。
Outline
- 24.2节:介绍如何组织代码,使优化尽可能简单、无bug。
- 24.3节:提醒你去寻找已有的解决方案。
- 24.4节:强调“懒惰”的重要性:使函数运行快的最简单方法就是让函数做最简单的事。
- 24.5节:介绍向量化,并展示如何最大限度地利用内置函数。
- 24.6节:讨论复制数据的性能风险。
- 24.7节:将所有片段整合成一个案例研究,展示如何将重复t检验的速度提高约1000倍。
- 24.8节:提供了更多帮助你编写快速代码资源的指针。
Prerequisites
Code organisation
在尝试优化代码时,通常会掉入两个陷阱中:
- 代码运行快速但不正确。
- 你认为运行快速,但实际效果并不好(没有进行基准测试)。
下面的策略会帮助你避免这些陷阱。
在进行多种方法的基准测试前,可以将这个方法打包成一个函数。因为函数内的环境是独立的,不会存在干扰,也方便检查返回的结果是否正确。例如,对比两种不同计算均值的方法:
建议你记录所有尝试的内容,甚至包括失败。如果将来发生类似的问题,查看你尝试过的所有内容会很有用。推荐使用RMarkdown或quarto, 这使得将代码与详细的注释和说明混合在一起变得容易。
接下来,生成一个具有代表性的测试示例。这个示例应该足够大,以捕捉问题的本质,但又要足够小,运行最多只需要几秒钟。你不希望花费太长时间,因为你需要多次运行测试示例来比较方法。另一方面,你也不希望示例太小,因为那样结果可能无法扩展到真正的问题。这里使用100,000个数字来进行测试:
x <- runif(1e5)
现在使用bench::mark()
来精确比较变量。bench::mark()
会自动检查所有调用是否返回相同类型的值。这并不能保证函数对所有输入的行为都相同,因此在理想情况下,还需要进行单元测试,以确保不会意外地改变函数的行为。
从结果上看,mean()
意外地要比sum(x) / length(x)
慢一些。这是因为mean()
在计算时,会进行一些额外的步骤,来提升结果地精度。如果你对这种计算策略感兴趣,可以查看:
- http://stackoverflow.com/questions/22515525#22518603
- http://stackoverflow.com/questions/22515175#22515856
- http://stackoverflow.com/questions/3476015#22511936
Checking for existing solutions
当你尝试过自己很多种想法后,仍然很难解决问题时,你可以检查是否已经有成熟地解决方案了。下面是两个好的检索开始:
CRAN task views,根据任务收集CRAN上的包。
在Rcpp的CRAN主页上,可以找到一些使用Rcpp的包,这些包都使用C++语言编写,可能会更快些。
除此之外,你需要将你的问题描述清楚,并使用搜索引擎(现在用AI啦😊)搜索。同时,你要广泛地阅读相关书籍,积攒的专业知识有助于你更快速的检索并理解答案。将自己解决问题的过程和最终答案记录下来,长时间的积累后,可以使用某些工具进行构建自己的知识库以便日后查阅。
Doing as little as possible
尽可能将函数的功能限定在某个范围内,接受特定的输入,输出特定的结果。例如:
rowSums()
,colSums()
,rowMeans()
,colMeans()
要比应用apply()
来计算快很多。any(x == 10)
要比10 %in% x
更快。
某些函数的输入要求特定类型,当输入不符合时,函数可能会执行额外的类型转换工作。例如,应用apply()
到data.frame
时,会自动将data.frame
转换为matrix
。
如果提供更多问题的信息,某些函数会减少一些工作量。例如:
read.csv()
中使用colClasses
指定已知列类型。factor()
使用levels
参数指定已知因子级别。cut()
设置labels = FALSE
可以避免产生标签。unlist(x, use.names = FALSE)
要比unlist(x)
更快。interaction()
设置drop = TRUE
可以丢掉不必要的因子水平。
下面以mean()
和as.data.frame()
为例,展示如何使用这种策略来提高性能。
mean()
由于R的大多数函数使用了S3或S4面向对象,因此,我们可以通过避免方法派发来提高性能。这在一个大型循环任务中会很有效。
S3,可以直接调用
generic.class()
函数。S4,需要使用
selectMethod()
函数获取方法,然后赋值给环境变量进行调用。
例如,mean.default()
计算小型数值向量时要比mean()
快上些:
x <- runif(1e2)
bench::mark(
mean(x),
mean.default(x)
)[c("expression", "min", "median", "itr/sec", "n_gc")]
#> # A tibble: 2 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 mean(x) 3.9µs 4.3µs 202369.
#> 2 mean.default(x) 1.4µs 1.8µs 499333.
这种优化方式存在一定风险,当x
不是数值向量时,mean.default()
会报错。你甚至可以直接调用.Internal()
函数来极大的提升性能,同时有也将引入更大的风险——无法对NA值进行处理。
x <- runif(1e2)
bench::mark(
mean(x),
mean.default(x),
.Internal(mean(x))
)[c("expression", "min", "median", "itr/sec", "n_gc")]
#> # A tibble: 3 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 mean(x) 3.9µs 4.2µs 202416.
#> 2 mean.default(x) 1.4µs 1.7µs 543593.
#> 3 .Internal(mean(x)) 100ns 200ns 3989468.
注意:这些差异之所以出现,是因为x
很小。如果你增加大小,这些差异基本上就会消失,因为大部分时间都用在计算平均值上,而不是进行方法派发。这很好地提醒了我们,输入的大小很重要,你应该根据真实的数据来进行优化。
x <- runif(1e5)
bench::mark(
mean(x),
mean.default(x),
.Internal(mean(x))
)[c("expression", "min", "median", "itr/sec", "n_gc")]
#> # A tibble: 3 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 mean(x) 136µs 137µs 6607.
#> 2 mean.default(x) 132µs 142µs 6712.
#> 3 .Internal(mean(x)) 131µs 144µs 6567.
as.data.frame()
能够确定输入的数据类型是另外一种加开代码运行的方式。例如,as.data.frame()
的转换过程分两步,先将每个元素强制转换为数据框,然后再使用rbind()
将结果拼接起来。如果你已经知道列表有name属性且元素等长,那么你可以直接将其转换为数据框(R中的所有数据结构都是向量,只是属性class不同):
quickdf <- function(l) {
class(l) <- "data.frame"
attr(l, "row.names") <- .set_row_names(length(l[[1]]))
l
}
l <- lapply(1:26, function(i) runif(1e3))
names(l) <- letters
bench::mark(
as.data.frame = as.data.frame(l),
quick_df = quickdf(l)
)[c("expression", "min", "median", "itr/sec", "n_gc")]
#> # A tibble: 2 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 as.data.frame 1.04ms 1.11ms 843.
#> 2 quick_df 6µs 6.7µs 129486.
当然,这种快速的方法牺牲的是对结果正确性的保证。如果你的输入错误,那么你将会得到错误的结果:
quickdf(list(x = 1, y = 1:2))
#> Warning in format.data.frame(if (omit) x[seq_len(n0), , drop = FALSE] else
#> x, : corrupt data frame: columns will be truncated or padded with NAs
#> x y
#> 1 1 1
为了得到这个最小化方法,作者仔细阅读并重写了as.data.frame.list()
和data.frame()
的源代码,并做了许多小的修改,每次都检查是否破坏了现有的行为;经过几个小时的工作,能够分离出上面显示的最小化代码。这是一种非常有用的技术:大多数base R函数是为了灵活性和功能性而编写的,而不是为了性能。因此,根据特定需求重写通常可以带来显著的改进。要做到这一点,需要阅读源代码,它可能很复杂和令人困惑,但不要放弃!
Vectorise
如果你使用过R一段时间,你可能听说过这样的话——“向量化你的代码”。但是究竟什么是“向量化”呢?“向量化”不仅仅只是避免使用for循环,而是一种整体化解决问题的思路,即你要处理的是一个向量,而不是向量中的每个标量。一个“向量化”的函数通常有两个关键特点:
简化了问题逻辑:从“逐个处理”到“整体处理”。
提升了运行速度:底层使用C语言而非R。
在实践中,除了使用map()
或lapply()
来实现“向量化”,也可以使用已经“向量化”的函数。base R提供了许多已经“向量化”的函数:
-
rowSums()
,colSums()
,rowMeans()
,colMeans()
:可以使用它们构建新的“向量化”函数: 向量化提取自己可以极大地提升运行速度(见4.5节):可以一步进行提取赋值多个值,例如,当
x
是向量、矩阵、数据框时,x[is.na(x)] <- 0
会替换所有缺失值为0。可以是使用
cut
和findInterval()
函数来将连续变量离散化。
线性代数的运行通常是向量化的,它们的循环使用了外部库,如BLAS
。如果你的问题可以使用线性代数来解决,那么运行速度通常会很快。
“向量化”的缺点是:很难预测性能,无法简单地进行线性估算。如下例,查询100个字符的运行时间并不是处理单个字符的100倍运行时间,而仅是10倍。这背后的逻辑是:“向量化”会动态的切换策略——操作量高于某个阈值时,会采用耗时的“初始化”+不耗时的“处理”策略。
lookup <- setNames(as.list(sample(100, 26)), letters)
x1 <- "j"
x10 <- sample(letters, 10)
x100 <- sample(letters, 100, replace = TRUE)
bench::mark(
lookup[x1],
lookup[x10],
lookup[x100],
check = FALSE
)[c("expression", "min", "median", "itr/sec", "n_gc")]
#> # A tibble: 3 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 lookup[x1] 300ns 500ns 1875223.
#> 2 lookup[x10] 1µs 1.2µs 766783.
#> 3 lookup[x100] 2.4µs 3.5µs 246153.
向量化并不能解决所有问题,而且与其费力地将现有算法强行改成使用向量化的方法,不如使用C++ 编写自己的向量化函数。我们将在第25章学习如何做到这一点。
Avoiding copies
R 代码运行缓慢的一个究极原因是在for循环中不断创建额外的对象。当你使用c()
,append()
,cbind()
,rbind()
,paste()
组合创建新的对象时,R必须首先创建一个新的对象,然后将旧对象的内容复制到新的对象中。当你在for循环中使用这些函数时,就会不断地创建额外对象。
下面是一个示例:collapse()
函数使用for循环将多个字符串连接成一个字符串;对比直接使用paste()
函数中的参数collapse
。
random_string <- function() {
paste(sample(letters, 50, replace = TRUE), collapse = "")
}
strings10 <- replicate(10, random_string())
strings100 <- replicate(100, random_string())
collapse <- function(xs) {
out <- ""
for (x in xs) {
out <- paste0(out, x)
}
out
}
bench::mark(
loop10 = collapse(strings10),
loop100 = collapse(strings100),
vec10 = paste(strings10, collapse = ""),
vec100 = paste(strings100, collapse = ""),
check = FALSE
)[c("expression", "min", "median", "itr/sec", "n_gc")]
#> # A tibble: 4 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 loop10 21.2µs 24.5µs 36165.
#> 2 loop100 569µs 590.6µs 1561.
#> 3 vec10 4.1µs 4.6µs 196385.
#> 4 vec100 24µs 26.7µs 35106.
因为“修改后复制”的机制,x[i] <- y
也会触发复制,详见第2章。
Case study: t-test
下面,我们使用上述介绍的方法来加快“t-test”中t统计量的批量计算。
假设我们有1000次实验(行),每次实验有50个样本(列),前25个样本为一组,后25个样本为另一组,生成测试数据:
有两种方法来批量计算t-test的t统计量:
system.time(
for (i in 1:m) {
t.test(X[i, ] ~ grp)$statistic
}
)
#> user system elapsed
#> 0.53 0.00 0.53
system.time(
for (i in 1:m) {
t.test(X[i, grp == 1], X[i, grp == 2])$statistic
}
)
#> user system elapsed
#> 0.13 0.00 0.13
当然,我们也可以使用map_dbl()
来批量计算:
compT <- function(i) {
t.test(X[i, grp == 1], X[i, grp == 2])$statistic
}
system.time(t1 <- purrr::map_dbl(1:m, compT))
#> user system elapsed
#> 0.16 0.00 0.16
首先我们可以使用减少函数额外工作的策略优化,查看stats::t.test.default()
的源码,你会发现它不仅计算了t统计量,还计算了p值和打印输出。我们可以只计算t统计量:
my_t <- function(x, grp) {
t_stat <- function(x) {
m <- mean(x)
n <- length(x)
var <- sum((x - m)^2) / (n - 1)
list(m = m, n = n, var = var)
}
g1 <- t_stat(x[grp == 1])
g2 <- t_stat(x[grp == 2])
se_total <- sqrt(g1$var / g1$n + g2$var / g2$n)
(g1$m - g2$m) / se_total
}
system.time(t2 <- purrr::map_dbl(1:m, ~ my_t(X[., ], grp)))
#> user system elapsed
#> 0.03 0.00 0.03
stopifnot(all.equal(t1, t2))
针对上面计的for循环计算策略,我们可以使用向量化函数来优化。
rowtstat <- function(X, grp) {
t_stat <- function(X) {
m <- rowMeans(X)
n <- ncol(X)
var <- rowSums((X - m)^2) / (n - 1)
list(m = m, n = n, var = var)
}
g1 <- t_stat(X[, grp == 1])
g2 <- t_stat(X[, grp == 2])
se_total <- sqrt(g1$var / g1$n + g2$var / g2$n)
(g1$m - g2$m) / se_total
}
system.time(t3 <- rowtstat(X, grp))
#> user system elapsed
#> 0.02 0.00 0.01
stopifnot(all.equal(t1, t3))
Other techniques
写出运行速度快的代码是成为优秀程序员的一部分。除了本章介绍的策略外,你可以通过下面的方式来提升自己的通用编程技巧: