#!/usr/bin/python
# -*- coding: utf-8 -*-

#~ GEBVtest
#~ Copyright (C) 2013 Interbull Centre
#~
#~ This program is free software: you can redistribute it and/or modify
#~ it under the terms of the GNU General Public License as published by
#~ the Free Software Foundation, either version 3 of the License, or
#~ (at your option) any later version.
#~
#~ This program is distributed in the hope that it will be useful,
#~ but WITHOUT ANY WARRANTY; without even the implied warranty of
#~ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#~ GNU General Public License for more details.
#~
#~  http://www.gnu.org/licenses/

'''
Perform the GEBV validation tests for one breed-population combination,
for all traits.
'''

# Revision history:
# 2013.01.05 GJansen - original version
# 2014.06.02 MAN - The same weight is used for DD and GM (edcd)
# 2015.02.27 MAN - The weight edcd/(edcd+Lambda) was used instead of edcd.
# 2019.05.23 HBA - (b<1.2) rule implemented

import os
import sys
import argparse
import zipfile
from datetime import date
import numpy as np
import ibutils

#=====================
version = '2019.05.23'
#=====================

rundate = date.today().strftime('%Y%m%d')

# to see help summary: python gebvtest.py --help
epilog = 'See detailed instructions at: '\
    'https://wiki.interbull.org/public/gebvtest_py?action=print'''

# see http://docs.python.org/2.7/howto/argparse.html
parser = argparse.ArgumentParser(epilog=epilog)
parser.add_argument('brd',
                    help='evaluation breed code (BSW/GUE/JER/HOL/RDC/SIM)')
parser.add_argument('pop',
                    help='population code (same as country code except for'\
                        ' CHR/DEA/DFS/FRR/FRM)')
parser.add_argument('datadir',
                    help='absolute or relative path to data files')
parser.add_argument('-v', '--verbose', action='store_true',
                    help='increase output verbosity')
parser.add_argument('-m', '--mergefiles', action='store_true',
                    help='write merged data files (for independent data'
                    ' checks)')
parser.add_argument('-M', '--mergedir',
                    help='absolute or relative path for merged data files'
                    ' (default=DATADIR/merged)')
parser.add_argument('-Z', '--no-zip', action='store_true',
                    help='do not create a zip file (eg. for preliminary testing'
                    ' or usage at ITBC)')
parser.add_argument('-C', '--cleanup', action='store_true',
                    help='delete all files successfully added to the zip file')
args = parser.parse_args()

brd = ibutils.check_breed(args.brd)
pop = args.pop.upper()
_POPBRD = '_' + pop + brd
# country is same as population, with a few exceptions ...
pop2cou = {'CHR':'CHE', 'DEA':'DEU', 'DFS':'DNK', 'FRR':'FRA', 'FRM':'FRA'}
cou = pop2cou.get(pop, pop)

if not os.path.exists(args.datadir):
    print('absolute DATADIR: ' + os.path.abspath(args.datadir))
    print('%s: error: DATADIR does not exist or has incorrect permissions'
          % sys.argv[0])
    sys.exit(1)
if args.mergefiles:
    mergedir = os.path.join(args.datadir, 'merged')
    mergedir = os.path.abspath(args.mergedir if args.mergedir else mergedir)
    if not os.path.exists(mergedir):
        os.makedirs(mergedir)

#============================================================
os.chdir(args.datadir)                  # NB! move to DATADIR
#============================================================

file736 = 'file736' + _POPBRD
filetrt = 'traits' + _POPBRD
fileout = 'file735' + _POPBRD
filelog = '%s_log' % sys.argv[0][:-3] + _POPBRD
if args.verbose:
    print('%s: writing log to %s/%s' % (sys.argv[0], args.datadir, filelog))
log = open(filelog, 'w')

ibutils.dated_msg(sys.argv[0]+': start', log=log)

if args.verbose:
    log.write(sys.argv[0] + ' version=' + version + '\n')
    log.write('absolute DATADIR: %s\n' % os.path.abspath(args.datadir))


log.write('Processing brd=%s pop=%s cou=%s datadir=%s\n'
      % (brd, pop, cou, args.datadir))

