"""
From a spreadsheet containing genotype data and optional phenotype and pedigree
information, this script saves the spreadsheet in a series of BEAGLE files with
the chromosome name (and optional population strata) appended to
each file. If there is no marker map, it will write everything to one file.

Author: Autumn Laughbaum, Jesse Dupre, Christophe Lambert
Last Modified: 2013-12-17:
"""

ghi.requireVersion('7.4')
import os
delim = ' '


def exportBeagle(dm,dmG,filePath,baseName,aff,ped,fid,mid,pop,cov,types,pMessage):
    
    m = 'Export files to beagle format\n' + \
        ' *Filepath: ' + filePath + '\n' + \
        ' *Basename: ' + baseName + '\n'
    if dm:
        dmHeaders = dm.colHeaders()
        dmIdx = dm.colIndexes()
        if aff:
            m += ' *Affection Status: [' + str(dmIdx[aff-1])+ '] ' + dmHeaders[aff-1] + '\n'
        else:
            m += ' *No Affection Status used\n'
        if ped:
            m += ' *Pedigree Identifier: [' + str(dmIdx[ped-1])+ '] ' + dmHeaders[ped-1] + '\n'
        else:
            m += ' *No Pedigree Identifier used\n'
        if fid:
            m += ' *Father Identifier: [' + str(dmIdx[fid-1])+ '] ' + dmHeaders[fid-1] + '\n'
        else:
            m += ' *No Father Identifier used\n'        
        if mid:
            m += ' *Mother Identifier: [' + str(dmIdx[mid-1])+ '] ' + dmHeaders[mid-1] + '\n'
        else:
            m += ' *No Mother Identifier used\n'
        if pop:
            m += ' *Population Stratum: [' + str(dmIdx[pop-1])+ '] ' + dmHeaders[pop-1] + '\n'
        else:
            m += ' *No Population Stratum used\n'        
        if cov:
            for c in cov:
                m += ' *Additional covariate: [' + str(dmIdx[c-1])+ '] ' + dmHeaders[c-1] + '\n'
        else:
            m += ' *No additional covariates used\n'
    
    m += '\nCreated the following files in ' + str(filePath) + ':\n'
    
    if dmG.hasMarkerMap():
        chrvect = dmG.markerMapChromosomes()
        chrList = dmG.orderedChrList()
    else:
        chrvect = [None for x in range(dmG.numCols())]
        chrList = [None]
        
    if pMessage:
        progress = ghi.progressDialog(pMessage,100)
    else:
        progress = ghi.progressDialog('Writing beagle file(s)...', 100)
    progress.show()

    if progress.wasCanceled() and pMessage:
        return None
    elif progress.wasCanceled():
        return

    N = dmG.numCols() + 5 + len(cov)
    count = 0

    curChr = chrvect[0]
    newFiles = 1
    f,tempM = new_file(None,curChr,filePath,baseName,aff,ped,fid,mid,pop,cov,types,dm,dmG)                
    m += tempM
       
    for i in range(1,dmG.numCols()+1):
        chr = chrvect[i-1]
        if chr != curChr:
            #New file
            f,tempM = new_file(f,chr,filePath,baseName,aff,ped,fid,mid,pop,cov,types,dm,dmG)
            curChr = chr
            m += tempM
            newFiles += 1
        #Write Genotypic data
        header = dmG.colHeader(i).replace(' ','_')
        f.write('M' + delim + header + delim)
        column = [delim.join(geno.split('_')) for geno in dmG.col(i)]
        f.write(delim.join(column) + '\n')
        
        if progress.wasCanceled() and pMessage:
            return None
        elif progress.wasCanceled():
            return
        if 100.0*i/N > count:
            count = 100.0*i/N
            progress.setProgress(count)
        
    progress.setProgress(100)
    progress.finish()
    
    dmG.appendLog(m)
    if not pMessage:
        ghi.message('Created ' + str(newFiles) + ' files in ' + filePath)
    else:
        return newFiles


