require(tidyverse)
require(reporttools)
require(ggpubr)
require(grid)
require(gridExtra)
require(cowplot)
require(ggh4x)


#################################read in data#####################################################

exclude_undetected <- TRUE

dataset_master<-read.csv("dataset.csv")
dataset_master<-dataset_master[is.na(dataset_master$cad_v5)==FALSE,]
dataset_master<-dataset_master[!dataset_master$survey=="AHCoS",]
dataset_master<-dataset_master[dataset_master$before_threshold_change==0,]
if (exclude_undetected==TRUE) {
  dataset_master$TB[dataset_master$xpert_pos==0]<-0
}

#################################CAD only#####################################################


dataset<-dataset_master[dataset_master$TB==1,]

aTB_all_CAD<-as.data.frame(matrix(NA,ncol=7,nrow=6))
colnames(aTB_all_CAD)<-c("Symptoms","Type","Threshold","Measure", "mean", "lb", "ub")
aTB_all_CAD$Symptoms<-"aTB"
aTB_all_CAD$Type<-"All"
aTB_all_CAD$Measure<-"CAD"
aTB_all_CAD$Threshold<-c(25,30,40,50,60,70)
sTB_all_CAD<-aTB_all_CAD
sTB_all_CAD$Symptoms<-"sTB"
aTB_xpos_CAD<-aTB_all_CAD
aTB_xpos_CAD$Type<-"xpert_pos"
sTB_xpos_CAD<-sTB_all_CAD
sTB_xpos_CAD$Type<-"xpert_pos"
aTB_all_xpert<-aTB_all_CAD
aTB_all_xpert$Measure<-"xpert"
sTB_all_xpert<-sTB_all_CAD
sTB_all_xpert$Measure<-"xpert"
aTB_xpos_xpert<-aTB_xpos_CAD
aTB_xpos_xpert$Measure<-"xpert"
sTB_xpos_xpert<-sTB_xpos_CAD
sTB_xpos_xpert$Measure<-"xpert"

CADonly_pvalue<-as.data.frame(matrix(NA,ncol=3,nrow=6))
colnames(CADonly_pvalue)<-c("threshold","CAD","above_trace")
CADonly_pvalue[,1]<-c(25,30,40,50,60,70)

r<-1
for (threshold in c(25,30,40,50,60,70)) {
  data<-dataset[dataset$cad_v5>=threshold,]
  
  model<-wilcox.test(data$cad_v5~data$tbsymp)
  CADonly_pvalue$CAD[r]<-format.pval(model$p.value,digits=2)
  model<-glm(data$above_trace~data$tbsymp, family = "binomial")
  CADonly_pvalue$above_trace[r]<-format.pval(summary(model)$coefficients[2,4],digits=2)
  
  data<-dataset[dataset$cad_v5>=threshold,]
  
  data<-data[data$tbsymp==0,]
  aTB_all_CAD$mean[r]<-median(data$cad_v5)
  aTB_all_CAD$lb[r]<-quantile(data$cad_v5,probs=c(0.25))
  aTB_all_CAD$ub[r]<-quantile(data$cad_v5,probs=c(0.75))
  
  l<-binom.test(sum(data$above_trace),length(data$above_trace))
  aTB_all_xpert$mean[r]<-l$estimate
  aTB_all_xpert$lb[r]<-l$conf.int[1]
  aTB_all_xpert$ub[r]<-l$conf.int[2]
  
  data<-data[!data$xpert==0,]
  aTB_xpos_CAD$mean[r]<-median(data$cad_v5)
  aTB_xpos_CAD$lb[r]<-quantile(data$cad_v5,probs=c(0.25))
  aTB_xpos_CAD$ub[r]<-quantile(data$cad_v5,probs=c(0.75))
  
  l<-binom.test(sum(data$above_trace),length(data$above_trace))
  aTB_xpos_xpert$mean[r]<-l$estimate
  aTB_xpos_xpert$lb[r]<-l$conf.int[1]
  aTB_xpos_xpert$ub[r]<-l$conf.int[2]
  
  data<-dataset[dataset$cad_v5>=threshold,]
  
  data<-data[data$tbsymp==1,]
  sTB_all_CAD$mean[r]<-median(data$cad_v5)
  sTB_all_CAD$lb[r]<-quantile(data$cad_v5,probs=c(0.25))
  sTB_all_CAD$ub[r]<-quantile(data$cad_v5,probs=c(0.75))
  
  l<-binom.test(sum(data$above_trace),length(data$above_trace))
  sTB_all_xpert$mean[r]<-l$estimate
  sTB_all_xpert$lb[r]<-l$conf.int[1]
  sTB_all_xpert$ub[r]<-l$conf.int[2]
  
  data<-data[!data$xpert==0,]
  sTB_xpos_CAD$mean[r]<-median(data$cad_v5)
  sTB_xpos_CAD$lb[r]<-quantile(data$cad_v5,probs=c(0.25))
  sTB_xpos_CAD$ub[r]<-quantile(data$cad_v5,probs=c(0.75))
  
  l<-binom.test(sum(data$above_trace),length(data$above_trace))
  sTB_xpos_xpert$mean[r]<-l$estimate
  sTB_xpos_xpert$lb[r]<-l$conf.int[1]
  sTB_xpos_xpert$ub[r]<-l$conf.int[2]
  
  r<-r+1
}

