#! /usr/bin/env python


##############################################################################
#    
#    This file is part of BATMON
#
#    BATMON (BAyes Test for MONotonicity) is a Bayesian Procedure for testing 
#    monotonicity of a regression
#    function. 
#    Copyright (C) 2012  JB Salomond 
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>
#
##############################################################################





 
"""
Needed packages

"""
from numpy import *
from scipy import *
from scipy.stats import poisson, norm
from scipy.stats import histogram
from scipy.stats import poisson
from scipy.stats import invgamma

import rpy as R
import sys
import numpy


def Hk(omega):
	"""
	Compute the statistic H(omega,k). c.f. paper for details. 
	Take as arguement a vector of length k containig the height of the jump. 
	"""	
	k = len(omega)
	omega = array(omega)
	b = zeros(k-1,dtype = 'float')
	for j in range(1,k):
			b[j-1] = max(omega[j] - omega[0:j])
	return max(b)


def CreaTDC(n,k):
	"""
	Create a matrix of dimention k*n such that the entry line i col. j equal 1 if x_i is in I_j in the fixed design case.
	"""
        Lu = floor(linspace(1,k,num = k)*n/k).astype('int')
        Ld = append(0,Lu[0:(k-1)])
        Lng = Lu - Ld
        M0 = zeros((k,n),dtype = 'int') 
        lign = array(repeat(range(0,k),Lng),dtype = 'int')
        col = array(range(0,n),dtype = 'int')
        ind = [lign,col]
        M0[ind] = 1
        return M0   


def TEST(n,y,k,X,Nsim,mu=1,alpha=1,beta=1,m=0):
	"""
	Takes as arguments 
	n     : (int.) the length of the dataset
	y     : (vec.) the data points
	X     : (vec.) the design point (fixed design case) 
	Nsim  : (int.) Number of simulations of (\omega,\sigma)|k,Y^n 
	mu    : (flo.) hyperparameter for the variance of omega
	alpha : (flo.) hyperparameter for sigma
	beta  : (flo.) hyperparameter for sigma
	m     : (flo.) hyperparameter for the mean of omega
	
	return 1 if H(omega,k) > M_n^k
	"""
	tdc = CreaTDC(n,k)
	k = float(k)
	n = float(n) 
	ni = array(tdc.sum(1))
	nif = array(ni,dtype='float')
	Yi = array(dot(tdc,y),dtype='float')
	Yi2 = array(dot(tdc,y**(2.0)),dtype='float')
	ny = divide((ni*mu*(Yi/ni - m)**(2.0)),ni+mu)
	btilde = beta + 0.5*(sum(Yi2 - ((Yi)**(2.0))/nif) + sum(ny) ) 
	si = invgamma.rvs(alpha+0.5*n, scale=btilde,size = int(Nsim)) 
	M0 = 2.33
	M = M0*(mean(sqrt(si)))
	omega = zeros((int(Nsim),k))
	for i in range(0,int(Nsim)):
		postmean = (m*mu+Yi)/(nif+mu)
		postvar = diag(si[i]/(nif+mu))
		omega[i,:] = random.multivariate_normal(postmean,postvar,size=1)
	rest = apply_along_axis(Hk,1,omega)

	mnk = M*sqrt(k)/sqrt(n)

	return mean((rest > mnk)*1)
    