def new_file(f,chr,filePath,baseName,aff,ped,fid,mid,pop,cov,types,dm,dmG):

    labels = [l.replace(' ','_') + delim + l.replace(' ','_') for l in dmG.rowLabels()]
    binaryDict = {False:'0',True:'1',None:'?'}
    
    if f:
        f.close()
    #First write phenotype/pedigree data
    if chr != None:
        path = os.path.join(filePath,baseName + '_chr' + str(chr) + '.bgl')
        tempM = ' - ' + baseName + '_chr' + str(chr) + '.bgl\n'
    else:
        path = os.path.join(filePath,baseName + '_unmapped.bgl')
        tempM = ' - ' + baseName + '_unmapped.bgl\n'
    f = open(path,'w')
    #Label identifier
    f.write('I' + delim + 'id' + delim) 
    f.write(delim.join(labels) + '\n') 
    if aff:
        header = dm.colHeader(aff).replace(' ','_')
        f.write('A' + delim + header + delim)
        column = [binaryDict[a] + delim + binaryDict[a] for a in dm.col(aff)]
        f.write(delim.join(column) + '\n')
    if ped:
        header = dm.colHeader(ped).replace(' ','_')
        f.write('P' + delim + header + delim)
        column = [str(a) + delim + str(a) for a in dm.col(ped)]
        f.write(delim.join(column) + '\n')        
    if fid:
        header = dm.colHeader(fid).replace(' ','_')
        f.write('FID' + delim + header + delim)
        column = []
        for a in dm.col(fid):
            if a == None:
                a = '?'
            column.append(str(a) + delim + str(a))
        f.write(delim.join(column) + '\n')
    if mid:
        header = dm.colHeader(mid).replace(' ','_')
        f.write('MID' + delim + header + delim)
        column = []
        for a in dm.col(mid):
            if a == None:
                a = '?'
            column.append(str(a) + delim + str(a))
        f.write(delim.join(column) + '\n')
    if pop:
        header = dm.colHeader(pop).replace(' ','_')
        f.write('S' + delim + header + delim)
        column = [str(a) + delim + str(a) for a in dm.col(pop)]
        f.write(delim.join(column) + '\n')
    if cov:
        for idx,c in enumerate(cov):
            type = types[idx]
            header = dm.colHeader(c).replace(' ','_')
            f.write(type + delim + header + delim)
            if type=='A':
                column = [binaryDict[a] + delim + binaryDict[a] for a in dm.col(c)]
            else:
                column = [str(a) + delim + str(a) if a!=None else '?'+delim+'?' for a in dm.col(c)]
            f.write(delim.join(column) + '\n') 
            
    return f,tempM


def loop_by_strata(dm,dmG,filePath,baseName,aff,ped,fid,mid,pop,cov,types):
    
    populations = []
    rowIdx = {}
    populationCol = dm.col(pop)
    
    for i,p in enumerate(populationCol):
        if p in rowIdx:
            rowIdx[p].append(i+1)
        else:
            rowIdx[p] = [i+1]
            populations.append(p)
    
    ss = ghi.getObject(dm.nodeID())
    allRows = range(1,ss.numRows()+1)
    prevInactive = list(set(allRows) - set(dm.rowIndexes()))
    newFiles = 0
    for p in populations:
        rows = rowIdx[p]
        ss.setRowState(allRows,ghi.const.StateInactive)
        ss.setRowState(rows,ghi.const.StateActive)
        ss.setRowState(prevInactive,ghi.const.StateInactive)
        dmTemp = ss.dataModel()
        dmGTemp = ss.dataModel(ghi.const.FilterGenotypic)
        if p=='?':
            baseNameTemp = baseName + '_population=None'
            pMessage = 'Writing beagle file(s) for population=None...'
        else:
            baseNameTemp = baseName+ '_population=' + p
            pMessage = 'Writing beagle file(s) for population='+p+'...'
        newFileTemp = exportBeagle(dmTemp,dmGTemp,filePath,baseNameTemp,aff,ped,fid,mid,pop,cov,types,pMessage)
        if newFileTemp == None:
            return
        else:
            newFiles += newFileTemp

    ss.setRowState(allRows,ghi.const.StateActive)
    ss.setRowState(prevInactive,ghi.const.StateInactive)    
    
    ghi.message('Created ' + str(newFiles) + ' files in ' + filePath)



