1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
use blas::fortran as backend; use format::Conventional; use operation::{Multiply, MultiplyInto, ScaleSelf}; impl Multiply<[f64], Conventional<f64>> for Conventional<f64> { #[inline] fn multiply(&self, right: &[f64]) -> Self { let (m, p) = (self.rows, self.columns); let n = right.len() / p; let mut result = unsafe { Conventional::with_uninitialized((m, n)) }; multiply(1.0, &self.values, right, 0.0, &mut result.values, m, p, n); result } } impl MultiplyInto<Conventional<f64>, [f64]> for Conventional<f64> { #[inline(always)] fn multiply_into(&self, right: &Self, result: &mut [f64]) { MultiplyInto::multiply_into(self, &*right as &[f64], result) } } impl MultiplyInto<Vec<f64>, [f64]> for Conventional<f64> { #[inline(always)] fn multiply_into(&self, right: &Vec<f64>, result: &mut [f64]) { MultiplyInto::multiply_into(self, &*right as &[f64], result) } } impl MultiplyInto<[f64], [f64]> for Conventional<f64> { #[inline] fn multiply_into(&self, right: &[f64], result: &mut [f64]) { let (m, p) = (self.rows, self.columns); let n = right.len() / p; multiply(1.0, &self.values, right, 1.0, result, m, p, n) } } impl ScaleSelf<f64> for [f64] { #[inline] fn scale_self(&mut self, alpha: f64) { backend::dscal(self.len() as i32, alpha, self, 1); } } fn multiply(alpha: f64, a: &[f64], b: &[f64], beta: f64, c: &mut [f64], m: usize, p: usize, n: usize) { debug_assert_eq!(a.len(), m * p); debug_assert_eq!(b.len(), p * n); debug_assert_eq!(c.len(), m * n); let (m, p, n) = (m as i32, p as i32, n as i32); if n == 1 { backend::dgemv(b'N', m, p, alpha, a, m, b, 1, beta, c, 1); } else { backend::dgemm(b'N', b'N', m, n, p, alpha, a, m, b, p, beta, c, m); } } #[cfg(test)] mod tests { use prelude::*; #[test] fn multiply() { let matrix = Conventional::from_vec((2, 3), vec![ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ]); let right = Conventional::from_vec((3, 4), vec![ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ]); assert_eq!(matrix.multiply(&right), Conventional::from_vec((2, 4), vec![ 22.0, 28.0, 49.0, 64.0, 76.0, 100.0, 103.0, 136.0, ])); } #[test] fn multiply_into() { let matrix = Conventional::from_vec((2, 3), vec![ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ]); let right = Conventional::from_vec((3, 4), vec![ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ]); let mut result = Conventional::from_vec((2, 4), vec![ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ]); matrix.multiply_into(&right, &mut result); assert_eq!(result, Conventional::from_vec((2, 4), vec![ 23.0, 30.0, 52.0, 68.0, 81.0, 106.0, 110.0, 144.0, ])); } #[test] fn scale_self() { let mut matrix = Conventional::from_vec(2, vec![21.0, 21.0, 21.0, 21.0]); matrix.scale_self(2.0); assert_eq!(matrix, Conventional::from_vec(2, vec![42.0, 42.0, 42.0, 42.0])); } }