def Cxk(n,Y,k,mu=1,alpha=1,beta=1,m=0,lamb = 0.5):
	"""
	Takes as arguments 
	n     : (int.) the length of the dataset
	y     : (vec.) the data points
	X     : (vec.) the design point (fixed design case) 
	Nsim  : (int.) Number of simulations of (\omega,\sigma)|k,Y^n 
	mu    : (flo.) hyperparameter for the variance of omega
	alpha : (flo.) hyperparameter for sigma
	beta  : (flo.) hyperparameter for sigma
	m     : (flo.) hyperparameter for the mean of omega
	lamb  : (flo.) hyperparameter for k
	
	return the log posterior of k|Y^n
	
	"""
	tdc = CreaTDC(n,k) 
	ni = array(tdc.sum(1))
	nif = array(ni,dtype='float')
	Yi = array(dot(tdc,Y),dtype='float')
	Yi2 = Y - repeat(Yi/nif,ni)
	ny = divide((ni*mu*(Yi/ni - m)**(2.0)),ni+mu)
	btilde = beta + 0.5*( sum((Yi2)**(2.0)) + sum(ny) )
	
	return (-log(btilde)*(alpha + 0.5*n) +0.5*k*log(mu)-0.5*sum(log(nif+mu)) + (k-2)*log(1-lamb))*(k>=2) + -exp(100)*(k<=1)

"""
===============================================================
Testing on real Data
===============================================================
"""
from entries import * 
Mres = zeros(Nech) # output vector 

n = len(y) # data length 
"""
Choice For the hyperparameters 
"""
lam = 0.05  
m =  mean(y)
mu = 0.05
alpha = (std(y))**(2.0) + 1.0 
beta = 1.0*(std(y))**(4.0)


	
for Z in range(Nech):
	"""
	--------------------------------------
	= MCMC Sampler - Hasting Metropolis  =
	--------------------------------------
	"""
	#
	# Initialisation
	#
	
	BURN = int(0.2*NKsim) # the part that will be burned 
	KS = ones(NKsim+BURN,dtype="int8") # Vector that will contain the simulated k 
	KS[0] = max(3,1*int(n**(numpy.std(y)))) # initialisation 
	k0 = KS[0] 
	accept = zeros(NKsim+BURN) # vector that that will contain entry 1 for accepted simulations 
	npik = zeros(50*n,dtype="int8") # vector that will contain entry 1 if the posterior has already been computed 
	pik = zeros(50*n) # vector that will contain the posterior 
	npik[k0] = 1 
	pik[k0] = Cxk(n,y,k0,mu,alpha,beta,m,lam)
	
	#
	# Sampling 
	#
	
	for i in range(1,NKsim+BURN):
		u = random.sample(1)[0]
		poisson = random.poisson(2,1)[0] + 1
		kprop = (poisson*(u<0.5)) + (-poisson*(u>=0.5))
		k0 = KS[i-1]+kprop
		k0 = max(int(k0),2)		
		if (npik[k0] == 0): 
			pik[k0] = Cxk(n,y,k0,mu,alpha,beta,m,lam)
			npik[k0] = 1
		
		r = min(exp((pik[k0]-pik[KS[i-1]])),1)
		u = random.sample(1)[0]
		accept[i] = (u<r)*1
		KS[i] = k0*accept[i] + KS[i-1]*(1-accept[i])
		
	# Trace - plot 
	R.r.par(mfrow=(2,1))
	R.r.plot(KS,type = 'l',xlab = "",ylab = "")
	R.r.plot(y,type = 'l',xlab = "",ylab = "")
	# Burnin
	KS = KS[BURN:(NKsim+BURN)]
	KS = array(KS,dtype = 'int8')
	
	frq = bincount(KS)
	nstar = len(frq) 
	res0 = zeros(nstar)
	
	"""
	computing pi(H(omega,k) > Mnk|Y^n) 
	"""
	for i in range(nstar): 
	
		if (frq[i]>0):
			res0[i] = TEST(n,y,i,X,frq[i],mu,alpha,beta,m)*frq[i]

	res = sum(res0)/NKsim
	

	Mres[Z] = (mean((res))>0.5)*1 
	print(mean((res))) # the posterior pi(H(omega,k) > Mnk|Y^n) 

	blabla = ["H0 accepted, the function is monotone", "H0 rejected, the function is not monotone"]
	sys.stdout.write("\r\r\r\r\r\r\r\r\r\r\r\r\r\r")
	sys.stdout.write(blabla[int(Mres[Z])])
	sys.stdout.write("\n")
	sys.stdout.flush()
