1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
use blas::fortran as backend;

use format::Conventional;
use operation::{Multiply, MultiplyInto, ScaleSelf};

impl Multiply<[f64], Conventional<f64>> for Conventional<f64> {
    #[inline]
    fn multiply(&self, right: &[f64]) -> Self {
        let (m, p) = (self.rows, self.columns);
        let n = right.len() / p;
        let mut result = unsafe { Conventional::with_uninitialized((m, n)) };
        multiply(1.0, &self.values, right, 0.0, &mut result.values, m, p, n);
        result
    }
}

impl MultiplyInto<Conventional<f64>, [f64]> for Conventional<f64> {
    #[inline(always)]
    fn multiply_into(&self, right: &Self, result: &mut [f64]) {
        MultiplyInto::multiply_into(self, &*right as &[f64], result)
    }
}

impl MultiplyInto<Vec<f64>, [f64]> for Conventional<f64> {
    #[inline(always)]
    fn multiply_into(&self, right: &Vec<f64>, result: &mut [f64]) {
        MultiplyInto::multiply_into(self, &*right as &[f64], result)
    }
}

impl MultiplyInto<[f64], [f64]> for Conventional<f64> {
    #[inline]
    fn multiply_into(&self, right: &[f64], result: &mut [f64]) {
        let (m, p) = (self.rows, self.columns);
        let n = right.len() / p;
        multiply(1.0, &self.values, right, 1.0, result, m, p, n)
    }
}

impl ScaleSelf<f64> for [f64] {
    #[inline]
    fn scale_self(&mut self, alpha: f64) {
        backend::dscal(self.len() as i32, alpha, self, 1);
    }
}

fn multiply(alpha: f64, a: &[f64], b: &[f64], beta: f64, c: &mut [f64], m: usize, p: usize,
            n: usize) {

    debug_assert_eq!(a.len(), m * p);
    debug_assert_eq!(b.len(), p * n);
    debug_assert_eq!(c.len(), m * n);

    let (m, p, n) = (m as i32, p as i32, n as i32);

    if n == 1 {
        backend::dgemv(b'N', m, p, alpha, a, m, b, 1, beta, c, 1);
    } else {
        backend::dgemm(b'N', b'N', m, n, p, alpha, a, m, b, p, beta, c, m);
    }
}

#[cfg(test)]
mod tests {
    use prelude::*;

    #[test]
    fn multiply() {
        let matrix = Conventional::from_vec((2, 3), vec![
            1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
        ]);
        let right = Conventional::from_vec((3, 4), vec![
            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
        ]);

        assert_eq!(matrix.multiply(&right), Conventional::from_vec((2, 4), vec![
            22.0, 28.0, 49.0, 64.0, 76.0, 100.0, 103.0, 136.0,
        ]));
    }

    #[test]
    fn multiply_into() {
        let matrix = Conventional::from_vec((2, 3), vec![
            1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
        ]);
        let right = Conventional::from_vec((3, 4), vec![
            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
        ]);
        let mut result = Conventional::from_vec((2, 4), vec![
            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
        ]);

        matrix.multiply_into(&right, &mut result);

        assert_eq!(result, Conventional::from_vec((2, 4), vec![
            23.0, 30.0, 52.0, 68.0, 81.0, 106.0, 110.0, 144.0,
        ]));
    }

    #[test]
    fn scale_self() {
        let mut matrix = Conventional::from_vec(2, vec![21.0, 21.0, 21.0, 21.0]);
        matrix.scale_self(2.0);
        assert_eq!(matrix, Conventional::from_vec(2, vec![42.0, 42.0, 42.0, 42.0]));
    }
}