all_data<-rbind(aTB_all_CAD,aTB_all_xpert,aTB_xpos_CAD,aTB_xpos_xpert,
                sTB_all_CAD,sTB_all_xpert,sTB_xpos_CAD,sTB_xpos_xpert)

all_data$Threshold<-factor(all_data$Threshold,levels=c(25,30,40,50,60,70))

data_CADonly<-all_data








######################CAD and symptom in parallel###################################

set.seed(7)
boot_num<-10000
lower_prev_asymp<-1

parallel_pvalue<-as.data.frame(matrix(NA,ncol=3,nrow=7))
colnames(parallel_pvalue)<-c("threshold","CAD","above_trace")
parallel_pvalue[,1]<-c("Test all",25,30,40,50,60,70)

######Test all#####

dataset_all<-dataset_master
dataset_all<-dataset_all[dataset_all$before_threshold_change==0,]
dataset_all<-dataset_all[is.na(dataset_all$cad_v5)==FALSE,]
dataset<-dataset_all[dataset_all$TB==1,]

s_lowCAD<-dataset_all[dataset_all$tbsymp==1,]
s_lowCAD<-s_lowCAD[s_lowCAD$cad_v5<25,]
s_lowCAD<-s_lowCAD[is.na(s_lowCAD$cad_v5)==FALSE,]
s_lowCAD<-select(s_lowCAD, c('TB','cad_v5','above_trace','xpert_pos'))
s_lowCAD[is.na(s_lowCAD$TB)==TRUE,]<-0
s_lowCAD_TB<-s_lowCAD[s_lowCAD$TB==1,]

a_lowCAD<-dataset_all[!dataset_all$tbsymp==1,]
a_lowCAD<-a_lowCAD[a_lowCAD$cad_v5<25,]
a_lowCAD<-a_lowCAD[is.na(a_lowCAD$cad_v5)==FALSE,]
a_lowCAD<-select(a_lowCAD, c('TB','cad_v5','above_trace'))
a_lowCAD[is.na(a_lowCAD$TB)==TRUE,]<-0

a_highCAD<-dataset_all[!dataset_all$tbsymp==1,]
a_highCAD<-a_highCAD[a_highCAD$cad_v5>=25,]
a_highCAD<-a_highCAD[a_highCAD$TB==1,]
a_highCAD<-a_highCAD[is.na(a_highCAD$cad_v5)==FALSE,]
a_highCAD<-select(a_highCAD, c('TB','cad_v5','above_trace'))


