#!/usr/bin/python
# -*- coding: utf-8 -*-

#~ trendtest
#~ Copyright (C) 2013 Interbull Centre
#~
#~ This program is free software: you can redistribute it and/or modify
#~ it under the terms of the GNU General Public License as published by
#~ the Free Software Foundation, either version 3 of the License, or
#~ (at your option) any later version.
#~
#~ This program is distributed in the hope that it will be useful,
#~ but WITHOUT ANY WARRANTY; without even the implied warranty of
#~ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#~ GNU General Public License for more details.
#~
#~  http://www.gnu.org/licenses/
#
# program trendtest3.py

'''
Perform trend validation by method 3 for one breed-population combination,
for all traits.
'''

# Revision history:
# 2013.10.21 GJansen - original version

import os
import sys
import argparse
from datetime import date
import numpy as np
import ibutils

# Make this code compatible with a wide range of numpy versions:
if 'float' not in np.__dict__:
    np.float = float

testdate = date.today().strftime('%Y%m%d')

# to see help summary: python trendtest3.py --help
epilog = 'See detailed instructions at: '\
    'https://wiki.interbull.org/public/TrendTest_Software?action=print'''

# see http://docs.python.org/2.7/howto/argparse.html
parser = argparse.ArgumentParser(epilog=epilog)
parser.add_argument('brd',
                    help='evaluation breed code (BSW/GUE/JER/HOL/RDC/SIM)')
parser.add_argument('pop',
                    help='population code (same as country code except for'\
                        ' CHR/DEA/DFS/FRR/FRM)')
parser.add_argument('datadir',
                    help='absolute or relative path to data files')
parser.add_argument('-v', '--verbose', action='store_true',
                    help='increase output verbosity')
parser.add_argument('-s', '--samples', default='1000',
                    help='number of bootstrap samples (default=1000)')
parser.add_argument('-c', '--controlfile',
                    help='path/name of the control file (default=DATADIR/'
                    'file305_POPBRD)')
parser.add_argument('-m', '--mergefiles', action='store_true',
                    help='write merged data files (for independent data'
                    ' checks)')
parser.add_argument('-M', '--mergedir',
                    help='absolute or relative path for merged data files'
                    ' (default=DATADIR/merged3)')
args = parser.parse_args()

brd = ibutils.check_breed(args.brd)
pop = args.pop.upper()
_POPBRD = '_' + pop + brd

if not os.path.exists(args.datadir):
    print('absolute DATADIR: ' + os.path.abspath(args.datadir))
    print('%s: error: DATADIR does not exist or has incorrect permissions'
          % sys.argv[0])
    sys.exit(1)
if args.mergefiles:
    mergedir = os.path.join(args.datadir, 'merged3')
    mergedir = os.path.abspath(args.mergedir if args.mergedir else mergedir)
    if not os.path.exists(mergedir):
        os.makedirs(mergedir)

#============================================================
os.chdir(args.datadir)                  # NB! move to DATADIR
#============================================================

filebd = 'bdate' + _POPBRD
file305 = args.controlfile if args.controlfile else ('file305' + _POPBRD)
fileout = 'file313' + _POPBRD
filelog = 'tt3' + _POPBRD + '.log'
if args.verbose:
    print('%s: writing log to %s/%s' % (sys.argv[0], args.datadir, filelog))
log = open(filelog, 'w')

ibutils.dated_msg(sys.argv[0]+': start', log=log)
log.write(sys.argv[0] + ' version=' + ibutils.version + '\n')

if args.verbose:
    log.write('absolute DATADIR: %s\n' % os.path.abspath(args.datadir))


log.write('Processing brd=%s pop=%s datadir=%s\n'
      % (brd, pop, args.datadir))

def bomb(msg, rc):
    log.write(msg); sys.stderr.write(msg); sys.exit(rc)

