---
title: "Saving and Using Saved Models"
date: "`r format(Sys.Date(), '%Y-%m-%d')`"
output: 
  rmarkdown::html_vignette:
    toc: true
    toc_depth: 2
    fig_width: 7
    fig_height: 5
    dpi: 600
vignette: >
  %\VignetteIndexEntry{Saving and Using Saved Models}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

## Saving a trained CISS-VAE model

```{r}
#| eval: false
library(reticulate)
library(rCISSVAE)

# Train a model
res <- run_cissvae(data)

# Save the trained model to disk
save_cissvae_model(res$model, "trained_vae.pt")

# IMPORTANT
# The Python environment must be active so 'torch' can be imported.
```

## Loading a saved model and imputing data 


```{.r}
library(rCISSVAE)
library(reticulate)

## Activate your Python environment 
reticulate::use_virtualenv("cissvae_environment", required = TRUE)

## Load full model object
model <- load_cissvae_model(
  file = "trained_vae.pt"
)

## Perform imputation on new data
# Make sure your `data` has valid NAs and `clusters` vector is ready
imputed_df <- impute_with_cissvae(
  model_py = model,
  data = data,
  index_col = "index",
  columns_ignore = NULL,
  clusters = clusters,
  imputable_matrix = NULL,
  binary_feature_mask = NULL,
  val_proportion = 0.1,
  replacement_value = 0,
  batch_size = 4000L,
  seed = 42
)

# `imputed_df` is returned to R as a data.frame
```

If you have binary variables in your dataset, make sure to define the binary_feature_mask and convert the probabilities for the binary variables into {0, 1} values using desired thresholding. 