"""
This script performs nonparametric association tests on numeric columns and a binary dependent column.
Uses python library scipy.stats

Author: Autumn Laughbaum
Last Revised: 2013-08-13
"""
ghi.requireVersion('7.4')
import numpy as n
import scipy.stats as s
from collections import Counter

def nonparametric(ss,depDM,test,correct,negLog):
    
    y = depDM.col(1)
    dm = ss.dataModel(ghi.const.FilterBinary|ghi.const.FilterReal|ghi.const.FilterInt|ghi.const.FilterActiveOnly)
    nC = dm.numCols()    
    nR = dm.numRows()
    numControls = 0    
    
    caseCounts = Counter(y)

    if caseCounts[False] < 20 or caseCounts[True] < 20: #Number of cases or controls is less than 20
        ghi.message("Must have at least twenty cases and twenty controls")
        return        

    caseRowIdxList = []
    controlRowIdxList = []

    for idx,r in enumerate(y):
        if r == False:
            controlRowIdxList.append(idx+1)
        if r == True:
            caseRowIdxList.append(idx+1)

    dmCols = range(1,nC+1)
    dmCases = dm.subsetRows(caseRowIdxList)
    dmControls = dm.subsetRows(controlRowIdxList)

    newRowLabs = []
    stat = []
    pVal = []
    negLogP = []   
    bonfP = []
   
    progress = ghi.progressDialog("Calculating...",len(dmCols))
    count = 0
    for column in dmCols:
        if progress.wasCanceled():
            progress.finish()
            return
        count += 1
        progress.setProgress(count)
        newRowLabs.append(dm.colHeaders()[column-1])

        controls = dmControls.col(column)
        cases = dmCases.col(column)
        controlsU = list(set(controls))
        casesU = list(set(cases))
        if len(controlsU) == 1 and len(casesU) == 1 and controlsU[0] == casesU[0]:
            stat.append(None)
            pVal.append(None)
            negLogP.append(None)
            bonfP.append(None)
        elif test == "Wilcoxon Rank-sum test (Z-statistic)":
            out = s.ranksums(cases,controls) 
            stat.append(float(out[0]))
            pVal.append(float(out[1]))
            tempVal = float(out[1])
            if tempVal == 0.0:
                negLogP.append(None)
            else:
                negLogP.append(-float(n.log10(out[1])))
            bp = nC*float(out[1])
            if bp > 1:
                bonfP.append(1.0)
            else:
                bonfP.append(bp)
        else:
            out = s.mannwhitneyu(cases,controls)
            stat.append(float(out[0]))
            pVal.append(float(out[1]))
            tempVal = float(out[1])
            if tempVal == 0.0:
                negLogP.append(None)
            else:
                negLogP.append(-float(n.log10(out[1])))
            bp = nC*float(out[1])
            if bp > 1:
                bonfP.append(1.0)
            else:
                bonfP.append(bp)
    progress.finish()
    
    if test == "Wilcoxon Rank-sum test (Z-statistic)":
        builder = ghi.dataSetBuilder('Wilcoxon Rank Sum Test Results',nC)
        builder.addRealColumn("Test Statistic (Z-Statistic)",stat)
    else:
        builder = ghi.dataSetBuilder('Mann-Whitney Rank Sum Test Results',nC)  
        builder.addRealColumn("Test Statistic (U-Statistic)",stat)        
    builder.addRowLabels('Columns',newRowLabs)
    builder.addRealColumn('One-sided P-value',pVal)
    if negLog:
        builder.addRealColumn('-log10(P)',negLogP)
    if correct:
        builder.addRealColumn('Bonf-P',bonfP)
    
    sheet = builder.finish(ss.getID())
    if dm.markerMapOrientation()==ghi.const.MapOrientationColumns:
        sheet.setMarkerMap(dm,offset=dm.markerMapOffset(),columnOriented=0)

    sheet.show()
    
try:
    ssIn = ghi.getCurrentObject()
    depDMIn = ssIn.dataModel(ghi.const.FilterDependent)
    if depDMIn.numCols() != 1 or depDMIn.colType(1) != ghi.const.TypeBinary:
        ghi.message("One dependent binary column must be specified in order to perform association tests")
    else:
        depDM = ssIn.dataModel(ghi.const.FilterDependent)
        prompt = ghi.promptDialog([{"name":"test","label":"Choose a rank based test:","type":"combo","list":["Mann-Whitney Rank test (U-statistic)","Wilcoxon Rank-sum test (Z-statistic)"]},
                                   {"name":"mult","label":"Bonferroni Adjustment for multiple tests","type":"check","default":1},
                                {"name":"neglog","label":"Output -log10(P)","type":"check","default":1}],
                                title="Nonparametric Association Tests")
                                
        if prompt:
            testIn = prompt['test']
            correctIn = prompt['mult']
            negLogIn = prompt['neglog']
            nonparametric(ssIn,depDMIn,testIn,correctIn,negLogIn)
                                
except:
    ghi.error()
