"""MatrixOps
A program to find the number of operations needed to multiply
two n x n matrices using two different algorithms.

Usage: matrixops.py [n]
If n is not specified, matrixops will prompt for it.
"""
__author__ = "David Osolkowski (qid@wadny.com)"
__version__ = "1.0"
__date__ = "2003/04/22"
__copyright__ = "Copyright (c) 2003, David Osolkowski, all rights reserved."
__license__ = """
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
"""
__history__ = """
    1.0 - 2003/04/22 - created script, initial coding
    1.1 - 2003/04/26 - implemented "fast" multiplication algorithm
    1.2 - 2003/04/28 - final cleanup, commenting
"""

def fastAlgorithm(A, B):
    """multiply two matrices using a "fast" algorithm

    A, B - two n*n matrices where n is even
    """
    n = len(A)
    
    if (n == 1):
        # Can't recurse anymore, so just let the standard
        # algorithm handle multiplying the 1x1 matrices
        return standardAlgorithm(A, B)
    else:
        operations = 0
        
        # Break the two matrices into quadrants
        (A11, A12, A21, A22) = fastBreak(A)
        (B11, B12, B21, B22) = fastBreak(B)
        
        # Calculate P1 = (A11 + A22)(B11 + B22) 
        (tmp1, ops1) = standardAdd(A11, A22)
        (tmp2, ops2) = standardAdd(B11, B22)
        (ops3, P1)   = fastAlgorithm(tmp1, tmp2)
        operations += ops1 + ops2 + ops3
        
        # Calculate P2 = (A21 + A22) * B11
        (tmp1, ops1) = standardAdd(A21, A22)
        (ops2, P2)   = fastAlgorithm(tmp1, B11)
        operations += ops1 + ops2
        
        # Calculate P3 = A11 * (B12 - B22)
        (tmp1, ops1) = standardSubtract(B12, B22)
        (ops2, P3)   = fastAlgorithm(A11, tmp1)
        operations += ops1 + ops2
        
        # Calculate P4 = A22 * (B21 - B11)
        (tmp1, ops1) = standardSubtract(B21, B11)
        (ops2, P4)   = fastAlgorithm(A22, tmp1)
        operations += ops1 + ops2
        
        # Calculate P5 = (A11 + A12) * B22
        (tmp1, ops1) = standardAdd(A11, A12)
        (ops2, P5)   = fastAlgorithm(tmp1, B22)
        operations += ops1 + ops2
        
        # Calculate P6 = (A21 - A11) * (B11 + B12)
        (tmp1, ops1) = standardSubtract(A21, A11)
        (tmp2, ops2) = standardAdd(B11, B12)
        (ops3, P6)   = fastAlgorithm(tmp1, tmp2)
        operations += ops1 + ops2 + ops3
        
        # Calculate P7 = (A12 - A22) * (B21 + B22)
        (tmp1, ops1) = standardSubtract(A12, A22)
        (tmp2, ops2) = standardAdd(B21, B22)
        (ops3, P7)   = fastAlgorithm(tmp1, tmp2)
        operations += ops1 + ops2 + ops3
        
        
        # Calculate C11 = P1 + P4 - P5 + P7
        (tmp1, ops1) = standardAdd(P1, P4)
        (tmp2, ops2) = standardSubtract(tmp1, P5)
        (C11, ops3)  = standardAdd(tmp2, P7)
        operations += ops1 + ops2 + ops3
        
        # Calculate C12 = P3 + P5
        (C12, ops1) = standardAdd(P3, P5)
        operations += ops1
        
        # Calculate C21 = P2 + P4
        (C21, ops1) = standardAdd(P2, P4)
        operations += ops1
        
        # Calculate C22 = P1 + P3 - P2 + P6
        (tmp1, ops1) = standardAdd(P1, P3)
        (tmp2, ops2) = standardSubtract(tmp1, P2)
        (C22, ops3)  = standardAdd(tmp2, P6)
        operations += ops1 + ops2 + ops3
        
        # Assemble C11, C12, C21, and C22 into the full C
        C = [[1 for j in range(n)] for i in range(n)]
        for i in range(n):
            for j in range(n):
                if ((i < n/2) and (j < n/2)):
                    C[i][j] = C11[i][j]
                elif ((i < n/2) and (j >= n/2)):
                    C[i][j] = C12[i][j - n/2]
                elif ((i >= n/2) and (j < n/2)):
                    C[i][j] = C21[i - n/2][j]
                else:
                    C[i][j] = C22[i - n/2][j - n/2]
        
        return (operations, C)

def fastBreak(A):
    """Break a matrix into quadrants"""
    n = len(A)
    A11 = [[A[i][j] for j in range(n/2)]    for i in range(n/2)]
    A12 = [[A[i][j] for j in range(n/2, n)] for i in range(n/2)]
    A21 = [[A[i][j] for j in range(n/2)]    for i in range(n/2, n)]
    A22 = [[A[i][j] for j in range(n/2, n)] for i in range(n/2, n)]
    return (A11, A12, A21, A22)

def standardAdd(A, B):
    """Add two matrices together using the standard algorithm"""
    n = len(A)
    return ([[A[i][j] + B[i][j] for j in range(n)] for i in range(n)], n**2)

def standardSubtract(A, B):
    """Subtract two matrices using the standard algorithm"""
    n = len(A)
    return ([[A[i][j] - B[i][j] for j in range(n)] for i in range(n)], n**2)

def standardAlgorithm(A, B):
    """Multiply two matrices using the standard algorithm"""
    operations = 0
    n = len(A)
    C = [[0 for x in range(n)] for x in range(n)]
    for i in range(n):
        for j in range(n):
            c = 0
            for q in range(n):
                c = c + (A[i][q] * B[q][j])
                operations = operations + 2
            C[i][j] = c
    return (operations, C)

if __name__ == "__main__":
    import sys
    from random import Random
    
    # Get a value for n
    if (len(sys.argv) != 2):
        n = int(raw_input("N:"))
    else:
        n = int(sys.argv[1])
    
    # Generate two random matrices
    rng = Random()
    A = [[rng.randint(1, 100) for i in range(n)] for i in range(n)]
    B = [[rng.randint(1, 100) for i in range(n)] for i in range(n)]
    print "Matrix one:", A
    print "Matrix two:", B
    raw_input("Press a key to continue...")
    
    # Multiply the two matrices using each algorithm
    (stdOps, stdResult) = standardAlgorithm(A, B)
    (fastOps, fastResult) = fastAlgorithm(A, B)
    
    if (fastResult != stdResult):
        print "Standard result:", stdResult
        print "Fast result:", fastResult
        print "ERROR: Results do not match!"
    
    print "The standard algorithm took %d operations." % (stdOps, )
    print "The fast algorithm took %d operations." % (fastOps, )