num_a_lowCAD<-length(a_lowCAD$TB)

boot_data<-as.data.frame(matrix(NA,ncol=8, nrow=boot_num))
colnames(boot_data)<-c("prev_s","num_a","mean_CAD","mean_trace","boot_CAD_sTB","boot_CAD_aTB","boot_trace_sTB","boot_trace_aTB")
num_s<-length(dataset$cad_v5[dataset$tbsymp==1])
num_a<-length(dataset$cad_v5[dataset$tbsymp==0])
for (b in seq(1,boot_num)) {
  boot_data$prev_s[b]<-mean(sample(s_lowCAD$TB,length(s_lowCAD$TB),replace=TRUE))   #######sample the prevalence of diagnosed TB in CAD<25 sTB
  boot_data$num_a[b]<-round(boot_data$prev_s[b] * num_a_lowCAD * lower_prev_asymp)     #####sample the number of people diagnosed with CAD<25 aTB
  boot_data$mean_CAD[b]<-median(c(a_highCAD$cad_v5,sample(s_lowCAD_TB$cad_v5,boot_data$num_a[b],replace=TRUE)))    #######sample CAD scores in people diagnosed with aTB (>=25 from data, <25 sampled), and record median
  boot_data$mean_trace[b]<-mean(c(a_highCAD$above_trace,sample(s_lowCAD_TB$above_trace,boot_data$num_a[b],replace=TRUE)))    #######sample xpert above/below trace in people diagnosed with aTB (>=25 from data, <25 sampled), and record proportion
  
  all_cad_list<-c(a_highCAD$cad_v5,
                  sample(s_lowCAD_TB$cad_v5,boot_data$num_a[b],replace=TRUE),
                  dataset$cad_v5[dataset$tbsymp==1])
  boot_data$boot_CAD_sTB[b]<-median(sample(all_cad_list,num_s,replace=TRUE))      #######sample sTB CAD scores assuming no diff between a and s, for p-value calculation
  boot_data$boot_CAD_aTB[b]<-median(sample(all_cad_list,(num_a + boot_data$num_a[b]),replace=TRUE))      #######sample aTB CAD scores assuming no diff between a and s, for p-value calculation
  
  all_trace_list<-c(a_highCAD$above_trace,
                    sample(s_lowCAD_TB$above_trace,boot_data$num_a[b],replace=TRUE),
                    dataset$above_trace[dataset$tbsymp==1])
  boot_data$boot_trace_sTB[b]<-mean(sample(all_trace_list,num_s,replace=TRUE))      #######sample sTB xperts assuming no diff between a and s, for p-value calculation
  boot_data$boot_trace_aTB[b]<-mean(sample(all_trace_list,(num_a + boot_data$num_a[b]),replace=TRUE))      #######sample aTB xperts assuming no diff between a and s, for p-value calculation
}
boot_data$CAD_diff<-abs(boot_data$boot_CAD_sTB - boot_data$boot_CAD_aTB)
boot_data$CAD_diff_greater<-0
mean_CAD_diff<-abs(mean(boot_data$mean_CAD - mean(dataset$cad_v5[dataset$tbsymp==1])))
boot_data$CAD_diff_greater[boot_data$CAD_diff>mean_CAD_diff]<-1

boot_data$trace_diff<-abs(boot_data$boot_trace_sTB - boot_data$boot_trace_aTB)
boot_data$trace_diff_greater<-0
mean_trace_diff<-abs(mean(boot_data$mean_trace - mean(dataset$above_trace[dataset$tbsymp==1])))
boot_data$trace_diff_greater[boot_data$trace_diff>mean_trace_diff]<-1

parallel_pvalue$CAD[1]<-format.pval(mean(boot_data$CAD_diff_greater),digits=2)
parallel_pvalue$above_trace[1]<-format.pval(mean(boot_data$trace_diff_greater),digits=2)

quantile(boot_data$num_a,probs=c(0.025,0.5,0.975))