def read_file30X(brd, pop, trt, filename, args):
    '''read and store data for one trait (file300 or file303)'''

    records = {}
    if args.verbose:
        log.write('reading %s ...\n' % filename)
    rec_type = filename[4:7]
    for n, rec in enumerate(open(filename)):
        if rec[12:15] != trt:
            continue
        if rec[:3] != rec_type:
            bomb('%s: error: bad rec_type in file %s\n%s instead of %s\n'\
                     % (sys.argv[0], filename, rec[:3], rec_type), 99)
        okay = True
        if rec_type == '300':
            # top = type of proof (11, 12, 21 etc)
            # off = officially publishable proof (Y/N)
            # sta = bull status
            # nd = n. daughters
            # nh = n. herds
            try:
                xxx,brd1,pop1,trt1,aid,top,off,sta,nd,nh,edc,rel,ebv = \
                    rec.strip().split()
            except:
                okay = False
        else: # rec_type 303
            try:
                xxx,brd1,pop1,trt1,aid,byear,top,nd,ebv,n1,n2,n3,n4,year1d = \
                    rec.strip().split()
            except:
                okay = False
        if not okay:
            bomb('%s: error: cannot parse record from file %s\nrecord: %s\n'\
                     % (sys.argv[0], filename, rec), 98)
        if n == 0 and brd1 != brd:
            bomb('%s: error: bad breed code in file %s\n%s instead of %s\n'\
                     % (sys.argv[0], filename, brd1, brd), 97)
        if n == 0 and pop1 != pop:
            bomb('%s: error: bad pop code in file %s\n%s instead of %s\n'\
                     % (sys.argv[0], filename, pop1, pop), 96)
        if rec_type == '300':
            records[aid] = (top, off, sta, int(nd), int(nh), int(edc),
                            float(rel), float(ebv))
        else:
            records[aid] = (top, int(nd), float(ebv), int(n1), int(n2),
                            int(n3), int(n4), int(year1d))
    if args.verbose:
        log.write('stored %6d records for trait %s from file %s\n'
                  % (len(records), trt, filename))
    return records

#------------------------------------------------------------------------------
# read and store birth years
byear = {}
for rec in open(filebd):
    aid = rec[:19]
    byear[aid] = rec[20:24].replace(' ', '0')
if args.verbose:
    log.write('stored %6d records from file %s\n' % (len(byear), filebd))


#------------------------------------------------------------------------------
# process files trait by trait, in the order listed in the contol file
if args.verbose:
    log.write('opening output file %s ...\n' % fileout)
fout = open(fileout, 'w'); first_rec = True
for rec in open(file305):
    if rec[0] == '#': continue # skip header
    try:
        trtg, trt, evdate, herit, siresd, merit, type2x, min_hrd, min_dgh, \
            byr1, miny, maxy, R, yy_mon, chg = rec.strip().split()
    except:
        print('error: could not parse this line from file ' + file305)
        print('line ::', rec)
        print('please review the documentation')
        sys.exit(9)

    yyyy = int(evdate[:4]) # yyyy is year of current evaluation
    merit = merit.upper()
    if merit not in ['B+', 'B-', 'T+', 'T-']:
        print('error: genetic merit must be B+/B-/T+/T- or TA in file305')
        print('line ::', rec)
        sys.exit(9)
    bvta = 'BV' if merit[0] == 'B' else 'TA'