def read_file300(brd, pop, trt, dset, args):
    '''read and store data for one dataset (Cf/Cr/Gr/Df) and one trait'''

    records = {}
    file300 = 'file300' + dset + '_' + pop + brd
    if args.verbose:
        log.write('reading %s ...\n' % file300)
    for n, rec in enumerate(open(file300)):
        if rec[12:15] != trt:
            continue
        # top = type of proof (11, 12, 21 etc)
        # off = officially publishable proof (Y/N)
        # sta = bull status
        # nd = n. daughters
        # nh = n. herds
        try:
            xxx, brd1, pop1, trt1, aid, top, off, sta, nd, nh, edc, rel, ebv = \
                 rec.strip().split()
        except:
            msg = '%s: error: cannot parse record from file %s\nrecord: %s\n'\
                  % (sys.argv[0], file300, rec)
            log.write(msg)
            sys.stderr.write(msg)
            sys.exit(-1)
        if n == 0 and brd1 != brd:
            msg = '%s: error: bad breed code in file %s\n%s instead of %s\n'\
                % (sys.argv[0], file300, brd1, brd)
            log.write(msg)
            sys.stderr.write(msg)
            sys.exit(99)
        if n == 0 and pop1 != pop:
            msg = '%s: error: bad pop code in file %s\n%s instead of %s\n'\
                % (sys.argv[0], file300, pop1, pop)
            log.write(msg)
            sys.stderr.write(msg)
            sys.exit(99)
        records[aid] = (top, off, sta, nd, nh, int(edc), float(rel), float(ebv))
    if args.verbose:
        log.write('stored %6d records for trait %s from file %s\n'
                  % (len(records), trt, file300))
    return records

#------------------------------------------------------------------------------
# read and store birth years from file736
byear = {}
for rec in open(file736):
    aid = rec[4:23]
    byear[aid] = rec[24:28]
if args.verbose:
    log.write('stored %6d records from file %s\n' % (len(byear), file736))

#------------------------------------------------------------------------------
def simple_wls(X, y, weights):
    '''Simple weighted least squares'''
    w = np.sqrt(weights)
    if X.ndim == 1:
        Xw = np.column_stack((w, X*w))
    else:
        Xw = X*w
    yw = y*w
    XWXinv = np.linalg.inv(np.dot(Xw.transpose(), Xw))
    XWy = np.dot(Xw.transpose(), yw)
    b = np.dot(XWXinv, XWy)
    sst = np.sum(weights * (y - np.dot(weights, y) / sum(weights))**2)
    wresid = yw - np.dot(Xw, b)
    sse = np.dot(wresid, wresid)
    bse = np.sqrt(np.diag(XWXinv) * sse / (len(y) - Xw.shape[1]))
    rsquared = 1. - sse / sst
    return b, bse, rsquared

#------------------------------------------------------------------------------
# process files trait by trait, in the order listed in the trait info file
if args.verbose:
    log.write('opening output file %s ...\n' % fileout)
f735 = open(fileout, 'w')
f735.write('735 brd pop trt evaldate m ntest    mean_y     std_y dv    '
           'mean_x     std_x        b0     se_b0      b1   se_b1 '
           'ncand   i_est  Exp_b1   R2 fb year test1_5   pass\n')