######With CAD screening#####

#CAD scores

CAD_data<-as.data.frame(matrix(NA,ncol=6,nrow=4*8))
colnames(CAD_data)<-c("Symptoms","HIV_ART","Scenario","mean","lb","ub")
CAD_data$HIV_ART<-rep(c("HIV-","HIV+ART-","HIV+ART+", "All"),8)
CAD_data$Scenario<-rep(c("sTB","aTB (no cutoff)",paste0("aTB (cut-off ",c(25,seq(30,70,10)),")")),each=4)
CAD_data$Symptoms<-rep(c("sTB",rep("aTB",7)),each=4)

data<-dataset$cad_v5[dataset$hivart=="HIV-" & dataset$tbsymp==1]
CAD_data$mean[1]<-median(data)
CAD_data$lb[1]<-quantile(data,probs=c(0.25))
CAD_data$ub[1]<-quantile(data,probs=c(0.75))
data<-dataset$cad_v5[dataset$hivart=="HIV+ART-" & dataset$tbsymp==1]
CAD_data$mean[2]<-median(data)
CAD_data$lb[2]<-quantile(data,probs=c(0.25))
CAD_data$ub[2]<-quantile(data,probs=c(0.75))
data<-dataset$cad_v5[dataset$hivart=="HIV+ART+" & dataset$tbsymp==1]
CAD_data$mean[3]<-median(data)
CAD_data$lb[3]<-quantile(data,probs=c(0.25))
CAD_data$ub[3]<-quantile(data,probs=c(0.75))
data<-dataset$cad_v5[dataset$tbsymp==1]
CAD_data$mean[4]<-median(data)
CAD_data$lb[4]<-quantile(data,probs=c(0.25))
CAD_data$ub[4]<-quantile(data,probs=c(0.75))

HIV_list<-rep(c("HIV-","HIV+ART-","HIV+ART+","All"),6)
cutoff_list<-rep(c(25,seq(30,70,10)),each=4)

data_asymp<-dataset[dataset$tbsymp==0,]
for (r in seq(1,24)) {
  if (HIV_list[r]=="All") {
    data<-data_asymp$cad_v5[data_asymp$cad_v5 >= cutoff_list[r]]
  } else {
    data<-data_asymp$cad_v5[data_asymp$hivart==HIV_list[r] & data_asymp$cad_v5 >= cutoff_list[r]]
  }
  CAD_data$mean[r+8]<-median(data)
  CAD_data$lb[r+8]<-quantile(data,probs=c(0.25))
  CAD_data$ub[r+8]<-quantile(data,probs=c(0.75))
}

CAD_data$mean[8]<-median(boot_data$mean_CAD)    #########median CAD score in aTB if all tested
CAD_data$lb[8]<-quantile(boot_data$mean_CAD,probs=c(0.25))    #########lower quartile CAD score in aTB if all tested
CAD_data$ub[8]<-quantile(boot_data$mean_CAD,probs=c(0.75))    #########upper quartile CAD score in aTB if all tested

#Xpert trace

dataset<-dataset_master[dataset_master$TB==1,]

trace_data<-as.data.frame(matrix(NA,ncol=6,nrow=4*8))
colnames(trace_data)<-c("Symptoms","HIV_ART","Scenario","mean","lb","ub")
trace_data$HIV_ART<-rep(c("HIV-","HIV+ART-","HIV+ART+", "All"),8)
trace_data$Scenario<-rep(c("sTB","aTB (no cutoff)",paste0("aTB (cut-off ",c(25,seq(30,70,10)),")")),each=4)
trace_data$Symptoms<-rep(c("sTB",rep("aTB",7)),each=4)

