折叠表达式

1. 折叠表达式

折叠表达式C++17 引入的新特性,它允许在编译时对参数包中的所有参数应用二元运算符。这使得处理可变参数模板变得更加简洁和直观。

1.1 为什么需要折叠表达式

对于初学者来说,可变参数模板本身的语法就比较晦涩,而展开参数包往往需要写大量的递归代码。折叠表达式的出现,就是为了解决这个痛点。折叠通常指的是将一个列表中的元素通过某种操作(比如加法、逻辑与)两两结合,最终得到一个单一结果的过程。例如,有一排数字:1, 2, 3, 4,如果你想要它们的和,会这样计算:

1
(((1 + 2) + 3) + 4)

这就是折叠。C++17 允许我们在编译期,对模板参数包进行这种操作。下面是一个典型的 C++ 可变参数模板递归求和 实现。让我详细解释它的工作原理。

1
2
3
4
5
6
7
8
9
10
11
12
// 1. 终止递归的基准函数
int sum()
{
return 0;
}

// 2. 递归展开参数包的函数模板
template <typename T, typename... Args>
int sum(T first, Args... rest)
{
return first + sum(rest...); // 必须写一个递归调用
}

详细执行过程(以 sum(1, 2, 3) 为例):

  • 第一次递归展开:此时计算:1 + sum(2, 3)

    1
    2
    3
    4
    5
    6
    7
    template <typename T, typename... Args>
    int sum<int, int, int>(int first, int... rest)
    {
    // first = 1
    // rest... = 2, 3
    return 1 + sum(2, 3); // 递归调用
    }
  • 第二次递归展开:此时计算:1 + (2 + sum(3))

    1
    2
    3
    4
    5
    6
    7
    template <typename T, typename... Args>
    int sum<int, int>(int first, int... rest)
    {
    // first = 2
    // rest... = 3
    return 2 + sum(3); // 递归调用
    }
  • 第三次递归展开:此时计算:1 + (2 + (3 + sum()))

    1
    2
    3
    4
    5
    6
    7
    template <typename T, typename... Args>
    int sum<int>(int first, int... rest)
    {
    // first = 3
    // rest... = (空参数包)
    return 3 + sum(); // 递归调用
    }
  • 终止递归,最终结果:1 + (2 + (3 + 0)) = 6

    1
    2
    3
    4
    int sum() 
    {
    return 0;
    }