#   siresd from MACE is actually        genetic variance for EBV countries and 
#                                1/4 of genetic variance for PTA countries  VP+HJ 20131022
    sdg = float(siresd) * (2. if bvta == 'TA' else 1.)
    k = (4. - float(herit)) / float(herit) # variance ratio
    min_hrd = int(min_hrd)
    min_dgh = int(min_dgh)
    R = float(R)

    if args.verbose:
        log.write('\n\n' + '-'*80 + '\n')
        log.write('file305  trt miny maxy corr herds daus type2x bvta  herit'
                  '    sdg\n')
        log.write('inputs   %s %s %s%5.2f %5d%5d    %s    %s  %s  %s\n' %
                  (trt, miny, maxy, R, min_hrd, min_dgh, type2x, bvta,
                   herit, sdg))
        log.write('-'*80 + '\n')
    #--------------------------------------------------------------------------
    # read and store all data for this trait
    data300 = read_file30X(brd, pop, trt, 'file300'+_POPBRD, args)
    data303 = read_file30X(brd, pop, trt, 'file303'+_POPBRD, args)

    #==========================================================================
    # edit data and merge proofs
    y=[]; x=[]; t=[]; w=[]; b=[]; ninc=0; nbig=0
    empty = ('??', '?', '??', 0, 0, 0., 0., 0.)
    # process bulls on 303 file and look for stored records from 300 dataset
    # - edc/rel/"ebv" are zero if no record is found for
    for aid in data303:
        byr = byear.get(aid, '0000')
        # first edit: skip any bulls born before min_byear or after max_byear
        if byr < miny or byr > maxy:
            continue
        # fetch data from YYYY-4 and newly added daughters each year
        topx, nx, ebvx, n1, n2, n3, n4, year1d = data303[aid]
        # fetch current data
        if aid not in data300:
            log.write('warning: bull in file303 but not file300: %s\n' % aid)
            continue
        top, off, sta, ny, nh, edc, rel, ebvy = data300[aid]
        # top = type of proof (11, 12, 21 etc)
        # off = officially publishable proof (Y/N)
        # sta = bull status
        # nh = n. herds
        # ny = n. daughters in yyyy
        # nx = n. daughters in yyyy-4

        # additional data edits
        keep = 'N'
        # -- is bull a first crop domestic AI bull
        #    (or foreign bull for small pops)
        if (topx == '11' and sta != '20') or (type2x == 'Y' and topx == '21'):
            keep = 'Y'
        # -- does bull meet minimums for daus/herds/edc
        #    note: the CoP calls for min 10 herds in YYYY-4, but that edit
        #    cannot be applied because nh is missing in (most) 04x files
        if nh < min_hrd:
            keep = 'N'
        if (ny < min_dgh) or (nx < min_dgh):
            keep = 'N'

        nj = np.array([n1, n2, n3, n4])
        newall = nj.sum()
        # -- new edit: at least one daughter must be added
        #if newall == 0:
        if newall <= 0 or ny <= nx: # 2015.04.22 GJ/VP - revised edit
            continue
        # check that added daughters are consistent with total daughters
        if ny != nx + newall:
            ninc += 1
            if ninc <= 10 or abs(ny - (nx + newall)) >= 10:
                log.write('MISMATCH: %s %s nx %d+%d+%d+%d+%d=%d != %d ny\n' %
                          (aid, trt, nx, n1, n2, n3, n4, nx+newall, ny))
        if abs(ny - (nx + newall)) >= 10: # skip big mismatches
            nbig += 1
            keep = 'N'

        if keep == 'Y':
            # okay, process good record and store in data vectors
            # compute time covariate
            # -- substitute mean year of calving of first crop daus by byear+4
            # -- (because user supplied values are incomplete or unreliable?)
            year1d = float(byr) + 4.
            j = np.array([1,2,3,4])
            mj = (yyyy - 4.) + j - year1d
            t_i = sum(nj * mj) / ny
            # compute weight
            w_i = (ny + k)**2 / (nx*(nx+k)*(1.-R**2) + \
                                 (ny-nx)*(k + (ny-nx)*k/(nx+k)))
            y.append(ebvy)
            x.append(ebvx)
            t.append(t_i)
            w.append(w_i)
            #w.append(1.0)
            b.append(float(byr))
        else:
            t_i = 0.; w_i = 1.

        if args.mergefiles:
            # create a merged dataset for additional checks with SAS, R, etc.
            if len(y) <= 1:
                fmerge = open(os.path.join(mergedir, trt + '.csv'), 'w')
            fmerge.write(','.join((aid, byr, keep, top, off, sta,
                                   'f300', str(ny), str(nh), str(edc),
                                   str(rel), str(ebvy),
                                   'f303', topx, str(nx), str(ebvx), str(n1),
                                   str(n2), str(n3), str(n4), str(year1d),
                                   '%0.3f,%0.3f' % (t_i, w_i)
                                   )))
            fmerge.write('\n')

    if args.mergefiles:
        fmerge.close()

    n = len(y)
    if n == 0:
        if args.verbose:
            log.write('warning: no merged records after edits for trait %s\n\n'\
                      % (trt))
        continue
    if n < 4:
        log.write('skipping test for trait %s: only %d bulls passed edits\n\n'
                  % (trt, n))

    if not args.verbose: log.write('\n')
    log.write('\nSummary statistics for trait "%s" (N=%d bulls)\n' % (trt, n))
    log.write('-'*62 + '\n')
    log.write('Trait Variable              Mean       Std      Min      Max\n')
    log.write('-'*62 + '\n')
    fmt = '%s   %s %9.3f %9.3f %8.2f %8.2f\n'
    y = np.array(y)
    sdy = y.std(ddof=1)
    log.write(fmt % (trt, 'EBV CURRENT  (y)', y.mean(), sdy, y.min(), y.max()))
    x = np.array(x)
    sdx = x.std(ddof=1)
    log.write(fmt % (trt, 'EBV YYYY-4   (x)', x.mean(), sdx, x.min(), x.max()))
    t = np.array(t)
    log.write(fmt % (trt, 'TIME VARIATE (t)', t.mean(), t.std(ddof=1),
                     t.min(), t.max()))
    w = np.array(w)
    log.write(fmt % (trt, 'WEIGHTS      (w)', w.mean(), w.std(ddof=1),
                     w.min(), w.max()))
    b = np.array(b)
    log.write(fmt % (trt, 'BIRTH YEAR      ', b.mean(), b.std(ddof=1),
                     b.min(), b.max()))
    log.write('-'*62 + '\n\n')

    # check ranges of std's for current and old proofs and supplied SD
    warnings = []
    if n < 10:
        log.write('WARNING for trait %s: too few bulls for meaningful test'
                  ' (N=%d)\n' % (trt, n))
        warnings.append('FEW_BULLS')
    if sdx/sdy < 0.85 or sdx/sdy > 1.15:
        log.write('WARNING for trait %s: XY_SCALE_WARNING\n'
              ' => Ratio SD(x)/SD(y) outside expected range (0.85 to 1.15).\n'
              ' => Current and old evaluations need to be on the same scale.\n'
                  % trt)
        warnings.append('XY_SCALE_WARNING')
    if bvta == 'BV' and (sdy/sdg < 0.7 or sdy/sdg > 1.4):
        log.write('WARNING for trait %s: SDG_BV_WARNING\n'
              ' => Ratio SDy/SDg outside expected range (0.7 to 1.4).\n'
              ' => SD of proofs on BV scale should be roughly the same as SDg.\n'
                  % trt)
        warnings.append('SDG_BV_WARNING')
    elif bvta == 'TA' and (sdy/sdg < 0.35 or sdy/sdg > 0.7):
        log.write('WARNING for trait %s: SDG_TA_WARNING\n'
              ' => Ratio SDy/SDg outside expected range (0.35 to 0.7).\n'
              ' => SD of proofs on TA scale should be roughly half of SDg.\n'
                  % trt)
        warnings.append('SDG_TA_WARNING')

    #---------------------------------------------------------------------------
    # calculate test stats
    # model: y = b0 + b1*x + b2*t +  e
    X = np.column_stack((np.ones(n), x, t))
    b, bse, R2, rmse = ibutils.simple_wls(X, y, w)

    sdg = float(sdg)
    log.write('Regression of current EBV (y) on previous EBV (x) and TIME '
              'variate (t)\n')
    log.write('-'*72 + '\n')
    log.write('         _____ x _____     _____ t _____\n')
    log.write('Trait    Slope    s.e.     Slope    s.e.    %R^2     RMSE\n')
    log.write('-'*72 + '\n')
    fmt = '%s   %8.3f %7.3f  %8.3f %7.3f  %6.1f %8.3f\n'
    log.write(fmt % (trt, b[1], bse[1], b[2], bse[2], 100.*R2, rmse))
    log.write('-'*72 + '\n')
    log.write('Regression 95%% C.I. for delta [%.3f, %.3f]\n' %
              (b[2] - 1.96 * bse[2], b[2] + 1.96 * bse[2]))

    # 95% confidence interval by bootstrap
    n_samples = int(args.samples)
    delta = np.zeros(n_samples, np.float)
    i = 0
    while i < n_samples:
        sample = np.random.randint(n, size=n)
        try:
            delta[i]=ibutils.simple_wls(X[sample,], y[sample], w[sample])[0][2]
            i += 1
        except np.linalg.linalg.LinAlgError:
            # skip samples with singular LHS (for very small n)
            pass
    delta = np.sort(delta)
    lower = delta[int(round(0.025 * n_samples) - 1)]
    upper = delta[int(round(0.975 * n_samples) - 1)]

    # method 3 pass or fail
    # -- statistical validation test (95% C.I. contains 0)
    pass3a = 'PASS' if lower < 0. and upper > 0. else 'FAIL'
    log.write('Bootstrap  95%% C.I. for delta [%.3f, %.3f] ==> %s for %s\n' %
              (lower, upper, pass3a, trt))
    # -- biological test
    testval = abs(b[2]) / sdg
    if bvta == 'BV' and testval < 0.02:
        pass3b = 'PASS'
    elif bvta == 'TA' and testval < 0.01:
        pass3b = 'PASS'
    else:
        pass3b = 'FAIL'
    log.write('Biol. test: abs(%0.3f)/%0.3f = %0.4f (%s)  ==> %s for %s\n' %
              (b[2], sdg, testval, bvta, pass3b, trt))

    # -- overall
    pass3 = 'PASS' if pass3a == 'PASS' or pass3b == 'PASS' else 'FAIL'

    if first_rec:
        fout.write('rec brd pop tgrp trt testdate pass   '
                   'delta    lower   upper stat '
                   'testval biol      SDg bv '
                   'bulls    std_y    std_x '
                   'x yyyy miny maxy  herit corr mh md nsamp warnings\n')
        first_rec = False
    # prepare output record in parts
    p1 = '313 %s %s %s %s %s %s' % (brd, pop, trtg, trt, testdate, pass3)
    p2 = ' %7.3f %8.3f %7.3f %s' % (b[2], lower, upper, pass3a)
    p3 = ' %7.3f %s %8.3f %s' % (b[2]/sdg, pass3b, sdg, bvta)
    p4 = ' %5d %8.3f %8.3f' % (n, sdy, sdx)
    p5 = ' %s %4d %s %s %6.4f %4.2f %2d %2d %5d' % \
        (type2x, yyyy, miny, maxy, float(herit), R, min_hrd, min_dgh, n_samples)
    warnings = ' ' + ','.join(warnings) if warnings else ' none'
    fout.write(p1 + p2 + p3 + p4 + p5 + warnings + '\n')
fout.close()

ibutils.dated_msg(sys.argv[0]+': end', log=log)
log.close()
