昆明賢邦網(wǎng)站建設(shè)百度站長(zhǎng)工具seo查詢
優(yōu)于立方復(fù)雜度的 Rust 中矩陣乘法
邁克·克維特
跟隨
更好的編程
143
中途:三次矩陣乘法
一、說(shuō)明
????????幾年前,我在 C++ 年編寫(xiě)了?Strassen 矩陣乘法算法的實(shí)現(xiàn),最近在?Rust 中重新實(shí)現(xiàn)了它,因?yàn)槲依^續(xù)學(xué)習(xí)該語(yǔ)言。這是學(xué)習(xí) Rust 性能特征和優(yōu)化技術(shù)的有用練習(xí),因?yàn)楸M管 Strassen 的算法復(fù)雜性優(yōu)于樸素方法,但它在算法結(jié)構(gòu)中的分配和遞歸開(kāi)銷中具有很高的常數(shù)因子。
- 通用算法
- 換位以獲得更好的性能
- 次立方:斯特拉森算法的工作原理
- 排比
- 標(biāo)桿
- 分析和性能優(yōu)化
二、通用算法
????????一般(樸素)矩陣乘法算法是每個(gè)人在他們的第一堂線性代數(shù)課上學(xué)習(xí)的三個(gè)嵌套循環(huán)方法,大多數(shù)人會(huì)將其識(shí)別為?O(n3)
pub fn
mult_naive (a: &Matrix, b: &Matrix) -> Matrix {if a.rows == b.cols {let m = a.rows;let n = a.cols;// preallocatelet mut c: Vec<f64> = Vec::with_capacity(m * m);for i in 0..m {for j in 0..m {let mut sum: f64 = 0.0;for k in 0..n {sum += a.at(i, k) * b.at(k, j);}c.push(sum);}}return Matrix::with_vector(c, m, m);} else {panic!("Matrix sizes do not match");}
}
????????這種算法很慢,不僅因?yàn)槿齻€(gè)嵌套循環(huán),還因?yàn)榘戳型ㄟ^(guò)而不是按行的內(nèi)部循環(huán)遍歷對(duì)于 CPU 緩存命中率來(lái)說(shuō)是可怕的。B
b.at(k, j)
三、換位以獲得更好的性能
? ? ? ? 轉(zhuǎn)置樸素方法允許 B 上的乘法迭代在行而不是列上運(yùn)行,將矩陣 B 的乘法步幅重新組織為更有利于緩存的格式。從而變成A x B
A x B^t
?????????它涉及一個(gè)新的矩陣分配(無(wú)論如何,在這個(gè)實(shí)現(xiàn)中)和一個(gè)完整的矩陣迭代(一個(gè) O(n2) 操作,更準(zhǔn)確地說(shuō),這種方法是 O(n3) + O(n2))——我將進(jìn)一步展示它的性能有多好。它如下所示:
fn multiply_transpose (A: Matrix, B: Matrix):C = new Matrix(A.num_rows, B.num_cols)// Construct transpose; requires allocation and iteration through BB’ = B.transpose()for i in 0 to A.num_rows:for j in 0 to B'.num_rows:sum = 0;for k in 0 to A.num_cols:// Sequential access of B'[j, k] is much faster than B[k, j]sum += A[i, k] * B'[j, k]C[i, j] = sumreturn C
四、次立方:斯特拉森算法的工作原理
????????要了解 Strassen 算法的工作原理(此處為 Rust 代碼),首先考慮矩陣如何用象限表示。要概念化它的外觀:
????????在樸素算法中使用此象限模型,結(jié)果矩陣?C?的四個(gè)象限中的每一個(gè)都是兩個(gè)子矩陣乘積的總和,總共產(chǎn)生 8 次乘法。
????????考慮到這八個(gè)乘法,每個(gè)乘法都在一個(gè)塊矩陣上運(yùn)行,其行和列跨度約為 A 和 B 大小的一半,復(fù)雜性相同:
????????斯特拉森算法定義了由這些象限組成的七個(gè)中間塊矩陣:
????????僅通過(guò)?7?次乘法而不是 8 次乘法計(jì)算。這些乘法可以是遞歸斯特拉森乘法,并可用于組成最終矩陣:
由此產(chǎn)生的亞立方復(fù)雜度:
五、排比
????????中間矩陣 M1 的計(jì)算 ...M7 是一個(gè)令人尷尬的并行問(wèn)題,因此也很容易檢測(cè)算法的并發(fā)變體(一旦你開(kāi)始理解?Rust 關(guān)于閉包的規(guī)則)。
/*** Execute a recursive strassen multiplication of the given vectors, * from a thread contained within the provided thread pool.*/
fn
_par_run_strassen (a: Vec<f64>, b: Vec<f64>, m: usize, pool: &ThreadPool) -> Arc<Mutex<Option<Matrix>>> {let m1: Arc<Mutex<Option<Matrix>>> = Arc::new(Mutex::new(None));let m1_clone = Arc::clone(&m1);pool.execute(move|| { // Recurse with non-parallel algorithm once we're // in a working threadlet result = mult_strassen(&mut Matrix::with_vector(a, m, m),&mut Matrix::with_vector(b, m, m));*m1_clone.lock().unwrap() = Some(result);});return m1;
}
六、標(biāo)桿
????????我編寫(xiě)了一些快速的基準(zhǔn)測(cè)試代碼,該代碼在不斷增加的矩陣維度范圍內(nèi)運(yùn)行四種算法中的每一種進(jìn)行幾次試驗(yàn),并報(bào)告每種算法的平均時(shí)間。
~/code/strassen ~>> ./strassen --lower 75 --upper 100 --factor 50 --trials 2running 50 groups of 2 trials with bounds between [75->3750, 100->5000]x y nxn naive transpose strassen par_strassen
75 100 7500 0.00ms 0.00ms 1.00ms 0.00ms
150 200 30000 6.50ms 4.00ms 4.00ms 1.00ms
225 300 67500 12.50ms 9.00ms 8.50ms 2.50ms
300 400 120000 26.50ms 22.00ms 18.00ms 5.50ms
[...]
3600 4800 17280000 131445.00ms 53683.50ms 21210.50ms 5660.00ms
3675 4900 18007500 141419.00ms 58530.00ms 28291.50ms 6811.00ms
3750 5000 18750000 154941.00ms 60990.00ms 26132.00ms 6613.00ms
????????然后,我通過(guò)以下方式可視化結(jié)果:pyplot
????????此圖顯示了矩陣從 7.5k 元素 () 到大約 19 萬(wàn) () 的乘法時(shí)間。你可以看到樸素算法在計(jì)算上變得不切實(shí)際的速度有多快,在高端需要兩分半鐘。N x M = 75 x 100
N x M = 3750 x 5000
????????相比之下,Strassen 算法的擴(kuò)展更平滑,并行算法計(jì)算兩個(gè) 19M 個(gè)元素的矩陣的結(jié)果,而樸素算法只處理 3.6M 個(gè)元素所花費(fèi)的時(shí)間。
????????對(duì)我來(lái)說(shuō)最有趣的是算法的性能。如前所述,緩存性能的改進(jìn)(以犧牲完整矩陣副本為代價(jià))在這些結(jié)果中得到了清楚地證明 - 即使使用與該方法漸近等效的算法也是如此。transpose
naive
七、分析和性能優(yōu)化
????????這個(gè)文檔是理解 Rust 性能基礎(chǔ)知識(shí)的絕佳資源。在?Mac OS 上啟動(dòng)并運(yùn)行儀器進(jìn)行分析是微不足道的,這要?dú)w功于貨運(yùn)儀器的 Rust 指南。這是調(diào)查分配行為、CPU 熱點(diǎn)和其他事情的絕佳工具。
在此過(guò)程中發(fā)生了一些變化:
- Strassen 代碼通過(guò)分而治之策略遞歸調(diào)用自己,但是一旦矩陣達(dá)到足夠小的大小,其高常數(shù)因子使其比一般矩陣算法慢。我發(fā)現(xiàn)這個(gè)點(diǎn)是大約?64?的行寬或列寬;通過(guò)提高吞吐量提高幾個(gè)因素來(lái)增加此閾值
2
- 斯特拉森算法要求矩陣填充到最接近的指數(shù) 2;減少這種情況以懶惰地確保矩陣只有偶數(shù)行和列?通過(guò)減少昂貴的大分配,將吞吐量提高了大約兩倍
- 將小矩陣回退算法從 更改為 導(dǎo)致大約 20% 的改進(jìn)
naive
transpose
- 添加和添加到?Cargo.toml?發(fā)布構(gòu)建標(biāo)志大約提高了 5%。有趣的是,性能持續(xù)惡化
codegen-units = 1
lto = "thin"
lto = “true”
- 一絲不茍地刪除所有可能的副本大約提高了~10%
Vec
- 提供一些提示并刪除隨機(jī)訪問(wèn)查找中的向量邊界檢查,又提高了大約 20%
#[inline]
/*** Returns the element at (i, j). Unsafe.*/#[inline]pub fn at (&self, i: usize, j: usize) -> f64 {unsafe {return *self.elements.get_unchecked(i * self.cols + j);}}
參考資料: