---
title: "jsCAT Validation"
author: "Wanjing Anya Ma"
date:  "`r Sys.Date()`"
output:
  bookdown::html_document2:
    toc: true
    toc_depth: 4
    theme: cosmo
    highlight: tango
    
---

# Data Preprocessing
## Load Packages
Let's first load the packages that we need for this chapter. 

```{r, message=FALSE, warning = FALSE, echo = FALSE}
library("knitr") # for rendering the RMarkdown file
library("tidyverse") # for data wrangling 
library(dplyr)
library(mirt)
library('ggpubr')
library(Metrics)
library(hash)
library(ggplot2)
library(GGally)
library(gridExtra)
library("catR")
library(corrplot)
library(patchwork)
```

## Settings 
```{r echo = FALSE}
# sets how code looks in knitted document
opts_chunk$set(comment = "")

# suppresses warning about grouping 
options(dplyr.summarise.inform = F)
```

## Data

# Theoretical Simulation
## create resp
```{r}
##we'll start small just so that you can see everything, but you'll want to make this bigger downstream.
## number of items
ni <-200
## number of participants for validation
np <-100
## number of participants for model calibration
np.calibration <- 1000
##now we're going to simulate data according to this model and examine some key properties
set.seed(1)

th<-rnorm(np.calibration, mean = 0, sd = 2)
```

```{r}
################################################################
##now we have to put this stuff together. what we want is a probability of a correct response for a person to an item
##we're going to use what you may know from logistic regression

inv_logit<-function(x) {
  return (exp(x)/(1+exp(x)))
}
```

randomly sample b params from a uniform distribution
```{r}
b.mat.random <- runif(ni, min = -4, max = 4)
# item.bank.random <- item.bank %>%
#    mutate(b = b.mat.random)
```

```{r}
func.create.response <- function(b, th, np, ni) {
  th.mat<-matrix(th,np,ni,byrow=FALSE) #these are the true abilities. we don't observe them,    which will be a real problem for us downstream. but we'll not worry about that today.
#th.mat #abilities, on per row
##now the item level part.
  b.mat <- matrix(rep(b, np), nrow = np, byrow = TRUE)

  pr<-inv_logit(th.mat-b.mat)

  resp <- pr

  for (i in 1:ncol(resp)) {
    resp[,i]<-rbinom(nrow(resp),1,resp[,i])
  }

  return (data.frame(resp))
}
```

```{r}
resp.calibration<- func.create.response(b.mat.random, th, np.calibration, ni)
```

```{r}
mod <- mirt(data.frame(resp.calibration), 1, itemtype="Rasch")

params.b <- data.frame(coef(mod, IRTpars = TRUE, simplify = TRUE))$items.b
```

```{r}
th.sample <- rnorm(np, mean = 0, sd = 2)
resp.sample <- func.create.response(params.b, th.sample, np, ni)
```
```{r}
resp.sample.with.pid <- resp.sample %>% 
  add_column(pid = seq(1, 100, 1)) %>% 
  mutate(pid = paste0("p", pid, "")) %>% 
  relocate(pid)
```

```{r}
#write.csv(resp.sample.with.pid, "../src/data/resp_sample_with_pid.csv")
```

```{r}
th_SE.sample <- as_tibble(fscores(mod, method = "ML", response.pattern = resp.sample, max_theta = 6, min_theta = -6, theta_lim = c(-6, 6)))
```

## catR simulation
### item bank
```{r}
item.bank <- data.frame(a = 1, b = params.b, c = 0, d = 1)
```
```{r}
item.bank.with.item.name <- item.bank %>% 
  mutate(b = round(b, digits = 8)) %>% 
  add_column(item = seq(1, 200, 1)) %>% 
  mutate(item = paste0("X", item, ""))
```
```{r}
#write.csv(item.bank.with.item.name, "../src/wordlist/theoretical.item.bank.csv")
```


### simulation
```{r}
func.catSim <- function(resp, item.bank, method){
  ni = length(resp)
  np = length(resp[[1]])
  list.thetas <- NULL
  
  list.se <- NULL
  
  list.pid <- NULL
  
  # define cat
  start <- list(nrItems = 1, theta = 0)
  test <- list(method = 'ML', itemSelect = method, range = c(-6, 6))
  stop <- list(rule = 'length',thr = ni)
  final <- list(method = 'ML',  range = c(-6, 6))
  for (pid in 1:np){
    res <- randomCAT(itemBank = item.bank, responses = as.numeric(resp[pid,]), start = start, test = test, final = final, stop = stop)
    list.pid <- c(list.pid, rep(pid, ni))
    list.thetas <- c(list.thetas, res$thetaProv)
    list.se <- c(list.se, res$seProv)
  }
  list.trialNumBlock <- rep(1:ni, np)
  length(list.trialNumBlock)
 
  return(data.frame(pid = list.pid, trialNumTotal = list.trialNumBlock, F1 = list.thetas, SE_F1 = list.se))

}
```
```{r}
df.mfi <- func.catSim(resp.sample, item.bank, 'MFI')
df.random <- func.catSim(resp.sample, item.bank, 'random')
```


## jsCAT simulation
```{r}
jsCAT.simulation.df <- read.csv("../data/jsCAT-simulation-data.csv")
```

### Validate theta estimate and standard error of measurement 
```{r}
jsCAT.simulation.df.final.estimate <- jsCAT.simulation.df %>% 
  filter(trialNumTotal == 200, variant == "mfi") %>% 
    mutate(F1 = thetaEstimate, SE_F1 = thetaSE) %>% 
    select(pid, trialNumTotal, F1, SE_F1)

g.1 <- ggplot(mapping = aes(x = th_SE.sample$F1,
                     y = jsCAT.simulation.df.final.estimate$F1)) + 
  geom_point() +
  labs(title = "Theta estimate", 
       x = "Theta estimate from mirt", 
       y = "Theta estimate from jsCAT") +
  geom_abline(color = 'gray') + 
  xlim(-6, 6) + 
  ylim(-6, 6) + 
  coord_equal()

g.2 <- ggplot(mapping = aes(x = th_SE.sample$SE_F1,
                     y = jsCAT.simulation.df.final.estimate$SE_F1)) + 
  geom_point() +
   labs(title = "Theta standard error", 
       x = "Standard error from mirt", 
       y = "Standard error from jsCAT") +
  geom_abline(color = 'gray') + 
  coord_equal()

g.1 + g.2

ggsave("../plots/jsCAT_validation_1.png")
ggsave("../plots/jsCAT_validation_1.pdf")
```

### Validate mfi effeciency 
```{r}
df.software.compare <- jsCAT.simulation.df %>% 
  group_by(variant, trialNumTotal) %>% 
  summarise(sem = mean(thetaSE)) %>% 
  add_column(software = "jsCAT") %>% 
  rbind(df.mfi %>% 
  add_column(variant = "mfi") %>% 
  rbind(df.random %>% add_column(variant = "random")) %>% 
  group_by(variant, trialNumTotal) %>% 
  summarise(sem = mean(SE_F1)) %>% 
  add_column(software = "catR"))
  
```

```{r}
g.software.compare.1 <- ggplot(df.software.compare, mapping = aes(x = trialNumTotal,
                     y = sem, 
                     color = software, 
                     shape = variant)) + 
  geom_point() + 
  ylim(0.2, 1) + 
  labs(x = "Number of test items",
       y = "Standard error of measurement",
       title = "jsCAT vs. catR (Rasch model)") + 
  theme(legend.position = "none")
```

# Post-hoc Simulation

## Data Loading
```{r}
load("../data/state_c1_2007_7_responses.Rdata")
```

```{r}
#grade 7, 
df.mc <- df %>% 
  filter(substring(item, 1, 2) == "mc") 

# remove participants who have missing response
df.mc.keep.pid <- df.mc %>% 
  filter(!is.na(resp)) %>% 
  group_by(id) %>% 
  tally() %>% 
  filter(n == 55)

df.reading.resp.with.pid <- df.mc %>% 
  left_join(df.mc.keep.pid, by = "id") %>% 
  filter(!is.na(n)) %>% 
  select(-n) %>% 
  pivot_wider(names_from = item, values_from = resp) %>% 
  dplyr :: rename(pid = id)

df.reading.resp <- df.reading.resp.with.pid %>% 
  select(-pid)
```


## Fit 3PL 
```{r}
mod.2 <- mirt(data.frame(df.reading.resp), 1, itemtype="3PL")
```
```{r}
params.2 <- data.frame(coef(mod.2, IRTpars = TRUE, simplify = TRUE))
```


## Create item bank
```{r}
posthoc.item.bank <- params.2 %>% 
  mutate(a = round(items.a, digits = 8), 
         b = round(items.b, digits = 8), 
         c = round(items.g, digits = 8), 
         d = round(items.u, digits = 8)) %>% 
  select(-c(means, F1, items.a, items.b, items.g, items.u)) %>% 
  tibble::rownames_to_column(var = "item")
```

```{r}
#write.csv(posthoc.item.bank, "../src/wordlist/posthoc.item.bank.csv")
```

## Create resp
```{r}
set.seed(123) # Setting seed for reproducibility
df.reading.resp.with.pid.200 <- df.reading.resp.with.pid %>% 
  sample_n(200) %>% 
  mutate(pid = seq(1, 200, 1)) 
```

```{r}
#write.csv(df.reading.resp.with.pid.200, "../src/data/resp_posthoc_with_pid.csv")
```

```{r}
th_SE.posthoc <- as_tibble(fscores(mod.2, method = "ML", response.pattern = df.reading.resp.with.pid.200 %>% select(-pid), max_theta = 6, min_theta = -6, theta_lim = c(-6, 6)))
```
## catR simulation
```{r}
df.mfi.posthoc <- func.catSim(df.reading.resp.with.pid.200 %>% select(-pid), posthoc.item.bank %>% select(-item), 'MFI')
df.random.posthoc <- func.catSim(df.reading.resp.with.pid.200 %>% select(-pid), posthoc.item.bank %>% select(-item), 'random')
```

## jsCAT simulation
```{r}
jsCAT.simulation.posthoc.df <- read.csv("../data/jsCAT-simulation-data-posthoc.csv")
```

### Validate MFI effeciency 
```{r}
df.software.compare.posthoc <- jsCAT.simulation.posthoc.df %>% 
  group_by(variant, trialNumTotal) %>% 
  summarise(sem = mean(thetaSE)) %>% 
  add_column(software = "jsCAT") %>% 
  rbind(df.mfi.posthoc %>% 
  add_column(variant = "mfi") %>% 
  rbind(df.random.posthoc %>% add_column(variant = "random")) %>% 
  group_by(variant, trialNumTotal) %>% 
  summarise(sem = mean(SE_F1)) %>% 
  add_column(software = "catR"))
  
```

```{r fig.width=4,fig.height=2}
g.software.compare.2 <-ggplot(df.software.compare.posthoc, mapping = aes(x = trialNumTotal,
                     y = sem, 
                     color = software, 
                     shape = variant)) + 
  geom_point() + 
  ylim(0.2, 1) + 
  labs(x = "Number of test items",
       y = "Standard error of measurement",
       title = "jsCAT vs. catR (3PL model)") + 
  theme(legend.position = "bottom", 
        legend.title = element_blank(),
        legend.key.size = unit(1, "lines")) 

g.software.compare.1 + g.software.compare.2


ggsave("../plots/jsCAT_validation_2.png")
ggsave("../plots/jsCAT_validation_2.pdf")
```

