2022-12-04 19:16:32 -06:00

657 lines
18 KiB
C++

/***
* Copyright (c) 2020 Duality Technologies, Inc.
* Licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License <https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode>
* See the LICENSE.md file for the full text of the license.
* If you share the Licensed Material (including in modified form) you must include the above attribution in the copy you share.
***/
/*
Implementation for the Chi-Square GWAS solution described in
"Secure large-scale genome-wide association studies using homomorphic encryption"
by Marcelo Blatt, Alexander Gusev, Yuriy Polyakov, and Shafi Goldwasser
Command to execute the prototype
./demo-chi2 --SNPdir "../data" --SNPfilename "random_sample" --pvalue "pvalue.txt" --runtime "result.txt" --samplesize="200" --snps="16384"
*/
#include <getopt.h>
#include <numeric>
#include <cmath>
#include <boost/mpi.hpp>
#include <boost/mpi/environment.hpp>
#include "palisade.h"
#include "demo_chi2.h"
#include "avail_mem.h"
namespace mpi = boost::mpi;
namespace mt = mpi::threading;
using namespace std;
using namespace lbcrypto;
double BS(double z) {
double y = exp(-z*z/2);
return sqrt(1-y) * (31*y/200 - 341*y*y/8000) / sqrt(M_PI);
}
static bool Equal(double a, double b) { return (EPSILON > fabs(a-b)); }
static bool Less(double a, double b) { return ((a-b) < (-EPSILON)); }
static bool Greater(double a, double b) { return ((a-b) > EPSILON); }
int main(int argc, char *argv[]) {
int opt;
PalisadeParallelControls.Enable();
static struct option long_options[] =
{
/* These options dont set a flag.
We distinguish them by their indices. */
{"SNPdir", required_argument, 0, 'S'},
{"SNPfilename", required_argument, 0, 's'},
{"pvalue", required_argument, 0, 'p'},
{"runtime", required_argument, 0, 'r'},
{"samplesize", required_argument, 0, 'N'},
{"snps", required_argument, 0, 'M'},
{0, 0, 0, 0}
};
/* getopt_long stores the option index here. */
int option_index = 0;
string SNPDir;
string SNPFileName;
string pValue;
string Runtime;
string SampleSize;
string SNPs;
vector<string> headers1;
vector<string> headersS;
std::vector<double> yData;
std::vector<std::vector<double>> sData;
while ((opt = getopt_long(argc, argv, "S:s:p:r:N:M", long_options, &option_index)) != -1) {
switch (opt)
{
case 'S':
SNPDir = optarg;
break;
case 's':
SNPFileName = optarg;
break;
case 'p':
pValue = optarg;
break;
case 'r':
Runtime = optarg;
break;
case 'N':
SampleSize = optarg;
break;
case 'M':
SNPs = optarg;
break;
default: /* '?' */
std::cerr<< "Usage: "<<argv[0]<<" <arguments> " <<std::endl
<< "arguments:" <<std::endl
<< " -S --SNPdir SNP file directory" <<std::endl
<< " -s --SNPfilename SNP file name" <<std::endl
<< " -o --pvalue p-values file" <<std::endl
<< " -r --runtime runtime output file name" <<std::endl
<< " -N --samplesize number of individuals" <<std::endl
<< " -M --snps number of SNPs" <<std::endl;
exit(EXIT_FAILURE);
}
}
mt::level mt_level = mt::multiple;
boost::mpi::environment env(argc, argv, mt_level);
mt ::level provided = env.thread_level();
mpi::communicator world;
if (world.rank() == 0){
std::cout << "Asked the MPI environment to be created with threading level: "\
<< mt_level << std::endl;
std::cout << "MPI Environment was created with threading level: " << provided \
<< std::endl;
}
char hammer_host[256];
gethostname(hammer_host, 256);
std::cout << "\nHello from process " << world.rank() << " of " << world.size()
<< " running on " << hammer_host << std::endl;
world.barrier();
ReadSNPFile(headersS,sData,yData,SNPDir + "/" + SNPFileName,1,1, 0);
auto tot_sys_mem = getTotalSystemMemory();
auto mem_test = (double) ((yData.size() * (sizeof(double_t)) + (sData.size() * (sData.back().size() * sizeof(double_t)))));
std::cout << "Maximum Memory size of Node " << world.rank() << " is " << std::ceil(double(tot_sys_mem/1024.0/1024.0/1024.0) * 100.0)/100.0 << " GB" << std::endl;
int batch_size = std::floor((0.50*tot_sys_mem)/mem_test);
int rem = std::stoi(SampleSize) % world.size();
int node_samples = std::stoi(SampleSize)/world.size();
int node_start = world.rank()*node_samples + rem;
//int node_stop = node_start+(node_samples-1);
if (world.rank() < rem){
node_start = world.rank() * (node_samples+1);
//int node_stop = node_start + node_samples;
}
std::cout << "Node " << world.rank() << " can have a maximum batch size of " << batch_size << " samples." << std::endl;
world.barrier();
if (world.rank() == 0 ){
std::cout << "Each node is responsible for " << node_samples << " samples.\n" << std::endl;
}
double keyGenTime(0.0);
double encryptionTime(0.0);
double computationTime(0.0);
double decryptionTime(0.0);
double endToEndTime(0.0);
double ioTime(0.0);
//usint m = 16384;
usint m = 65536;
usint init_size = 4;
usint dcrtBits = 54;
CryptoContext<DCRTPoly> cc =
CryptoContextFactory<DCRTPoly>::genCryptoContextCKKS(
init_size-1,
dcrtBits,
0,
HEStd_128_classic,
m/2, /*ringDimension*/
APPROXRESCALE,
BV,
3, /*numLargeDigits*/
2, /*maxDepth*/
dcrtBits, /*firstMod*/
0,
OPTIMIZED);
cc->Enable(ENCRYPTION);
cc->Enable(SHE);
cc->Enable(LEVELEDSHE);
TimeVar t;
TIC(t);
auto keyPair = cc->KeyGen();
keyGenTime += TOC(t);
for (int i = 0; i < std::ceil((double)node_samples/(double)batch_size); i++)
{
int curr_batch = node_samples - (batch_size*i);
int batch_gen = curr_batch;
if (curr_batch > batch_size){
batch_gen = batch_size;
}
int file_start = node_start + (batch_size*i);
std::cout << "Node " << world.rank() << " running batch " << i << " with batch size " << batch_gen << std::endl;
RunChi2(SNPDir, SNPFileName, pValue, Runtime, std::to_string(batch_gen), SNPs, file_start, keyGenTime, encryptionTime, computationTime, decryptionTime, endToEndTime, ioTime, cc, keyPair, i, world);
}
std::cout << "Node " << world.rank() << " completed all batches." << std::endl;
if (world.rank() != 0){
world.isend(0, 0, keyGenTime);
world.isend(0, 1, encryptionTime);
world.isend(0, 2, computationTime);
world.isend(0, 3, decryptionTime);
world.isend(0, 4, endToEndTime);
world.isend(0, 5, ioTime);
}
world.barrier();
if (world.rank() == 0){
std::vector<double> task_times;
for (int task = 0; task < 6; task++){
task_times.push_back(0);
}
task_times.at(0) += keyGenTime;
task_times.at(1) = encryptionTime;
task_times.at(2)= computationTime;
task_times.at(3) = decryptionTime;
task_times.at(4) = endToEndTime;
task_times.at(5) = ioTime;
if (world.size() > 1){
for (int j = 0; j < 6; j++){
while (world.iprobe(mpi::any_source, j)){
double tmp;
world.recv(mpi::any_source, j, tmp);
task_times.at(j) += tmp;
}
}
}
std::cout << "\nTotal Key Generation Time: \t\t" << task_times.at(0)/1000 << " s" << std::endl;
std::cout << "Total Encoding and Encryption Time: \t" << task_times.at(1)/1000 << " s" << std::endl;
std::cout << "Total Computation Time: \t\t" << task_times.at(2)/1000 << " s" << std::endl;
std::cout << "Total Decryption & Decoding Time: \t" << task_times.at(3)/1000 << " s" << std::endl;
std::cout << "Total File IO Time: \t\t\t" << task_times.at(5)/1000 << " s" << std::endl;
std::cout << "Total End-to-end Runtime: \t\t" << task_times.at(4)/1000 << " s" << std::endl;
}
return 0;
}
void RunChi2(const string &SNPDir,
const string &SNPFileName, const string &pValue, const string &Runtime, const string &SampleSize, const string &SNPs,
size_t file_start,
double &keyGenTime, double &encryptionTime, double &computationTime, double &decryptionTime, double &endToEndTime, double &ioTime,
CryptoContext<DCRTPoly> &cc, auto &keyPair, int batch_num, mpi::communicator &world) {
TimeVar t;
TimeVar tAll;
TIC(tAll);
//std::cout << "\n======CHI-SQUARE SOLUTION========\n" << std::endl;
vector<string> headers1;
vector<string> headersS;
std::vector<double> yData;
std::vector<std::vector<double>> sData;
size_t N = std::stoi(SampleSize);
size_t M = std::stoi(SNPs);
usint m = 65536;
double scalingFactor = 2.5e-6;
TIC(t);
ReadSNPFile(headersS,sData,yData,SNPDir + "/" + SNPFileName,N,M, file_start);
ioTime += TOC(t);
TIC(t);
size_t sizeS = (size_t)std::ceil((double)sData[0].size()/(m/4));
std::vector<std::vector<std::vector<std::complex<double>>>> sDataArray(sizeS);
for(size_t s = 0; s < sizeS; s++)
sDataArray[s] = std::vector<std::vector<std::complex<double>>>(sData.size());
for (size_t i=0; i < sData.size(); i++){
for(size_t s = 0; s < sizeS; s++)
sDataArray[s][i] = std::vector<std::complex<double>>(sData[i].size());
size_t counter = 0;
for (size_t j=0; j<sData[i].size(); j++) {
if ((j>0) && (j%(m/4)==0))
counter++;
sDataArray[counter][i][j%(m/4)] = sData[i][j];
}
}
//std::cerr << " sData = " << sDataArray << std::endl;
std::vector<std::vector<Ciphertext<DCRTPoly>>> S(sizeS);
std::vector<Ciphertext<DCRTPoly>> Y(N);
for (size_t i = 0; i < sizeS; i++)
S[i] = std::vector<Ciphertext<DCRTPoly>>(N);
//Encryption of single-integer ciphertexts
#pragma omp parallel for
for (size_t i=0; i<N; i++){
for (size_t s=0; s < sizeS; s++){
Plaintext sTemp = cc->MakeCKKSPackedPlaintext(sDataArray[s][i]);
S[s][i] = cc->Encrypt(keyPair.publicKey, sTemp);
}
Plaintext sTemp2 = cc->MakeCKKSPackedPlaintext(std::vector<std::complex<double>>(m/4,yData[i]));
Y[i] = cc->Encrypt(keyPair.publicKey, sTemp2);
}
encryptionTime += TOC(t);
TIC(t);
Plaintext d = cc->MakeCKKSPackedPlaintext(std::vector<std::complex<double>>(m/4,2*N));
d->SetFormat(EVALUATION);
Plaintext dScaled = cc->MakeCKKSPackedPlaintext(std::vector<std::complex<double>>(m/4,2*N*scalingFactor));
dScaled->SetFormat(EVALUATION);
std::vector<Ciphertext<DCRTPoly>> ySum = Y;
auto yU = BinaryTreeAdd(ySum);
std::vector<Ciphertext<DCRTPoly>> chiD(sizeS);
std::vector<Ciphertext<DCRTPoly>> chiN(sizeS);
std::vector<Ciphertext<DCRTPoly>> orD(sizeS);
std::vector<Ciphertext<DCRTPoly>> orN(sizeS);
for (size_t s = 0; s < sizeS; s++) {
std::vector<Ciphertext<DCRTPoly>> ySMult(N);
#pragma omp parallel for
for(size_t i = 0; i < N; i++) {
ySMult[i] = cc->EvalMultNoRelin(S[s][i],Y[i]);
}
auto n11 = BinaryTreeAdd(ySMult);
n11 = cc->ModReduce(n11);
auto c1 = BinaryTreeAdd(S[s]);
auto r1 = cc->EvalAdd(yU,yU);
auto r1Scaled = cc->EvalMult(r1,scalingFactor);
r1Scaled = cc->ModReduce(r1Scaled);
auto c1Scaled = cc->EvalMult(c1,scalingFactor);
c1Scaled = cc->ModReduce(c1Scaled);
r1 = cc->LevelReduce(r1,nullptr);
c1 = cc->LevelReduce(c1,nullptr);
// Chi2 computation
// numerator
auto mult1 = cc->EvalMult(n11,dScaled);
auto mult2 = cc->EvalMultNoRelin(c1,r1Scaled);
auto chiN1 = cc->EvalSub(mult1,mult2);
chiN1 = cc->ModReduce(chiN1);
chiN[s] = cc->EvalMultNoRelin(chiN1,chiN1);
// denominator
auto chiD1 = cc->EvalMultNoRelin(c1,cc->EvalSub(dScaled,c1Scaled));
chiD1 = cc->ModReduce(chiD1);
auto chiD2 = cc->EvalMultNoRelin(r1,cc->EvalSub(dScaled,r1Scaled));
chiD2 = cc->ModReduce(chiD2);
chiD[s] = cc->EvalMultNoRelin(chiD1,chiD2);
// Odds Ratio Computation
auto n11Scaled = cc->EvalMult(n11,scalingFactor);
n11Scaled = cc->ModReduce(n11Scaled);
// denominator
auto or2 = cc->EvalSub(c1,n11);
or2 = cc->LevelReduce(or2,nullptr);
r1Scaled = cc->LevelReduce(r1Scaled,nullptr);
auto or3 = cc->EvalSub(r1Scaled,n11Scaled);
orD[s] = cc->EvalMultNoRelin(or2,or3);
// numerator
auto or1 = cc->EvalSub(n11Scaled,r1Scaled);
c1Scaled = cc->LevelReduce(c1Scaled,nullptr);
or1 = cc->EvalSub(or1,c1Scaled);
or1 = cc->EvalAdd(or1,dScaled);
n11 = cc->LevelReduce(n11,nullptr);
orN[s] = cc->EvalMultNoRelin(n11,or1);
}
computationTime += TOC(t);
TIC(t);
std::vector<Plaintext> pN(sizeS);
std::vector<Plaintext> pD(sizeS);
std::vector<Plaintext> oddN(sizeS);
std::vector<Plaintext> oddD(sizeS);
for (size_t s = 0; s < sizeS; s++) {
cc->Decrypt(keyPair.secretKey, chiN[s] , &(pN[s]));
cc->Decrypt(keyPair.secretKey, chiD[s] , &(pD[s]));
cc->Decrypt(keyPair.secretKey, orN[s] , &(oddN[s]));
cc->Decrypt(keyPair.secretKey, orD[s] , &(oddD[s]));
}
decryptionTime += TOC(t);
std::vector<double> chival(headersS.size());
std::vector<double> pval(headersS.size());
std::vector<double> odds(headersS.size());
for (size_t s = 0; s < sizeS; s++) {
for (size_t i = 0; i < m/4; i++) {
if (s*m/4 + i < headersS.size()) {
chival[s*m/4 + i] = pN[s]->GetCKKSPackedValue()[i].real()*2*N/pD[s]->GetCKKSPackedValue()[i].real();
if (chival[s*m/4 + i] < 0)
chival[s*m/4 + i] = 0;
pval[s*m/4 + i] = (double)1-IncompleteGamma(chival[s*m/4 + i]/2,0.5);
if (pval[s*m/4 + i] < 0)
pval[s*m/4 + i] = 1e-15;
else
if (pval[s*m/4 + i]==0)
pval[s*m/4 + i] = BS(sqrt(chival[s*m/4 + i]));
odds[s*m/4 + i] = oddN[s]->GetCKKSPackedValue()[i].real()/oddD[s]->GetCKKSPackedValue()[i].real();
}
}
}
ofstream myfile;
myfile.open(SNPDir + "/" + std::to_string(world.rank()) + "_" + std::to_string(batch_num) + pValue);
myfile.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfile << headersS[i] << "\t" << pval[i] << std::endl;
}
myfile.close();
ofstream myfile2;
myfile2.open(SNPDir + "/" + std::to_string(world.rank()) + "_" + std::to_string(batch_num) + "odds.txt");
myfile2.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfile2 << headersS[i] << "\t" << odds[i] << std::endl;
}
myfile2.close();
ofstream myfile3;
myfile3.open(SNPDir + "/" + std::to_string(world.rank())+ "_" + std::to_string(batch_num) + "chi2.txt");
myfile3.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfile3 << headersS[i] << "\t" << chival[i] << std::endl;
}
myfile3.close();
//std::cout << "\nKey Generation Time: \t\t" << keyGenTime/1000 << " s" << std::endl;
//std::cout << "Encoding and Encryption Time: \t" << encryptionTime/1000 << " s" << std::endl;
//std::cout << "Computation Time: \t\t" << computationTime/1000 << " s" << std::endl;
//std::cout << "Decryption & Decoding Time: \t" << decryptionTime/1000 << " s" << std::endl;
endToEndTime += TOC(tAll);
//std::cout << "\nEnd-to-end Runtime: \t\t" << endToEndTime/1000 << " s" << std::endl;
ofstream myfileRuntime;
myfileRuntime.open(SNPDir + "/" + std::to_string(world.rank()) + "_" + std::to_string(batch_num) + Runtime);
myfileRuntime << "Key Generation Time: \t\t" << keyGenTime/1000 << " s" << std::endl;
myfileRuntime << "Encoding and Encryption Time: \t" << encryptionTime/1000 << " s" << std::endl;
myfileRuntime << "Computation Time: \t\t" << computationTime/1000 << " s" << std::endl;
myfileRuntime << "Decryption & Decoding Time: \t" << decryptionTime/1000 << " s" << std::endl;
myfileRuntime << "End-to-end Runtime: \t\t" << endToEndTime/1000 << " s" << std::endl;
myfileRuntime.close();
}
Ciphertext<DCRTPoly> BinaryTreeAdd(std::vector<Ciphertext<DCRTPoly>> &vector) {
auto cc = vector[0]->GetCryptoContext();
for(size_t j = 1; j < vector.size(); j=j*2) {
for(size_t i = 0; i<vector.size(); i = i + 2*j) {
if ((i+j)<vector.size())
vector[i] = cc->EvalAdd(vector[i],vector[i+j]);
}
}
Ciphertext<DCRTPoly> result(new CiphertextImpl<DCRTPoly>(*(vector[0])));
return result;
}
double IncompleteGamma(double val, double p)
{
if( !Greater(val, 0) || !Greater(p, 0) )
return 0;
double expValue = p*log(val) - val - lgamma(p);
if( Less(expValue, log(1.0E-37)) ) // undeflow
return 0;
double factor = exp(expValue);
if( !Greater(val, 1) || Less(val, p) )
{
double igamma = 1;
double term = 1;
for( int i = 1; Greater(term, EPSILON); ++i )
{
term *= (val/(p+i));
igamma += term;
}
return (igamma*factor/p);
}
double pn[6] = { 1, val, val+1, val*(2+val-p) };
double upperIncGamma = pn[2]/pn[3];
for( int j = 1; ; ++j )
{
double a = (j+1)*2 + val- p;
double b = (1 + j - p)*j;
pn[4] = a*pn[2] - b*pn[0];
pn[5] = a*pn[3] - b*pn[1];
if( !Equal(pn[5], 0) )
{
double rn = pn[4]/pn[5];
double diff = fabs(upperIncGamma - rn);
if( !Greater(diff, EPSILON) && !Greater(diff, (EPSILON*rn)) )
return (1 - factor*upperIncGamma);
upperIncGamma = rn;
}
for( int i = 0; i < 4; i++ )
pn[i] = pn[i+2];
if( !Greater(1.0E+37, fabs(pn[4])) ) // overflow
{
for( int i = 0; i < 4; i++ )
pn[i] = pn[i] / 1.0E+37;
}
}
return 0;
}
void ReadSNPFile(vector<string>& headers, std::vector<std::vector<double>> & dataColumns, std::vector<double> &y,
string dataFileName, size_t N, size_t M, size_t from)
{
uint32_t cols = 0;
string fileName = dataFileName + ".csv";
//std::cerr << "file name = " << fileName << std::endl;
ifstream file(fileName);
string line, value;
if(file.good()) {
getline(file, line);
cols = std::count(line.begin(), line.end(), ',') + 1;
stringstream ss(line);
vector<string> result;
size_t tempcounter = 0;
for(uint32_t i = 0; i < cols; i++) {
string substr;
getline(ss, substr, ',');
if ((substr != "") && (i>4) && (i<M+5)) {
headers.push_back(substr);
tempcounter++;
}
}
cols = tempcounter;
}
size_t counter = 0;
size_t skip_counter = 0;
while((file.good()) && (counter < N)) {
while(skip_counter < from){
getline(file, line);
skip_counter++;
}
getline(file, line);
uint32_t curCols = std::count(line.begin(), line.end(), ',') + 1;
if (curCols > 2) {
stringstream ss(line);
for(uint32_t i = 0; i < 5; i++) {
string substr;
getline(ss, substr, ',');
if ((i==1))
y.push_back(std::stod(substr));
}
std::vector<double> row(cols);
for(uint32_t i = 5; i < cols + 5; i++) {
string substr;
getline(ss, substr, ',');
if (i < M+5)
{
double val;
val = std::stod(substr);
row[i-5] = val;
}
}
dataColumns.push_back(row);
}
counter++;
}
file.close();
//std::cout << "Read in data: ";
//std::cout << dataFileName << std::endl;
}