矩阵链相乘(matrix-chain multiplication)
步骤
第一步
Characterize the structure of an optimal solution
数学定义:Ai..jAi..j : 矩阵乘—–AiAi+1··AjAiAi+1··Aj
存在k 使得i≤k<ji≤k<j
把Ai··jAi··j 划分为 Ai··kAi··k 和A(k+1)··jA(k+1)··j
假设此时的k正好为最优划分,使得Ai··jAi··j矩阵相乘的规模最小
反证: 如果此时k对原来矩阵序列的划分不是最优的,则肯定存在另一个k,使得划分为最优:与假设矛盾。
第二步
Recursively define the value of an optimal solution
现在我们定义递归求解这个子问题:Ai··jAi··j 为: 矩阵AiAi+1··AjAiAi+1··Aj (1≤i≤j≤n)(1≤i≤j≤n)
数学定义:M[i,j]M[i,j] = Ai··jAi··j 的最小矩阵乘规模
定义数组P,使得矩阵AiAi的纬度为 = p[i−1]∗p[i]p[i−1]∗p[i]。p[i−1]p[i−1]代表矩阵的行, p[i]代表矩阵的列)
最优子结构定义如下
![屏幕快照 2019-05-26 23.26.27](屏幕快照 2019-05-26 23.26.27.png)
伪代码
![屏幕快照 2019-05-26 23.27.34](屏幕快照 2019-05-26 23.27.34.png)
C语言实现
//
// Created by 安炳旭 on 2019-05-26.
//
#include <cstdio>
#include <vector>
#define MAX 100000 //定义最大值
#define SIZE 100 //SIZE为最多支持矩阵乘法的数量 + 1
int m[SIZE][SIZE], s[SIZE][SIZE];//用于存储中间过程的数组m 以及用于存储分割点的数组s
//初始化
void init(int n) {
for (int i = 1; i <= n; ++i) {
m[i][i] = 0;
}
}
//函数功能 计算一个矩阵链乘乘法的最小乘法次数并返回 以及把最优的分割方法存储在m数组中
int matrixChainMult(int array[], int numbersOfMatrix) {
//array待进行矩阵链乘的数组(数组下标从1开始) numbersOfMatrix代表矩阵元素的个数
//第一层循环 :区间长度
for (int partLength = 1; partLength <= numbersOfMatrix - 1; ++partLength) {
//第二层 :区间起点
for (int partStart = 1; partStart <= numbersOfMatrix - partLength; ++partStart) {
//第三层: 区间的终点 = 起点+长度
int partEnd = partStart + partLength;
m[partStart][partEnd] = MAX;
for (int k = partStart; k <= partEnd; ++k) {
//k为区间划分点
int temp = m[partStart][k] + m[k + 1][partEnd] + array[partStart] * array[partEnd + 1] * array[k + 1];
if (temp < m[partStart][partEnd]) {
//更新m 和s
m[partStart][partEnd] = temp;
s[partStart][partEnd] = k;
}
}
}
}
return m[1][numbersOfMatrix];
}
//打印一个数组的最终分法
void printResult(int i, int j) {
//i j 代表左右边界
if (i == j) {
printf("\tA%d\t",i);
} else {
printf("(");
printResult(i, s[i][j]);
printResult(s[i][j] + 1, j);
printf(")");
}
}
int main() {
//p和n为待输入的内容
// int p[SIZE+1] = {0,3, 5, 2, 1, 10};
int p[SIZE+1] = {0,10, 3, 15, 12, 7,2};
int n = 5; //n为矩阵的个数 也就说==p实际长度-1
init(SIZE+1);
int h = matrixChainMult(p, n);
printResult(1,n);
printf("h:%d",h);
}