for rec in open(filetrt):
    trt, h2, evdate, depvar, min_byear, type2x = rec.strip().split()

    #--------------------------------------------------------------------------
    # read and store all data for this trait
    dataCf = read_file300(brd, pop, trt, 'Cf', args)
    if not dataCf: # skip traits present in file736 but not in file300Cf
        continue
    dataDf = read_file300(brd, pop, trt, 'Df', args)
    dataCr = read_file300(brd, pop, trt, 'Cr', args)
    dataGr = read_file300(brd, pop, trt, 'Gr', args)

    if args.verbose:
        log.write('\n' + '-'*79 + '\n')
        log.write('processing trt=%s h2=%s evdate=%s depvar=%s byear=%s '
                  'type2x=%s\n' % (trt, h2, evdate, depvar, min_byear, type2x))
        log.write('-'*79 + '\n')
    Lambda = 4./float(h2) - 1.

    #==========================================================================
    # CREATE DATA FOR CANDIDATE AND TEST BULLS
    if args.mergefiles:
        fmerge = open(os.path.join(mergedir, trt + '.csv'), 'w')

    nc=0; nt=0; sumc=0.; sumt=0.; ssc=0.; sst=0.
    x1=[]; x2=[]; y=[]; w=[]; ebvc=[]; ebvt=[]
    empty = ('??', '?', '??', 0, 0, 0., 0., 0.)
    # process bulls on Cf file and look for stored records from other datasets
    # - edc/rel/"ebv" are zero if no record is found for other datasets
    for aid in dataCf:
        byr = byear.get(aid, '0000')

        # skip any bulls born before min_byear
        if byr < min_byear:
            continue
        # top = type of proof (11, 12, 21 etc)
        # off = officially publishable proof (Y/N)
        # sta = bull status
        # nd = n. daughters
        # nh = n. herds
        top, off, sta, nd, nh, edc, rel, ebv = dataCf[aid]
        topd, offd, stad, ndd, nhd, edcd, reld, dpgm = dataDf.get(aid, empty)
        topr, offr, star, ndr, nhr, edcr, relr, ebvr = dataCr.get(aid, empty)
        topg, offg, stag, ndg, nhg, edcg, relg, gebv = dataGr.get(aid, empty)

        # - candidates are young proven bulls from the full dataset (Cf) that
        #   have no daughters in the reduced dataset (Cr)
        # - candidates are usually only domestic bulls, but if they are
        #   fewer than 50 or so the user may optionally include foreign bulls
        candidate = 'Y'
        if edc < 20 or edcr > 0:
            candidate = 'N'
        elif type2x == 'N' and top >= '20' or sta >= '20':
            candidate = 'N'

        # - test bulls are the subset of candidate bulls that also have
        #   a GEBV record (Gr) with non-zero reliability and
        #   a DD/DPGM record (Df) and a parent average record (Cr)
        flag = testbull = 'N'
        no_Df = no_Cr = 0
        if candidate == 'Y':
            ebvc.append(ebv)
            if relg > 0.:
                # bull has gebv record (file300Gr) with reliability > 0
                flag = 'Y'
                if not (aid in dataDf and aid in dataCr):
                    if not aid in dataDf:
                        # bull has no DD/D_PGM record (file300Df)
                        no_Df += 1
                    if not aid in dataCr:
                        # bull has no EBVr record (file300Cr)
                        no_Cr += 1
                else:
                    testbull = 'Y'
                    ebvt.append(ebv)
                    x1.append(gebv)
                    x2.append(ebvr)
                    y.append(dpgm)
                    # w.append(edcd)
                    w.append(edcd/(edcd+Lambda))

        if args.mergefiles:
            # create a merged dataset for additional checks with SAS, R, etc.
            aid = aid.replace(' ', '~')
            fmerge.write(','.join((aid, byr, flag, candidate, testbull,
                                   top, off, sta, nd, nh,
                                   'Cf', str(edc), str(rel), str(ebv),
                                   'Df', str(edcd), str(reld), str(dpgm),
                                   'Cr', str(edcr), str(relr), str(ebvr),
                                   'Gr', str(edcg), str(relg), str(gebv))))
            fmerge.write('\n')
    if args.mergefiles:
        fmerge.close()

    if no_Df > 0:
        log.write('%s: WARNING: there were %d potential TEST bulls with no'
                  ' Df record!\n' % (trt, no_Df))
    if no_Cr > 0:
        log.write('%s: WARNING: there were %d potential TEST bulls with no'
                  ' Cr record!\n' % (trt, no_Cr))
    if len(ebvc) == 0:
        log.write('%s: no candidate bulls found for this trait\n' % trt)
        continue
    if len(ebvt) == 0:
        log.write('%s: no test bulls found for this trait\n' % trt)
        continue

    log.write('\nSummary statistics on candidate bulls (CB) and test bulls'
              ' (TB)\n')
    log.write('-'*62 + '\n')
    log.write('Trait Variable         N      Mean       Std      Min     '
              ' Max\n')
    log.write('-'*62 + '\n')
    fmt = '%s   %s %6d %9.3f %9.3f %8.2f %8.2f\n'
    zc = np.array(ebvc); ncb = len(zc)
    log.write(fmt % (trt, 'CB EBV     ', ncb, zc.mean(), zc.std(ddof=1),
                     zc.min(), zc.max()))
    zt = np.array(ebvt); ntb = len(zt)
    log.write(fmt % (trt, 'TB EBV     ', ntb, zt.mean(), zt.std(ddof=1),
                     zt.min(), zt.max()))
    y = np.array(y); n = len(y)
    log.write(fmt % (trt, 'TB DPGM(y) ', n, y.mean(), y.std(ddof=1),
                     y.min(), y.max()))
    x1 = np.array(x1)
    log.write(fmt % (trt, 'TB GEBV(x1)', n, x1.mean(), x1.std(ddof=1),
                     x1.min(), x1.max()))
    x2 = np.array(x2)
    log.write(fmt % (trt, 'TB EBVr(x2)', n, x2.mean(), x2.std(ddof=1),
                     x2.min(), x2.max()))
    log.write('-'*62 + '\n')

    #---------------------------------------------------------------------------
    # calculate test stats
    # model 1: D_PGM = b0 + b1*GEBV + e
    b1, bse1, R2_1 = simple_wls(x1, y, w)
    # model 2: D_PGM = b0 + b1*EBVr + e
    b2, bse2, R2_2 = simple_wls(x2, y, w)

    # estimate selection differential and search for corresponding p and x
    i_est = (zt.mean() - zc.mean()) / zc.std(ddof=1)
    if args.verbose:
        log.write('\nDetails of GEBVtest calculations\n')
        log.write('%s i_est = (%0.3f - %0.3f) / %0.3f = %0.3f\n'
                  % (trt, zt.mean(), zc.mean(), zc.std(ddof=1), i_est))
    i_est = abs(i_est)
    if i_est < 0.0001:
        i = x = 0.; p = 1.
    else:
        x = 0.001 * (np.arange(5000, -5000, -1))
        d = np.exp(-x**2/2) / np.sqrt(2*np.pi)
        p = 0.001 * np.cumsum(d)
        i = np.cumsum(x * d) / np.cumsum(d)
        for k in range(len(x)):
            if i[k] <= i_est:
                p = p[k]
                x = x[k]
                break

    # calculate E(b1) given selection
    k = i_est * (i_est - x)
    R2b = R2_1 / (1. - k + k * R2_1)
    E_b1 = (1. - k) / (1. - k * R2b)
    t_val = (b1[1] - E_b1) / bse1[1]
    # pass or fail???
    pass1 = 'Y' if abs(t_val) < 2. else 'N'
    pass2 = 'Y' if abs(b1[1] - E_b1) < 0.1 else 'N'
    pass3 = 'Y' if b1[1] >= E_b1 else 'N'
    pass4 = 'Y' if R2_1 > R2_2 else 'N'
    pass5 = 'Y' if b1[1] <1.2  else 'N'
    if (pass1 == 'Y' or pass2 == 'Y' or pass3 == 'Y') and pass4 == 'Y' and pass5 == 'Y':
        pass6 = 'PASS'
    else:
        pass6 = 'FAIL'

    log.write('%s p=%0.3f  x=%6.3f  i=%0.3f k=%0.3f  R2b=%0.3f   E(b1)=%0.3f\n'
              % (trt, p, x, i_est, k, R2b, E_b1))
    log.write('%s b1=%0.3f  se=%0.3f  E(b1)=%0.3f  t=%0.2f  R2_1=%0.1f  '
              'R2_2=%0.1f\n'
              % (trt, b1[1], bse1[1], E_b1, t_val, 100.*R2_1, 100.*R2_2))
    log.write('%s passes t-test=%s  bio-test=%s b1>1=%s R2-test=%s b1<1.2=%s overall=%s\n\n'
          % (trt, pass1, pass2, pass3, pass4, pass5, pass6))

    fmt = '735 %s %s %s %s %d%6d%10.4f%10.4f %s%10.4f%10.4f%10.4f%10.4f'\
        '%8.4f%8.4f%6d%8.4f%8.4f%5.1f  %s %s %s-%s-%s-%s-%s %s\n'
    f735.write(fmt % (brd, pop, trt, rundate, 1, ntb, y.mean(), y.std(ddof=1),
                      depvar, x1.mean(), x1.std(ddof=1), b1[0], bse1[0],
                      b1[1], bse1[1], ncb, i_est, E_b1, 100.*R2_1,
                      type2x, min_byear, pass1, pass2, pass3, pass4, pass5, pass6))
    f735.write(fmt % (brd, pop, trt, rundate, 2, ntb, y.mean(), y.std(ddof=1),
                      depvar, x2.mean(), x2.std(ddof=1), b2[0], bse2[0],
                      b2[1], bse2[1], ncb, 0.0, 1.0, 100.*R2_2,
                      type2x, min_byear, '-', '-', '-', '-', '-', '----'))
f735.close()

ibutils.dated_msg(sys.argv[0]+': end', log=log)
log.close()

# end here if no zip file is to be created (-Z option)
if args.no_zip:
    sys.exit(0)

# prepare the zip file
files = [name + _POPBRD for name in
         ['file300Cf', 'file300Df', 'file300Cr', 'file300Gr',
          'file736', 'traits', 'file735']] + [filelog]
filezip = 'gt%s%s.zip' % (rundate[2:6], _POPBRD)
z = zipfile.ZipFile(filezip, 'w', zipfile.ZIP_DEFLATED)
for fname in files:
    z.write(fname)
z.close()
if args.verbose:
    print('%s: files zipped to %s' %
          (sys.argv[0], os.path.join(args.datadir, filezip)))

if args.cleanup:
    # delete all files added to the zip file
    for fname in files:
        os.unlink(fname)