data<-dataset$above_trace[dataset$hivart=="HIV-" & dataset$tbsymp==1]
trace_data$mean[1]<-mean(data)
trace_data$lb[1]<-mean(data) - 1.96*sd(data)/sqrt(length(data))
trace_data$ub[1]<-mean(data) + 1.96*sd(data)/sqrt(length(data))
data<-dataset$above_trace[dataset$hivart=="HIV+ART-" & dataset$tbsymp==1]
trace_data$mean[2]<-mean(data)
trace_data$lb[2]<-mean(data) - 1.96*sd(data)/sqrt(length(data))
trace_data$ub[2]<-mean(data) + 1.96*sd(data)/sqrt(length(data))
data<-dataset$above_trace[dataset$hivart=="HIV+ART+" & dataset$tbsymp==1]
trace_data$mean[3]<-mean(data)
trace_data$lb[3]<-mean(data) - 1.96*sd(data)/sqrt(length(data))
trace_data$ub[3]<-mean(data) + 1.96*sd(data)/sqrt(length(data))
data<-dataset$above_trace[dataset$tbsymp==1]
trace_data$mean[4]<-mean(data)
trace_data$lb[4]<-mean(data) - 1.96*sd(data)/sqrt(length(data))
trace_data$ub[4]<-mean(data) + 1.96*sd(data)/sqrt(length(data))

HIV_list<-rep(c("HIV-","HIV+ART-","HIV+ART+","All"),6)
cutoff_list<-rep(c(25,seq(30,70,10)),each=4)

data_asymp<-dataset[dataset$tbsymp==0,]
for (r in seq(1,24)) {
  if (HIV_list[r]=="All") {
    data<-data_asymp$above_trace[data_asymp$cad_v5 >= cutoff_list[r]]
  } else {
    data<-data_asymp$above_trace[data_asymp$hivart==HIV_list[r] & data_asymp$cad_v5 >= cutoff_list[r]]
  }
  trace_data$mean[r+8]<-mean(data)
  trace_data$lb[r+8]<-mean(data) - 1.96*sd(data)/sqrt(length(data))
  trace_data$ub[r+8]<-mean(data) + 1.96*sd(data)/sqrt(length(data))
}

trace_data$mean[8]<-mean(boot_data$mean_trace)
trace_data$lb[8]<-quantile(boot_data$mean_trace,probs=c(0.025))
trace_data$ub[8]<-quantile(boot_data$mean_trace,probs=c(0.975))


CAD_data$type<-"CAD"
trace_data$type<-"trace"
plot_data<-rbind(CAD_data[CAD_data$HIV_ART=="All",],trace_data[trace_data$HIV_ART=="All",])

plot_data$Scenario<-factor(plot_data$Scenario,
                           levels=c("sTB","aTB (no cutoff)","aTB (cut-off 25)","aTB (cut-off 30)",
                                    "aTB (cut-off 40)","aTB (cut-off 50)","aTB (cut-off 60)","aTB (cut-off 70)"),
                           labels=c("sTB","Test all","25","30",
                                    "40","50","60","70"))
data_parallel_all<-plot_data

#p-values
r<-2
for (co in c(25,30,40,50,60,70)) {
  data<-dataset[dataset$cad_v5 >= co|dataset$tbsymp==1,]
  
  model<-wilcox.test(data$cad_v5~data$tbsymp)
  parallel_pvalue$CAD[r]<-format.pval(model$p.value,digits=2)
  
  model<-glm(data$above_trace~data$tbsymp, family = "binomial")
  parallel_pvalue$above_trace[r]<-format.pval(summary(model)$coefficients[2,4],digits=2)
  r<-r+1
}



################################Figure################################

scales <- list(
  scale_y_continuous(limits = c(0, 105)),
  scale_y_continuous(limits = c(0, 1))
)