上面是 C++11/14 中处理参数包的常见方法:

  • 递归分解:每次取出第一个参数(first),剩下的放入参数包(rest...
  • 递归调用sum(rest...) 将参数包展开传递给下一次调用
  • 终止条件:当参数包为空时,调用无参数的 sum() 终止递归

这种写法虽然能工作,但优点不明显:

  1. 代码冗长:必须显式定义一个终止条件的重载函数。
  2. 编译慢:递归实例化会增加编译器的负担。
  3. 易读性差:逻辑分散在两个函数中。

所以我们可以使用现代 C++ 的改进写法 – 折叠表达式来解决上述问题:

1
2
3
4
5
6
7
template <typename... Args>
auto sum(Args... args)
{
return (args + ...); // 一元右折叠
// 或者 return (... + args); // 一元左折叠
}
// sum(1, 2, 3) -> 展开为 1 + (2 + 3)

是不是像魔法一样整齐划一?折叠表达式是 C++17 引入的强大特性,用于简化可变参数模板的参数包展开。它可以直接对参数包执行操作,无需递归。但是想要彻底搞明白其中含义,我们必须要学习一下折叠表达式的语法。

1.2 折叠表达式的语法

折叠表达式的核心语法包含四个部分:

1
2
3
4
( 参数包 op ... )           // 一元右折叠
( ... op 参数包 ) // 一元左折叠
( 参数包 op ... op 初始值 ) // 二元右折叠
( 初始值 op ... op 参数包 ) // 二元左折叠

其中 op 是运算符,几乎所有的二元运算符都支持折叠:

  • 算术运算符:+, -, *, /, %
  • 位运算符:&, |, ^, <<, >>
  • 逻辑运算符:&&, ||
  • 比较运算符:==, !=, <, >, <=, >=
  • 成员访问运算符:.*, ->*
  • 逗号运算符:,
  • 赋值与复合赋值运算符:=+=-=*=/=%=&=|=^=<<=>>=

1.2.1 一元右折叠

1
2
3
4
5
6
// (pack op ...)
template<typename... Args>
auto sum(Args... args)
{
return (args + ...); // 从右向左结合
}

展开规则

  • (args + ...)arg1 + (arg2 + (... + argN))
  • 等价于:arg1 + (arg2 + (arg3 + ...))

示例

1
2
3
sum(1, 2, 3, 4);
// 展开:1 + (2 + (3 + 4)) = 10
// 执行顺序:先计算 3+4=7,再 2+7=9,再 1+9=10

1.2.2 一元左折叠

1
2
3
4
5
6
// (... op pack)
template<typename... Args>
auto sum(Args... args)
{
return (... + args); // 从左向右结合
}

展开规则

  • (... + args)((arg1 + arg2) + ...) + argN
  • 等价于:((arg1 + arg2) + arg3) + ...

示例

1
2
3
sum(1, 2, 3, 4);
// 展开:((1 + 2) + 3) + 4 = 10
// 执行顺序:先计算 1+2=3,再 3+3=6,再 6+4=10

1.2.3 二元右折叠

1
2
3
4
5
6
// ( pack op ... op init )
template<typename... Args>
auto sum_with_init(Args... args)
{
return (args + ... + 0); // 二元右折叠
}

展开规则

  • (args + ... + init) → arg1 + (arg2 + (... + (argN + init)))
  • 等价于:arg1 + (arg2 + (arg3 + ... + init))

示例

1
2
3
sum_with_init(1, 2, 3, 4);
// 展开:1 + (2 + (3 + (4 + 0))) = 10
// 相当于:一元右折叠 + init

1.2.4 二元左折叠

1
2
3
4
5
6
// ( init op ... op pack )
template<typename... Args>
auto sum_with_init2(Args... args)
{
return (0 + ... + args); // 二元左折叠
}

展开规则

  • (init + ... + args) → (((init + arg1) + arg2) + ...) + argN
  • 等价于:(((init + arg1) + arg2) + arg3) + ...

示例

1
2
3
sum_with_init2(1, 2, 3, 4);
// 展开:(((0 + 1) + 2) + 3) + 4 = 10
// 相当于:init + 一元左折叠

1.3 空参数包的处理

当涉及到 空参数包 的处理时,折叠表达式的行为取决于你使用的是哪种形式的折叠(一元折叠还是二元折叠),因为并非所有的运算符都支持在没有操作数的情况下进行计算。以下是关于空参数包处理的详细规则:

  1. 核心规则:合法与非法

    C++ 标准明确规定,如果折叠表达式展开时涉及空参数包,除了以下例外情况,通常会导致格式错误

    • 合法且必须指定初始值的运算符(以下 3 种):
      • 逻辑与 (&&):空包结果为 true
      • **逻辑或 **(||):空包结果为 false
      • **逗号 **(,):空包结果为 void()(即什么都不做)。
    • 非法(不能用于空包)的运算符:
      • 算术运算符(+, -, *, /, % 等):加法或乘法没有操作数是无意义的。
      • 比较、位运算、指针运算等。
  2. 折叠形式与空包处理

    虽然一元折叠(简单的 op ... argsargs ... op)很吸引人,但它们通常不支持空包(除了上述三种运算符)。为了安全地处理空参数包,二元折叠(带初始值) 是推荐的解决方案。

    • 一元折叠

      语法: ( ... op Pack )( Pack op ... )

      对于空包:

      • &&|| 有默认值(truefalse)。
      • , 是合法的(空操作)。
      • 其他所有运算符在不支持空包的情况下编译失败。

      示例(合法 - 逻辑):

      1
      2
      3
      4
      5
      6
      template<typename... Args>
      bool all_true(Args... args)
      {
      // 如果 args 为空,结果是 true
      return (args && ...);
      }

      示例(非法 - 算术):

      1
      2
      3
      4
      5
      6
      template<typename... Args>
      auto sum(Args... args)
      {
      // 如果 args 为空,这里会报错!因为 + 不能作用于空包
      return (args + ...);
      }
    • 二元折叠

      语法: ( Value op ... op Pack )( Pack op ... op Value )

      • 在这种形式下,即使 Pack 为空,表达式也会退化为 Value op Value(通常编译器会优化为直接返回 Value)。
      • 这是处理算术运算空包的标准方法。

      示例(处理空包):

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      #include <iostream>

      template<typename... Args>
      auto sum(Args... args)
      {
      // 即使 args 为空,结果也是 0
      return (0 + ... + args);
      }

      template<typename... Args>
      void print_all(Args... args)
      {
      // 基于 << 运算符的二元左折叠
      (std::cout << ... << args) << std::endl;
      }

      int main()
      {
      std::cout << sum() << std::endl; // 输出: 0
      std::cout << sum(1, 2, 3) << std::endl; // 输出: 6

      print_all(); // 输出: (换行)
      print_all("A", "B"); // 输出: AB (换行)

      // 检查是否所有为真(使用二元折叠处理空包更通用)
      auto check = [](auto... conditions) {
      // 空包默认返回 true
      return (true && ... && conditions);
      };

      std::cout << std::boolalpha << check() << std::endl; // 输出: true
      return 0;
      }

      关于print_all函数的函数体(std::cout << ... << args) 实际上是一个基于 <<(流插入) 运算符的二元左折叠。它可以展开为嵌套的左移调用:((((std::cout << arg1) << arg2) << arg3 ) ... )。这个二元折叠,有一个初始值 std::cout

      • 如果 args 为空,表达式就退化只剩下 std::cout
      • 所以后面接上的 << std::endl 实际上是 std::cout << std::endl,因此是安全的,会输出一个换行符。

2. 详细示例和测试代码

2.1 基本算术运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <iostream>
#include <string>

// 一元右折叠:求和
template<typename... Args>
auto sum_right(Args... args)
{
return (args + ...); // 相当于:arg1 + (arg2 + (arg3 + ...))
}

// 一元左折叠:求和
template<typename... Args>
auto sum_left(Args... args)
{
return (... + args); // 相当于:((arg1 + arg2) + arg3) + ...
}

// 二元右折叠:带初始值的求和
template<typename... Args>
auto sum_with_init_right(Args... args)
{
return (args + ... + 0); // 相当于:arg1 + (arg2 + (... + (argN + 0)))
}

// 二元左折叠:带初始值的求和
template<typename... Args>
auto sum_with_init_left(Args... args)
{
return (0 + ... + args); // 相当于:(((0 + arg1) + arg2) + ...) + argN
}

void test_basic_arithmetic()
{
std::cout << "=== 基本算术运算测试 ===" << std::endl;

// 测试一元折叠
std::cout << "sum_right(1, 2, 3, 4, 5) = " << sum_right(1, 2, 3, 4, 5) << std::endl;
std::cout << "sum_left(1, 2, 3, 4, 5) = " << sum_left(1, 2, 3, 4, 5) << std::endl;

// 对于加法,左右折叠结果相同(加法满足结合律)
std::cout << "sum_right(1.1, 2.2, 3.3) = " << sum_right(1.1, 2.2, 3.3) << std::endl;

// 测试二元折叠(带初始值)
std::cout << "sum_with_init_right(1, 2, 3) = " << sum_with_init_right(1, 2, 3) << std::endl;
std::cout << "sum_with_init_left(1, 2, 3) = " << sum_with_init_left(1, 2, 3) << std::endl;

// 空参数包测试(需要二元折叠)
std::cout << "sum_with_init_right() = " << sum_with_init_right() << std::endl;
std::cout << "sum_with_init_left() = " << sum_with_init_left() << std::endl;

// 注意:一元折叠不能用于空参数包,会导致编译错误
// auto result = sum_right(); // 编译错误!

std::cout << std::endl;
}

2.2 逻辑运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// 检查所有参数是否都为 true
template<typename... Args>
bool all_true(Args... args)
{
return (... && args); // 相当于:arg1 && arg2 && ... && argN
}

// 检查是否有参数为 true
template<typename... Args>
bool any_true(Args... args)
{
return (... || args); // 相当于:arg1 || arg2 || ... || argN
}

void test_logical_operations()
{
std::cout << "=== 逻辑运算测试 ===" << std::endl;

std::cout << "all_true(true, true, true) = "
<< std::boolalpha << all_true(true, true, true) << std::endl;
std::cout << "all_true(true, false, true) = "
<< all_true(true, false, true) << std::endl;

std::cout << "any_true(false, false, true) = "
<< any_true(false, false, true) << std::endl;
std::cout << "any_true(false, false, false) = "
<< any_true(false, false, false) << std::endl;

// 空参数包的特殊情况
std::cout << "all_true() = " << all_true() << std::endl; // 返回 true
std::cout << "any_true() = " << any_true() << std::endl; // 返回 false

std::cout << std::endl;
}

2.3 字符串连接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <sstream>

// 使用折叠表达式连接字符串
template<typename... Args>
std::string concatenate(Args... args)
{
std::ostringstream oss;
(oss << ... << args); // 二元左折叠:(((oss << arg1) << arg2) << ...) << argN
return oss.str();
}

// 使用逗号分隔参数
template<typename... Args>
void print_with_comma(Args... args)
{
((std::cout << args << ", "), ...) << std::endl; // 使用逗号运算符
}

void test_string_operations()
{
std::cout << "=== 字符串操作测试 ===" << std::endl;

std::string result = concatenate("Hello", " ", "World", "!", " ", 2024);
std::cout << "concatenate(\"Hello\", \" \", \"World\", \"!\", \" \", 2024) = "
<< result << std::endl;

std::cout << "print_with_comma(1, 2, 3, \"apple\", \"banana\"): ";
print_with_comma(1, 2, 3, "apple", "banana");

std::cout << std::endl;
}

逗号分隔打印:一元右折叠(最常用模式)

1
2
3
4
5
template<typename... Args>
void print_with_comma(Args... args)
{
((std::cout << args << ", "), ...) << std::endl;
}

对于逗号运算符 (A, B) 的逻辑是:先执行 A,再执行 B,整个表达式的值是 B 的值。所以下面这行的代码

1
print_with_comma(1, 2, 3)

展开如下:

1
(std::cout << 1 << ", "), ( (std::cout << 2 << ", "), (std::cout << 3 << ", ") )

或者更直观地看执行顺序(逗号运算符从左向右执行):

  1. std::cout << 1 << ", "
  2. std::cout << 2 << ", "
  3. std::cout << 3 << ", "

2.4 比较运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <limits>
// 检查所有参数是否按顺序递增
template<typename... Args>
bool is_increasing(Args... args)
{
// 使用逗号运算符和逻辑与组合比较
bool result = true;
auto last = std::numeric_limits<std::common_type_t<Args...>>::lowest();

// 折叠表达式:检查每个参数是否大于前一个
((result = result && (args > last), last = args), ...);
return result;
}

// 简化版:检查相邻元素是否递增
template<typename First, typename... Rest>
bool are_increasing(First first, Rest... rest)
{
// 使用右折叠:first < arg1 && (arg1 < arg2 && (... && argN-1 < argN))
if constexpr (sizeof...(rest) == 0)
{
return true;
}
else
{
return ((first < rest) && ...);
}
}

void test_comparison_operations()
{
std::cout << "=== 比较运算测试 ===" << std::endl;

std::cout << "is_increasing(1, 2, 3, 4, 5) = "
<< std::boolalpha << is_increasing(1, 2, 3, 4, 5) << std::endl;
std::cout << "is_increasing(1, 3, 2, 4, 5) = "
<< is_increasing(1, 3, 2, 4, 5) << std::endl;

std::cout << "are_increasing(1.0, 2.0, 3.0) = "
<< are_increasing(1.0, 2.0, 3.0) << std::endl;
std::cout << "are_increasing(5, 4, 3) = "
<< are_increasing(5, 4, 3) << std::endl;

std::cout << std::endl;
}
  • std::numeric_limits<T>::lowest():获取模板参数 T 中所有类型的公共类型的最小可表示值。

  • std::common_type_t<Args...>:它会推断出 Args... 中所有参数都能隐式转换成的那个公共类型。

    • 如果 Args...<int, double>,公共类型是 double
    • 如果 Args...<short, int, long>,公共类型通常是 long(取决于具体平台和实现,通常是能容纳所有数的最宽类型)。
    • 如果 Args...<int, std::string>,这行代码会编译报错,因为没有公共类型。
  • ((result = result && (args > last), last = args), ...);这是一个 一元右折叠

    假设我们传入 is_increasing(10, 20, 30),参数包展开为 10, 20, 30。这行代码实际上会被编译器展开为类似下面的链式调用:

    1
    2
    3
    4
    5
    6
    // 伪代码
    (
    ( result = result && (10 > last_initial), last = 10 ), // 处理第一个参数 10
    ( result = result && (20 > 10), last = 20 ), // 处理第二个参数 20
    ( result = result && (30 > 20), last = 30 ) // 处理第三个参数 30
    );

2.5 复杂应用 - 调用函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <vector>
#include <functional>

// 使用折叠表达式调用多个函数
template<typename... Funcs>
void execute_all(Funcs... funcs)
{
(funcs(), ...); // 依次调用所有函数
}

// 将多个值插入容器
template<typename Container, typename... Args>
void insert_all(Container& container, Args... args)
{
(container.push_back(args), ...); // 使用逗号运算符
}

void test_function_calls()
{
std::cout << "=== 函数调用测试 ===" << std::endl;

std::cout << "execute_all 测试:" << std::endl;
execute_all(
[]() { std::cout << "Function 1" << std::endl; },
[]() { std::cout << "Function 2" << std::endl; },
[]() { std::cout << "Function 3" << std::endl; }
);

std::cout << "\ninsert_all 测试:" << std::endl;
std::vector<int> vec;
insert_all(vec, 1, 2, 3, 4, 5);

std::cout << "Vector contains: ";
for (int num : vec)
{
std::cout << num << " ";
}
std::cout << std::endl << std::endl;
}