def check(ss):
    try:
        dmG = ss.dataModel(ghi.const.FilterGenotypic)
    except:
        ghi.message("Spreadsheet must contain active genotypic data.")  
        return None
    
    try:
        dm = ss.dataModel(ghi.const.FilterBinary|ghi.const.FilterCategorical|ghi.const.FilterQuantitative)
    except:
        dm = None
        #Its fine to not have additional columns, will not be able to export additional
        
    return dm,dmG   

    
def get_col(idx,dm):
    dmIdx = dm.colIndexes()
    try:
        outIdx = dmIdx.index(idx)+1
    except:
        outIdx = None
    
    return outIdx
    

def get_type(idx,dm): #Expects dmIdx
    type = dm.colType(idx)
    if type == ghi.const.TypeCategorical:
        return 'C'
    elif type == ghi.const.TypeBinary:
        return 'A'
    elif type in [ghi.const.TypeInteger,ghi.const.TypeReal]:
        return 'T'
    
    

def prompt():
    
    ss = ghi.getCurrentObject()
    result = check(ss)
    if result == None:
        return
    else:
        dm,dmG = result
    
    pdl = [{"name":"outdir","label":"Please select a save location...","type":"dir"},
           {"name":"baseName","label":"Specify the base file name:","type":"string","default":ss.nodeName().replace(" ","_")},
           {'type':'group','label':'Optional Output','items':[
                {"name":"aff","label":"Affection Status Column:","type":"column","spreadsheet":ss.nodeID(),"types":[ghi.const.TypeBinary],"required":False},
                {"name":"ped","label":"Pedigree Identifier:","type":"column","spreadsheet":ss.nodeID(),"types":[ghi.const.TypeCategorical],"required":False},
                {"name":"fid","label":"Father Identifier:","type":"column","spreadsheet":ss.nodeID(),"types":[ghi.const.TypeCategorical,ghi.const.TypeInteger],"required":False},
                {"name":"mid","label":"Mother Identifier:","type":"column","spreadsheet":ss.nodeID(),"types":[ghi.const.TypeCategorical,ghi.const.TypeInteger],"required":False},
                {"name":"pop","label":"Population Stratum:","type":"column","spreadsheet":ss.nodeID(),"types":[ghi.const.TypeCategorical],"required":False},
                {"name":"cov","label":"Additional Covariates:","type":"columns","spreadsheet":ss.nodeID(),"types":[ghi.const.TypeCategorical,ghi.const.TypeBinary,ghi.const.TypeInteger,ghi.const.TypeReal],"required":False}]}]               
    
    prompt = ghi.promptDialog(pdl, title = "Export BEAGLE by Chromosome", width = 400)   
    
    if not prompt:
        return
    
    filePath = prompt['outdir']
    if filePath[-4:] == '.bgl':
        filePath = filePath[0:-4]
    baseName = prompt['baseName']
    
    if prompt['aff']:
        aff = get_col(prompt['aff'],dm)
    else:
        aff = None
    if prompt['ped']:
        ped = get_col(prompt['ped'],dm)
    else:
        ped = None
    if prompt['fid']:
        fid = get_col(prompt['fid'],dm)
    else:
        fid = None
    if prompt['mid']:
        mid = get_col(prompt['mid'],dm)
    else:
        mid = None
    if prompt['pop']:
        pop = get_col(prompt['pop'],dm)
    else:
        pop = None
    cov = []
    types = []
    if prompt['cov']:
        for c in prompt['cov']:
            idx = get_col(c,dm)
            cov.append(idx)
            types.append(get_type(idx,dm))
    
    pMessage = None
    #If Population Stratum is selected, optionally output separate set of files for each strata
    if pop:
        if ghi.question('Would you like to output a separate set of files for each population strata?'):
            loop_by_strata(dm,dmG,filePath,baseName,aff,ped,fid,mid,pop,cov,types)
        else:
            exportBeagle(dm,dmG,filePath,baseName,aff,ped,fid,mid,pop,cov,types,pMessage)
    else:
        exportBeagle(dm,dmG,filePath,baseName,aff,ped,fid,mid,pop,cov,types,pMessage)
            
            
try:
    prompt()
except:
    ghi.error()