plot_parallel_all<-ggplot(data=data_parallel_all) +
  theme_bw() +
  geom_col(aes(x=Scenario,fill=Symptoms,y=mean),position="dodge") +
  geom_errorbar(aes(x=Scenario,group=Symptoms, ymin = lb, ymax = ub),position="dodge") +
  theme(axis.text.x = element_text(angle=90)) +
  theme(legend.position="none") +
  ggtitle("a) CXR and symptom screen in parallel") +
  theme(axis.title.y = element_blank()) +
  theme(axis.title.x = element_text(margin = margin(t = 7, r = 0, b = 0, l = 0))) +
  scale_x_discrete(name="CAD score threshold") +
  facet_wrap(~type,scales="free",
             labeller = labeller(type = 
                                   c("CAD" = "i) Median CAD score",
                                     "trace" = "ii) Proportion with Xpert\ngreater than trace"))) +
  facetted_pos_scales(y = scales)

if (exclude_undetected == TRUE) {
  #excluding xpert undetected
  plot_parallel_all_text <- data.frame(
    label = c("**","***","***"),
    type   = c("CAD","CAD","CAD"),
    x     = c(6,7,8),
    y     = c(92,96,100)
  )
} else {
  #including xpert undetected
  plot_parallel_all_text <- data.frame(
    label = c("*","***","***"),
    type   = c("CAD","CAD","CAD"),
    x     = c(6,7,8),
    y     = c(92,96,100)
  )
}


plot_parallel_all<-plot_parallel_all + geom_text(
  data    = plot_parallel_all_text,
  mapping = aes(x = x, y = y, label = label),size=5
)

plot_CADonly_all<-ggplot(data=data_CADonly[data_CADonly$Type=="All",]) +
  theme_bw() +
  geom_col(aes(x=Threshold,fill=Symptoms,y=mean),position="dodge") +
  geom_errorbar(aes(x=Threshold,group=Symptoms, ymin = lb, ymax = ub),position="dodge") +
  theme(axis.text.x = element_text(angle=90)) +
  theme(legend.position="none") +
  ggtitle("b) CXR screening only") +
  theme(axis.title.y = element_blank()) +
  theme(axis.title.x = element_text(margin = margin(t = 25, r = 0, b = 0, l = 0))) +
  scale_x_discrete(name="CAD score threshold") +
  facet_wrap(~Measure,scales="free",
             labeller = labeller(Measure = 
                                   c("CAD" = "i) Median CAD score",
                                     "xpert" = "ii) Proportion with Xpert\ngreater than trace"))) +
  facetted_pos_scales(y = scales)


p_legend<-as.data.frame(matrix(NA,ncol=3,nrow=3))
colnames(p_legend)<-c("value","col2","Significance")
p_legend[,1]<-c(1,2,3)
p_legend[,2]<-c(1,2,3)
p_legend[,3]<-c("* p < 0.1","** p < 0.05","*** p < 0.01")

plot_legend<-ggplot(data=data_parallel_all) +
  theme_bw() +
  geom_col(aes(x=Scenario,fill=Symptoms,y=mean),position="dodge") +
  geom_point(data=p_legend,aes(x=value,y=col2,shape=Significance)) +
  geom_errorbar(aes(x=Scenario,group=Symptoms, ymin = lb, ymax = ub),position="dodge") +
  theme(axis.text.x = element_text(angle=90)) +
  facet_wrap(~type,scales="free") +
  scale_shape_manual(
    name = "Significance",
    labels = c("p < 0.1","p < 0.05","p < 0.01"),
    values = c(paste0("*"), paste0("*"), paste0("*")))
#values = c(paste0("\U+002A"), paste0("\U+002A"), paste0("\U+002A")))
#guides(colour = guide_legend(override.aes = list(size = 0)))
legend <- cowplot::get_legend(plot_legend)


if (exclude_undetected == TRUE) {
  pdf(paste0("Figure_screening_tool_comparison.pdf"), width=11,height=5,onefile=FALSE)
} else {
  pdf(paste0("Figure_screening_tool_comparison-include xpert neg.pdf"), width=11,height=5,onefile=FALSE)
}


ggarrange(
  ggarrange(plot_parallel_all,plot_CADonly_all,legend,ncol=3,widths=c(2,2,0.4)),
  nrow=1,heights=c(1))
dev.off()

