1654 lines
51 KiB
C++

/***
* Copyright (c) 2020 Duality Technologies, Inc.
* Licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License <https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode>
* See the LICENSE.md file for the full text of the license.
* If you share the Licensed Material (including in modified form) you must include the above attribution in the copy you share.
***/
/*
Implementation for the Logistic Regression Approximation GWAS solution described in
"Secure large-scale genome-wide association studies using homomorphic encryption"
by Marcelo Blatt, Alexander Gusev, Yuriy Polyakov, and Shafi Goldwasser
Command to execute the prototype
./demo-logistic --SNPdir "../data" --SNPfilename "random_sample" --pvalue "pvalue.txt" --runtime "result.txt" --samplesize="200" --snps="16384"
*/
#include <getopt.h>
#include <numeric>
#include <cmath>
#include "palisade.h"
using namespace std;
using namespace lbcrypto;
void RunLogReg(const string &SNPDir, const string &SNPFileName, const string &pValue,
const string &Runtime, const string &SampleSize, const string &SNPs);
Ciphertext<DCRTPoly> zExpand(const Ciphertext<DCRTPoly> p, const Ciphertext<DCRTPoly> y);
shared_ptr<std::vector<std::vector<Ciphertext<DCRTPoly>>>> MatrixInverse(const Ciphertext<DCRTPoly> m, size_t k, CiphertextImpl<DCRTPoly> &b, CiphertextImpl<DCRTPoly> &cd,
const std::map<usint, LPEvalKey<DCRTPoly>> &map, const std::map<usint, LPEvalKey<DCRTPoly>> &rotKeys,
const std::map<usint, LPEvalKey<DCRTPoly>> &evalSumRows);
Ciphertext<DCRTPoly> CloneCiphertext(const Ciphertext<DCRTPoly> ciphertext, size_t size,
const std::map<usint,LPEvalKey<DCRTPoly>> &rotKeys, const std::map<usint,LPEvalKey<DCRTPoly>> &evalSumRows);
shared_ptr<std::vector<Ciphertext<DCRTPoly>>> SplitIntoSingle(const Ciphertext<DCRTPoly> c, size_t N, size_t k,
const std::map<usint, LPEvalKey<DCRTPoly>> &rotKeys);
Ciphertext<DCRTPoly> BinaryTreeAdd(std::vector<Ciphertext<DCRTPoly>> &vector);
void CompressEvalKeys(std::map<usint, LPEvalKey<DCRTPoly>> &ek, size_t level);
void ReadSNPFile(vector<string>& headers, std::vector<std::vector<double>> & dataColumns,std::vector<std::vector<double>> &x, std::vector<double> &y,
string dataFileName, size_t N, size_t M);
shared_ptr<vector<DCRTPoly>> KeySwitchPrecompute(ConstCiphertext<DCRTPoly> cipherText);
Ciphertext<DCRTPoly> HoistedAutomorphism(const LPEvalKey<DCRTPoly> ek,
ConstCiphertext<DCRTPoly> cipherText, const shared_ptr<vector<DCRTPoly>> digits, const usint index);
double normalCFD(double value) { return 0.5 * erfc(-value * M_SQRT1_2); }
double sf(double value) { return 1 - normalCFD(value); }
double BS(double z) {
double y = exp(-z*z/2);
return sqrt(1-y) * (31*y/200 - 341*y*y/8000) / sqrt(M_PI);
}
int main(int argc, char **argv) {
int opt;
static struct option long_options[] =
{
/* These options dont set a flag.
We distinguish them by their indices. */
{"SNPdir", required_argument, 0, 'S'},
{"SNPfilename", required_argument, 0, 's'},
{"pvalue", required_argument, 0, 'p'},
{"runtime", required_argument, 0, 'r'},
{"samplesize", required_argument, 0, 'N'},
{"snps", required_argument, 0, 'M'},
{0, 0, 0, 0}
};
/* getopt_long stores the option index here. */
int option_index = 0;
string SNPDir;
string SNPFileName;
string pValue;
string Runtime;
string SampleSize;
string SNPs;
while ((opt = getopt_long(argc, argv, "S:s:p:r:N:M", long_options, &option_index)) != -1) {
switch (opt)
{
case 'S':
SNPDir = optarg;
break;
case 's':
SNPFileName = optarg;
break;
case 'p':
pValue = optarg;
break;
case 'r':
Runtime = optarg;
break;
case 'N':
SampleSize = optarg;
break;
case 'M':
SNPs = optarg;
break;
default: /* '?' */
std::cerr<< "Usage: "<<argv[0]<<" <arguments> " <<std::endl
<< "arguments:" <<std::endl
<< " -S --SNPdir SNP file directory" <<std::endl
<< " -s --SNPfilename SNP file name" <<std::endl
<< " -o --pvalue p-values file" <<std::endl
<< " -r --runtime runtime output file name" <<std::endl
<< " -N --samplesize number of individuals" <<std::endl
<< " -M --snps number of SNPs" <<std::endl;
exit(EXIT_FAILURE);
}
}
RunLogReg(SNPDir, SNPFileName, pValue, Runtime, SampleSize, SNPs);
return 0;
}
void RunLogReg(const string &SNPDir, const string &SNPFileName, const string &pValue, const string &Runtime, const string &SampleSize, const string &SNPs) {
TimeVar t;
TimeVar tAll;
TIC(tAll);
double keyGenTime(0.0);
double encryptionTime(0.0);
double computation1Time(0.0);
double computation2Time(0.0);
double computation3Time(0.0);
double computation4Time(0.0);
double computation5Time(0.0);
double computation6Time(0.0);
double computation7Time(0.0);
double computation8Time(0.0);
double computationTime(0.0);
double decryptionTime(0.0);
double endToEndTime(0.0);
std::cout << "\n======LOGISTIC REGRESSION SOLUTION========\n" << std::endl;
vector<string> headers1;
vector<string> headersS;
std::vector<std::vector<double>> xData;
std::vector<double> yData;
std::vector<std::vector<double>> sData;
size_t N = std::stoi(SampleSize);
size_t M = std::stoi(SNPs);
double scalingFactor = 1e-1;
double scalingFactorD = 1e-2;
double scalingFactorN = 1e-2;
ReadSNPFile(headersS,sData,xData,yData,SNPDir + "/" + SNPFileName,N,M);
N = sData.size();
M = sData[0].size();
size_t k = xData[0].size();
usint m;
m = 65536;
size_t n = m/4;
usint init_size = 17;
usint dcrtBits = 50;
size_t k2 = k*k;
usint batchSize = k*k;
CryptoContext<DCRTPoly> cc =
CryptoContextFactory<DCRTPoly>::genCryptoContextCKKS(
init_size-1,
dcrtBits,
k,
HEStd_128_classic,
m/2, /*ringDimension*/
APPROXRESCALE,
BV,
3, /*numLargeDigits*/
4, /*maxDepth*/
dcrtBits, /*firstMod*/
0,
OPTIMIZED);
cc->Enable(ENCRYPTION);
cc->Enable(SHE);
cc->Enable(LEVELEDSHE);
std::cout << "\nNumber of Individuals = " << N << std::endl;
std::cout << "Number of SNPs = " << M << std::endl;
std::cerr << "Number of features = " << k << std::endl;
TIC(t);
auto keyPair = cc->KeyGen();
cc->EvalMultKeysGen(keyPair.secretKey);
cc->EvalSumKeyGen(keyPair.secretKey);
auto evalSumRows = cc->EvalSumRowsKeyGen(keyPair.secretKey, nullptr, k);
auto evalSumCols = cc->EvalSumColsKeyGen(keyPair.secretKey, nullptr);
// EvalSum key is also used for rotations by 1 and 2
auto evalSum = cc->GetEvalSumKeyMap(keyPair.secretKey->GetKeyTag());
auto pubKeyS = LPPublicKey<DCRTPoly>(new LPPublicKeyImpl<DCRTPoly>(*keyPair.publicKey));
std::vector<DCRTPoly> pubElementsS = pubKeyS->GetPublicElements();
for (size_t i=0; i < pubElementsS.size(); i++)
pubElementsS[i].DropLastElements(10);
pubKeyS->SetPublicElements(pubElementsS);
auto pubKeyX = LPPublicKey<DCRTPoly>(new LPPublicKeyImpl<DCRTPoly>(*keyPair.publicKey));
auto pubElementsX = pubKeyX->GetPublicElements();
for (size_t i=0; i < pubElementsX.size(); i++)
pubElementsX[i].DropLastElements(11);
pubKeyX->SetPublicElements(pubElementsX);
std::vector<int32_t> indicesM;
for (size_t i = 3; i < k*k; i++) {
if (!((i == 4) || (i == 8)))
indicesM.push_back(i);
}
cc->SetKeyGenLevel(5);
auto rotKeysM = cc->GetEncryptionAlgorithm()->EvalAtIndexKeyGen(nullptr,keyPair.secretKey, indicesM);
std::vector<int32_t> indicesConv;
for (size_t i = 4; i < m/4; i=2*i)
indicesConv.push_back(m/4-i);
cc->SetKeyGenLevel(8);
auto rotKeysConv = cc->GetEncryptionAlgorithm()->EvalAtIndexKeyGen(nullptr,keyPair.secretKey, indicesConv);
keyGenTime = TOC(t);
TIC(t);
uint32_t numCt = (uint32_t)std::ceil((double)(N*k*k)/(double)(n));
std::cerr << "Number of ciphertexts: " << numCt << std::endl;
size_t Nfull = n/k2;
std::vector<std::vector<std::complex<double>>> x;
for (size_t r = 0; r < numCt; r++) {
std::vector<std::complex<double>> xTemp;
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
for (size_t j = 0; j < N1; j++)
for (size_t i = 0; i < k; i++)
for (size_t p = 0; p < k; p++)
xTemp.push_back(xData[j+r*Nfull][p]);
x.push_back(xTemp);
}
std::vector<std::vector<std::complex<double>>> xt;
for (size_t r = 0; r < numCt; r++) {
std::vector<std::complex<double>> xTemp;
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
for (size_t j = 0; j < N1; j++)
for (size_t i = 0; i < k; i++)
for (size_t p = 0; p < k; p++)
xTemp.push_back(xData[j+r*Nfull][i]);
xt.push_back(xTemp);
}
std::vector<std::vector<std::complex<double>>> y;
for (size_t r = 0; r < numCt; r++) {
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
std::vector<std::complex<double>> yTemp;
for (size_t j = 0; j < N1; j++)
for (size_t i = 0; i < k*k; i++) {
{
yTemp.push_back(yData[j+r*Nfull]);
}
}
y.push_back(yTemp);
}
vector<Plaintext> X(numCt);
vector<Plaintext> XT(numCt);
vector<Plaintext> Y(numCt);
for (size_t r = 0; r < numCt; r++) {
X[r] = cc->MakeCKKSPackedPlaintext(x[r]);
XT[r] = cc->MakeCKKSPackedPlaintext(xt[r]);
Y[r] = cc->MakeCKKSPackedPlaintext(y[r]);
}
std::vector<vector<Ciphertext<DCRTPoly>>> cX1(N);
size_t sizeS = (size_t)std::ceil((double)sData[0].size()/(m/4));
std::vector<std::vector<std::vector<std::complex<double>>>> sDataArray(sizeS);
for(size_t s = 0; s < sizeS; s++)
sDataArray[s] = std::vector<std::vector<std::complex<double>>>(sData.size());
for (size_t i=0; i < sData.size(); i++){
for(size_t s = 0; s < sizeS; s++)
sDataArray[s][i] = std::vector<std::complex<double>>(sData[i].size());
size_t counter = 0;
for (size_t j=0; j<sData[i].size(); j++) {
if ((j>0) && (j%(m/4)==0))
counter++;
sDataArray[counter][i][j%(m/4)] = scalingFactor*sData[i][j];
}
}
std::vector<std::vector<Ciphertext<DCRTPoly>>> S(sizeS);
for (size_t i = 0; i < sizeS; i++)
S[i] = std::vector<Ciphertext<DCRTPoly>>(N);
//Encryption of single-integer ciphertexts
#pragma omp parallel for
for (size_t i=0; i<N; i++){
for (size_t s=0; s < sizeS; s++){
Plaintext sTemp = cc->MakeCKKSPackedPlaintext(sDataArray[s][i],1,10,pubElementsS[0].GetParams());
S[s][i] = cc->Encrypt(pubKeyS, sTemp);
}
std::vector<Ciphertext<DCRTPoly>> x1Temp;
for (size_t j=0; j<k; j++){
std::vector<std::complex<double>> xVector(m/4,xData[i][j]);
Plaintext xTemp = cc->MakeCKKSPackedPlaintext(xVector,1,11,pubElementsX[0].GetParams());
x1Temp.push_back(cc->Encrypt(pubKeyX, xTemp));
}
cX1[i] = x1Temp;
}
vector<Ciphertext<DCRTPoly>> cX(numCt);
vector<Ciphertext<DCRTPoly>> cXT(numCt);
vector<Ciphertext<DCRTPoly>> cY(numCt);
for (size_t r=0; r<numCt; r++){
cX[r] = cc->Encrypt(keyPair.publicKey, X[r]);
cXT[r] = cc->Encrypt(keyPair.publicKey, XT[r]);
cY[r] = cc->Encrypt(keyPair.publicKey, Y[r]);
}
encryptionTime = TOC(t);
TIC(t);
double alpha = 0.00358;
// Compute p1 = X^T (y - 0.5)
Ciphertext<DCRTPoly> cP1Sum, cP1;
for (size_t r=0; r<numCt; r++){
cP1 = cc->EvalMult(cX[r],cc->EvalSub(cc->EvalSub(cY[r],double(0.25)),double(0.25)));
cP1 = cc->EvalSumRows(cP1,k*k,*evalSumRows);
//cP1 = cc->ModReduce(cP1);
if (r==0)
cP1Sum = cP1;
else
cP1Sum = cc->EvalAdd(cP1Sum,cP1);
}
cP1Sum = cc->ModReduce(cP1Sum);
// Compute p2 = 0.15625*alpha*X
vector<Ciphertext<DCRTPoly>> cP2Arr(numCt);
for (size_t r=0; r<numCt; r++){
auto cP2 = cc->EvalMult(cX[r],double(0.15625*alpha));
cP2Arr[r] = cc->ModReduce(cP2);
}
//Compute p3 = p1*p2
vector<Ciphertext<DCRTPoly>> cP3Arr(numCt);
for (size_t r=0; r<numCt; r++){
auto cP3 = cc->EvalMult(cP1Sum,cP2Arr[r]);
cP3 = cc->EvalSumCols(cP3,k,*evalSumCols);
cP3 = cc->ModReduce(cP3);
cP3Arr[r] = cc->ModReduce(cP3);
}
//Compute p = p3 + 0.5
vector<Ciphertext<DCRTPoly>> cPArr(numCt);
for (size_t r=0; r<numCt; r++){
//auto cP = cc->EvalSub(cP3Arr[r],cP6Arr[r]);
auto cP = cc->EvalAdd(cP3Arr[r],double(0.25));
cPArr[r] = cc->EvalAdd(cP,double(0.25));
}
// Compute p^2
vector<Ciphertext<DCRTPoly>> cPSquareArr(numCt);
for (size_t r=0; r<numCt; r++){
auto cPSquare = cc->EvalMult(cPArr[r],cPArr[r]);
cPSquareArr[r] = cc->ModReduce(cPSquare);
}
// Compute w = p - p^2
vector<Ciphertext<DCRTPoly>> cWArr(numCt);
for (size_t r=0; r<numCt; r++){
auto cPReduced = cc->LevelReduce(cPArr[r],nullptr);
cWArr[r] = cc->EvalSub(cPReduced,cPSquareArr[r]);
}
// Computes zExpand
vector<Ciphertext<DCRTPoly>> cZArr(numCt);
for (size_t r=0; r<numCt; r++){
cZArr[r] = zExpand(cPArr[r], cY[r]);
}
// Compute x^T diag(w)
vector<Ciphertext<DCRTPoly>> cM1Arr(numCt);
for (size_t r=0; r<numCt; r++){
auto cXTNReduced = cc->LevelReduce(cXT[r],nullptr,4);
auto cM1 = cc->EvalMult(cXTNReduced,cWArr[r]);
cM1Arr[r] = cc->ModReduce(cM1); // Level 5
}
CompressEvalKeys(*evalSumRows,5);
//Compute M = (x^T diag(w)) X
Ciphertext<DCRTPoly> cMSum, cM;
vector<Ciphertext<DCRTPoly>> cXReducedArr(numCt);
for (size_t r=0; r<numCt; r++){
cXReducedArr[r] = cc->LevelReduce(cX[r],nullptr,5); //Level 5
cM = cc->EvalMult(cM1Arr[r],cXReducedArr[r]);
cM = cc->EvalSumRows(cM,k*k,*evalSumRows);
if (r==0)
cMSum = cM;
else
cMSum = cc->EvalAdd(cMSum,cM);
}
computation1Time = TOC(t);
TIC(t);
Ciphertext<DCRTPoly> cB(new CiphertextImpl<DCRTPoly>(cc));
Ciphertext<DCRTPoly> cd(new CiphertextImpl<DCRTPoly>(cc));
CompressEvalKeys(evalSum,5);
auto cB1 = MatrixInverse(cMSum,k,*cB,*cd, evalSum, *rotKeysM, *evalSumRows);
computation2Time = TOC(t);
TIC(t);
cMSum = cc->ModReduce(cMSum);
for (size_t r=0; r<numCt; r++){
cM1Arr[r] = cc->LevelReduce(cM1Arr[r],nullptr,2); //Level 7
}
// Compute (X^T diag(w)) z
Ciphertext<DCRTPoly> cztr1Sum, cztr1;
for (size_t r=0; r<numCt; r++){
cztr1 = cc->EvalMult(cM1Arr[r],cZArr[r]);
cztr1 = cc->EvalSumRows(cztr1,k*k,*evalSumRows);
//cztr1 = cc->ModReduce(cztr1); // Level 8
if (r==0)
cztr1Sum = cztr1;
else
cztr1Sum = cc->EvalAdd(cztr1Sum,cztr1);
}
cztr1Sum = cc->ModReduce(cztr1Sum);
// Compute XB
vector<Ciphertext<DCRTPoly>> cztr2Arr(numCt);
for (size_t r=0; r<numCt; r++){
cXReducedArr[r] = cc->LevelReduce(cXReducedArr[r],nullptr,4); //Level 9
auto cztr2 = cc->EvalMult(cXReducedArr[r],cB);
cztr2 = cc->EvalSumCols(cztr2,k,*evalSumCols);
cztr2 = cc->ModReduce(cztr2);
cztr2Arr[r] = cc->ModReduce(cztr2); //Level 11
}
CompressEvalKeys(*rotKeysM,6);
CompressEvalKeys(*evalSumRows,6);
// Compute the product of (XB) and (X^T diag(w)) z
vector<Ciphertext<DCRTPoly>> cztr3Arr(numCt);
cztr1Sum = cc->LevelReduce(cztr1Sum,nullptr,3); //Level 11
for (size_t r=0; r<numCt; r++){
auto cztr3 = cc->EvalMult(cztr1Sum,cztr2Arr[r]);
auto cztr4 = cc->EvalAdd(cztr3,cc->GetEncryptionAlgorithm()->EvalAtIndex(cztr3,4,*evalSumRows));
auto cztr5 = cc->EvalAdd(cc->GetEncryptionAlgorithm()->EvalAtIndex(cztr3,8,*evalSumRows),cc->GetEncryptionAlgorithm()->EvalAtIndex(cztr3,12,*rotKeysM));
auto cztr6 = cc->EvalAdd(cztr4,cztr5);
cztr6 = cc->ModReduce(cztr6); //Level 12
std::vector<std::complex<double>> mask(m/4);
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
size_t NPow2cur = 1<<(size_t)std::ceil(log2(N1));
for (size_t i = 0; i < NPow2cur*k*k; i++)
{
if ((i % batchSize == 0) || (i % batchSize == 1) || (i % batchSize == 2) || (i % batchSize == 3))
mask[i] = 1;
else
mask[i] = 0;
}
Plaintext plaintextMask = cc->MakeCKKSPackedPlaintext(mask,1);
auto cMask = cc->EvalMult(cztr6,plaintextMask);
cztr4 = cc->EvalAdd(cMask,cc->GetEncryptionAlgorithm()->EvalAtIndex(cMask,m/4-4,*rotKeysConv));
cztr3Arr[r] = cc->EvalAdd(cztr4,cc->GetEncryptionAlgorithm()->EvalAtIndex(cztr4,m/4-8,*rotKeysConv));
}
rotKeysM->clear();
evalSumRows->clear();
//Compute d*z
auto cdDen = cc->LevelReduce(cd,nullptr,1); //Level 10
cd = cc->LevelReduce(cd,nullptr,3); //Level 12
vector<Ciphertext<DCRTPoly>> cztr4Arr(numCt);
for (size_t r=0; r<numCt; r++){
cZArr[r] = cc->LevelReduce(cZArr[r],nullptr,5); //Level 12
cztr4Arr[r] = cc->EvalMult(cZArr[r],cd);
}
//Computer ztr
vector<Ciphertext<DCRTPoly>> cztrArr(numCt);
for (size_t r=0; r<numCt; r++){
auto cztr = cc->EvalSub(cztr4Arr[r],cztr3Arr[r]);
cztrArr[r] = cc->ModReduce(cztr); //Level 13
}
computation3Time = TOC(t);
TIC(t);
vector<Ciphertext<DCRTPoly>> cWConvArr(numCt);
for (size_t r=0; r<numCt; r++){
cWArr[r] = cc->LevelReduce(cWArr[r],nullptr,4); //Level 8
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
size_t NPow2cur = 1<<(size_t)std::ceil(log2(N1));
std::vector<std::complex<double>> maskW(m/4);
for (size_t v = 0; v < NPow2cur*k*k; v++)
maskW[v] = 1;
for (size_t v = NPow2cur*k*k; v < m/4; v++)
maskW[v] = 0;
Plaintext plaintextW = cc->MakeCKKSPackedPlaintext(maskW,1);
auto cWConv = cc->EvalMult(cWArr[r],plaintextW);
for (size_t j = NPow2cur*k*k; j < m/4; j=j*2 ) {
cWConv = cc->EvalAdd(cWConv,cc->GetEncryptionAlgorithm()->EvalAtIndex(cWConv,m/4-j,*rotKeysConv));
}
cWConvArr[r] = cc->ModReduce(cWConv); //Level 9
}
CompressEvalKeys(*rotKeysConv,1);
vector<Ciphertext<DCRTPoly>> cWVector;
for (size_t r=0; r<numCt; r++){
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
auto temp = SplitIntoSingle(cWConvArr[r], N1, k, *rotKeysConv); //Level 10
cWVector.insert(cWVector.end(),temp->begin(),temp->end()); //Level 10
cWArr[r] = cc->LevelReduce(cWArr[r],nullptr,1); //Level 9
}
computation4Time = TOC(t);
TIC(t);
std::vector<std::vector<Ciphertext<DCRTPoly>>> strVector1(sizeS);
for (size_t i = 0; i < sizeS; i++)
strVector1[i] = std::vector<Ciphertext<DCRTPoly>>(N);
// Compute diag(w) S
for (size_t s = 0; s < sizeS; s++) {
#pragma omp parallel for
for (size_t i = 0; i < N; i++)
{
strVector1[s][i] = cc->EvalMultNoRelin(cWVector[i],S[s][i]);
strVector1[s][i] = cc->ModReduce(strVector1[s][i]); // Level 11
}
}
// Compute X^T (diag(w) S)
std::vector<std::vector<Ciphertext<DCRTPoly>>> strVector2(sizeS);
for (size_t s = 0; s < sizeS; s++) {
strVector2[s] = std::vector<Ciphertext<DCRTPoly>>(k);
}
for (size_t j = 0; j < k; j++)
{
for (size_t s = 0; s < sizeS; s++) {
std::vector<Ciphertext<DCRTPoly>> tempVector(N);
#pragma omp parallel for
for (size_t i = 0; i < N; i++)
{
tempVector[i] = cc->EvalMultNoRelin(strVector1[s][i],cX1[i][j]);
}
//std::cerr << "passed mult" << std::endl;
strVector2[s][j] = BinaryTreeAdd(tempVector);
tempVector.clear();
//std::cerr << "passed binary tree" << std::endl;
strVector2[s][j] = cc->ModReduce(strVector2[s][j]); //Level 12
}
}
// Compute B X^T (diag(w) S)
std::vector<std::vector<Ciphertext<DCRTPoly>>> strVector3(sizeS);
for (size_t s = 0; s < sizeS; s++)
strVector3[s] = std::vector<Ciphertext<DCRTPoly>>(k);
#pragma omp parallel for
for(size_t i = 0; i < k; i++) {
(*cB1)[i][0] = cc->LevelReduce((*cB1)[i][0],nullptr,3);
for(size_t j=1; j < k; j++) {
(*cB1)[i][j] = cc->LevelReduce((*cB1)[i][j],nullptr,3);
}
}
for (size_t s = 0; s < sizeS; s++) {
#pragma omp parallel for
for(size_t i = 0; i < k; i++) {
auto temp = cc->EvalMultAndRelinearize((*cB1)[i][0],strVector2[s][0]);
for(size_t j=1; j < k; j++) {
temp = cc->EvalAdd(temp,cc->EvalMultAndRelinearize((*cB1)[i][j],strVector2[s][j]));
}
temp = cc->ModReduce(temp); //Level 13
strVector3[s][i] = temp;
}
}
#pragma omp parallel for
for(size_t i = 0; i < N; i++) {
cX1[i][0] = cc->LevelReduce(cX1[i][0],nullptr,2);
for(size_t j=1; j < k; j++) {
cX1[i][j] = cc->LevelReduce(cX1[i][j],nullptr,2);
}
}
// Compute X B X^T (diag(w) S)
for (size_t s = 0; s < sizeS; s++) {
#pragma omp parallel for
for(size_t i = 0; i < N; i++) {
auto temp = cc->EvalMultNoRelin(cX1[i][0],strVector3[s][0]);
for(size_t j=1; j < k; j++) {
temp = cc->EvalAdd(temp,cc->EvalMultNoRelin(cX1[i][j],strVector3[s][j]));
}
strVector1[s][i] = temp;
}
}
// Compute d S - X B X^T (diag(w) S)
//auto cd12 = cd;
cd = cc->LevelReduce(cd,nullptr,1); //Level 13
for (size_t s = 0; s < sizeS; s++) {
#pragma omp parallel for
for (size_t i = 0; i < N; i++)
{
S[s][i] = cc->LevelReduce(S[s][i],nullptr,3);
strVector1[s][i] = cc->EvalSub(cc->EvalMultNoRelin(cd,S[s][i]),strVector1[s][i]); // Level 13
strVector1[s][i] = cc->ModReduce(strVector1[s][i]); // Level 14
}
}
for (size_t i = 0; i < S.size(); i++)
S.clear();
S.clear();
for (size_t i = 0; i < cX1.size(); i++)
cX1.clear();
cX1.clear();
computation5Time = TOC(t);
TIC(t);
// Compute str * str
std::vector<std::vector<Ciphertext<DCRTPoly>>> invVarD(sizeS);
for (size_t s = 0; s < sizeS; s++)
invVarD[s] = std::vector<Ciphertext<DCRTPoly>>(N);
for (size_t s = 0; s < sizeS; s++) {
#pragma omp parallel for
for (size_t i = 0; i < N; i++)
{
invVarD[s][i] = cc->EvalMultNoRelin(strVector1[s][i],strVector1[s][i]); // Level 14
//invVarD[i] = cc->ModReduce(invVarD[i]); // Level 15
}
}
// Compute d*d
auto cd12 = cc->LevelReduce(cd,nullptr,1); // Level 14
auto cd2 = cc->EvalMult(cd12,cd12);
cd2 = cc->ModReduce(cd2); // Level 15
// Compute (w^T) (str*str)
#pragma omp parallel for
for (size_t i = 0; i < N; i++)
{
cWVector[i] = cc->LevelReduce(cWVector[i],nullptr,3); // Level 14
//scaling down
auto temp = cc->EvalMult(cWVector[i],double(scalingFactorD));
temp = cc->ModReduce(temp); // Level 15
//auto temp = cc->EvalMultNoRelin(cd2,cWVector[i]);
for (size_t s = 0; s < sizeS; s++) {
invVarD[s][i] = cc->EvalMultNoRelin(temp,invVarD[s][i]);
}
}
std::vector<Ciphertext<DCRTPoly>> invVar(sizeS);
for (size_t s = 0; s < sizeS; s++)
invVar[s] = BinaryTreeAdd(invVarD[s]);
for (size_t s = 0; s < sizeS; s++)
invVarD[s].clear();
invVarD.clear();
computation6Time = TOC(t);
TIC(t);
CompressEvalKeys(*rotKeysConv,4);
// Compute w*ztr
vector<Ciphertext<DCRTPoly>> beta1Arr(numCt);
for (size_t r=0; r<numCt; r++){
cWArr[r] = cc->LevelReduce(cWArr[r],nullptr,3); //Level 12
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
size_t NPow2cur = 1<<(size_t)std::ceil(log2(N1));
// clears the mask to prepare for the conversion
std::vector<std::complex<double>> maskW2(m/4);
for (size_t v = 0; v < NPow2cur*k*k; v++)
maskW2[v] = scalingFactorN;
for (size_t v = NPow2cur*k*k; v < m/4; v++)
maskW2[v] = 0;
Plaintext plaintextW2 = cc->MakeCKKSPackedPlaintext(maskW2,1);
auto cWConv2 = cc->EvalMult(cWArr[r],plaintextW2);
cWConv2 = cc->ModReduce(cWConv2);//Level 13
auto beta1 = cc->EvalMult(cWConv2,cztrArr[r]);
for (size_t j = NPow2cur*k*k; j < m/4; j=j*2 ) {
beta1 = cc->EvalAdd(beta1,cc->GetEncryptionAlgorithm()->EvalAtIndex(beta1,m/4-j,*rotKeysConv));
}
beta1Arr[r] = cc->ModReduce(beta1); //Level 14
}
CompressEvalKeys(*rotKeysConv,1);
//CompressEvalKeys(*rotKeysBabyGiant,6);
vector<Ciphertext<DCRTPoly>> betaVector;
for (size_t r=0; r<numCt; r++){
size_t N1;
if ((r+1)*n < N*k2)
N1 = n/(k2);
else
N1 = N - r*n/k2;
auto temp = SplitIntoSingle(beta1Arr[r], N1, k, *rotKeysConv); //Level 15
betaVector.insert(betaVector.end(),temp->begin(),temp->end());
}
rotKeysConv->clear();
computation7Time = TOC(t);
TIC(t);
// Compute (w*ztr)^T str
for (size_t s = 0; s < sizeS; s++) {
#pragma omp parallel for
for (size_t i = 0; i < N; i++)
{
strVector1[s][i] = cc->LevelReduce(strVector1[s][i],nullptr,1);// Level 15
strVector1[s][i] = cc->EvalMultNoRelin(betaVector[i],strVector1[s][i]);
}
}
betaVector.clear();
std::vector<Ciphertext<DCRTPoly>> beta(sizeS);
for (size_t s = 0; s < sizeS; s++)
beta[s] = BinaryTreeAdd(strVector1[s]);
for (size_t s = 0; s < sizeS; s++)
strVector1[s].clear();
strVector1.clear();
computation8Time = TOC(t);
std::vector<Plaintext> pInvVar(sizeS);
std::vector<Plaintext> pBeta(sizeS);
Plaintext pD;
TIC(t);
cc->Decrypt(keyPair.secretKey, cd2 , &pD);
for (size_t s = 0; s < sizeS; s++) {
cc->Decrypt(keyPair.secretKey, invVar[s] , &(pInvVar[s]));
cc->Decrypt(keyPair.secretKey, beta[s] , &(pBeta[s]));
}
decryptionTime = TOC(t);
std::vector<double> zval(headersS.size());
std::vector<double> pval(headersS.size());
std::vector<double> betaval(headersS.size());
std::vector<std::complex<double>> num(headersS.size());
std::vector<std::complex<double>> den(headersS.size());
for (size_t s = 0; s < sizeS; s++) {
for (size_t i = 0; i < m/4; i++) {
if (s*m/4 + i < headersS.size()) {
num[s*m/4 + i] = pow(scalingFactor,-1)*pow(scalingFactorN,-1)*pBeta[s]->GetCKKSPackedValue()[i];
den[s*m/4 + i] = pow(scalingFactor,-2)*pow(scalingFactorD,-1)*pInvVar[s]->GetCKKSPackedValue()[i];
betaval[s*m/4 + i] = num[s*m/4 + i].real()/den[s*m/4 + i].real();
zval[s*m/4 + i] = num[s*m/4 + i].real()/sqrt(den[s*m/4 + i].real()*pD->GetCKKSPackedValue()[0].real());
pval[s*m/4 + i] = 2*sf(abs(zval[s*m/4 + i]));
if (pval[s*m/4 + i] == 0)
pval[s*m/4 + i] = BS(zval[s*m/4 + i]);
}
}
}
ofstream myfile;
myfile.open(SNPDir + "/" + pValue);
myfile.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfile << headersS[i] << "\t" << pval[i] << std::endl;
}
myfile.close();
ofstream myfilez;
myfilez.open(SNPDir + "/" + "zvalue.txt");
myfilez.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfilez << headersS[i] << "\t" << zval[i] << std::endl;
}
myfilez.close();
ofstream myfileb;
myfileb.open(SNPDir + "/" + "betavalue.txt");
myfileb.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfileb << headersS[i] << "\t" << betaval[i] << std::endl;
}
myfileb.close();
ofstream myfilenum;
myfilenum.open(SNPDir + "/" + "num.txt");
myfilenum.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfilenum << headersS[i] << "\t" << num[i] << std::endl;
}
myfilenum.close();
ofstream myfileden;
myfileden.open(SNPDir + "/" + "den.txt");
myfileden.precision(10);
for(uint32_t i = 0; i < headersS.size(); i++) {
myfileden << headersS[i] << "\t" << den[i] << std::endl;
}
myfileden.close();
computationTime = computation1Time + computation2Time + computation3Time + computation4Time +
computation5Time + computation6Time + computation7Time + computation8Time;
std::cout << "\nKey Generation Time: \t\t" << keyGenTime/1000 << " s" << std::endl;
std::cout << "Encoding and Encryption Time: \t" << encryptionTime/1000 << " s" << std::endl;
std::cout << "Computation Time: \t\t" << computationTime/1000 << " s" << std::endl;
std::cout << "Decryption & Decoding Time: \t" << decryptionTime/1000 << " s" << std::endl;
endToEndTime = TOC(tAll);
std::cout << "\nEnd-to-end Runtime: \t\t" << endToEndTime/1000 << " s" << std::endl;
ofstream myfileRuntime;
myfileRuntime.open(SNPDir + "/" + Runtime);
myfileRuntime << "Key Generation Time: \t\t" << keyGenTime/1000 << " s" << std::endl;
myfileRuntime << "Encoding and Encryption Time: \t" << encryptionTime/1000 << " s" << std::endl;
myfileRuntime << "Computation Time: \t\t" << computationTime/1000 << " s" << std::endl;
myfileRuntime << "Decryption & Decoding Time: \t" << decryptionTime/1000 << " s" << std::endl;
myfileRuntime << "End-to-end Runtime: \t\t" << endToEndTime/1000 << " s" << std::endl;
myfileRuntime.close();
}
Ciphertext<DCRTPoly> zExpand(const Ciphertext<DCRTPoly> p, const Ciphertext<DCRTPoly> y) {
CryptoContext<DCRTPoly> cc = p->GetCryptoContext();
//Compute p-0.5
auto pAdj = cc->EvalSub(p,double(0.25));
pAdj = cc->EvalSub(pAdj,double(0.25));
//Compute (p-0.5)^2; level 4
auto p2 = cc->EvalMult(pAdj,pAdj);
p2 = cc->ModReduce(p2);
//Compute (p-0.5)^3; level 5
auto p1 = cc->LevelReduce(pAdj,nullptr);
auto p3 = cc->EvalMult(p2,p1);
p3 = cc->ModReduce(p3);
//Compute (p-0.5)^4; level 5
auto p4 = cc->EvalMult(p2,p2);
p4 = cc->ModReduce(p4);
//Compute (p-0.5)^5; level 6
p1 = cc->LevelReduce(p1,nullptr);
auto p5 = cc->EvalMult(p4,p1);
p5 = cc->ModReduce(p5);
//Compute (p-0.5)^6; level 6
p2 = cc->LevelReduce(p2,nullptr);
auto p6 = cc->EvalMult(p4,p2);
p6 = cc->ModReduce(p6);
//Compute (p-0.5)^7; level 6
auto p7 = cc->EvalMult(p4,p3);
p7 = cc->ModReduce(p7);
//Compute (p-0.5)^8; level 6
auto p8 = cc->EvalMult(p4,p4);
p8 = cc->ModReduce(p8);
//Compute -2 + 4y
auto factor = cc->EvalMult(y,4);
factor = cc->ModReduce(factor); //level 1
factor = cc->EvalSub(factor,2);
auto t0 = cc->LevelReduce(factor,nullptr,6); //level 7
//Compute (-8 + 16y)*(p-0.5)^2
auto t1 = cc->EvalMult(factor,4);
t1 = cc->ModReduce(t1); // level 2
t1 = cc->LevelReduce(t1,nullptr,3); //level 5
t1 = cc->EvalMult(t1,p2);
t1 = cc->ModReduce(t1); //level 6
t1 = cc->LevelReduce(t1,nullptr); //level 7
//Compute (32/3)*(p-0.5)^3
auto t2 = cc->EvalMult(p3,double(32/3));
t2 = cc->ModReduce(t2); // level 6
t2 = cc->LevelReduce(t2,nullptr); //level 7
//Compute (-32 + 64y)*(p-0.5)^4
auto t3 = cc->EvalMult(factor,16);
t3 = cc->ModReduce(t3); // level 2
t3 = cc->LevelReduce(t3,nullptr,3); //level 5
t3 = cc->EvalMult(t3,p4);
t3 = cc->ModReduce(t3); //level 6
t3 = cc->LevelReduce(t3,nullptr); //level 7
//Compute (256/5)*(p-0.5)^5
auto t4 = cc->EvalMult(p5,double(256/5));
t4 = cc->ModReduce(t4); // level 7
//Compute (-128 + 256y)*(p-0.5)^6
auto t5 = cc->EvalMult(factor,64);
t5 = cc->ModReduce(t5); // level 2
t5 = cc->LevelReduce(t5,nullptr,4); //level 6
t5 = cc->EvalMult(t5,p6);
t5 = cc->ModReduce(t5); //level 7
//Compute (1536/7)*(p-0.5)^7
auto t6 = cc->EvalMult(p7,double(1536/7));
t6 = cc->ModReduce(t6); // level 7
//Compute (-512 + 1024y)*(p-0.5)^8
auto t7 = cc->EvalMult(factor,256);
t7 = cc->ModReduce(t7); // level 2
t7 = cc->LevelReduce(t7,nullptr,4); //level 6
t7 = cc->EvalMult(t7,p8);
t7 = cc->ModReduce(t7); //level 7
auto z = cc->EvalAdd(t0,t1);
z = cc->EvalSub(z,t2);
z = cc->EvalAdd(z,t3);
z = cc->EvalSub(z,t4);
z = cc->EvalAdd(z,t5);
z = cc->EvalSub(z,t6);
z = cc->EvalAdd(z,t7);
return z;
}
shared_ptr<std::vector<std::vector<Ciphertext<DCRTPoly>>>> MatrixInverse(const Ciphertext<DCRTPoly> cM, size_t k,
CiphertextImpl<DCRTPoly> &B, CiphertextImpl<DCRTPoly> &d, const std::map<usint, LPEvalKey<DCRTPoly>> &evalSum,
const std::map<usint, LPEvalKey<DCRTPoly>> &rotKeys, const std::map<usint, LPEvalKey<DCRTPoly>> &evalSumRows) {
auto cc = cM->GetCryptoContext();
const shared_ptr<LPCryptoParameters<DCRTPoly>> cryptoParams = cM->GetCryptoParameters();
const auto elementParams = cryptoParams->GetElementParams();
usint m = elementParams->GetCyclotomicOrder();
size_t kSquare = k*k;
std::vector<std::complex<double>> mask(m/4);
for (size_t i = 0; i < mask.size(); i++)
{
if (i % kSquare == 0)
mask[i] = 1;
else
mask[i] = 0;
}
Plaintext plaintext = cc->MakeCKKSPackedPlaintext(mask,1);
std::vector<Ciphertext<DCRTPoly>> cMRotations(k*k-1);
auto precomputedcM = KeySwitchPrecompute(cM);
#pragma omp parallel for
for (size_t i = 1; i < k*k; i++) {
usint autoIndex = FindAutomorphismIndex2nComplex(i,m);
if (i < 3)
cMRotations[i-1] = HoistedAutomorphism(evalSum.find(autoIndex)->second,cM,precomputedcM,autoIndex);
else if ((i == 4) || (i==8))
cMRotations[i-1] = HoistedAutomorphism(evalSumRows.find(autoIndex)->second,cM,precomputedcM,autoIndex);
else
cMRotations[i-1] = HoistedAutomorphism(rotKeys.find(autoIndex)->second,cM,precomputedcM,autoIndex);
cMRotations[i-1] = cc->ModReduce(cMRotations[i-1]);
// clear all values that are not used
cMRotations[i-1] = cc->EvalMult(cMRotations[i-1],plaintext);
cMRotations[i-1] = cc->ModReduce(cMRotations[i-1]);
}
auto cMReduced = cc->ModReduce(cM);
cMReduced = cc->LevelReduce(cMReduced,nullptr);
auto a11a22 = cc->EvalMult(cMReduced,cMRotations[4]);
a11a22 = cc->ModReduce(a11a22);
auto a11a23 = cc->EvalMult(cMReduced,cMRotations[5]);
a11a23 = cc->ModReduce(a11a23);
auto a11a24 = cc->EvalMult(cMReduced,cMRotations[6]);
a11a24 = cc->ModReduce(a11a24);
auto a12a12 = cc->EvalMult(cMRotations[0],cMRotations[0]);
a12a12 = cc->ModReduce(a12a12);
auto a12a13 = cc->EvalMult(cMRotations[0],cMRotations[1]);
a12a13 = cc->ModReduce(a12a13);
auto a12a14 = cc->EvalMult(cMRotations[0],cMRotations[2]);
a12a14 = cc->ModReduce(a12a14);
auto a13a13 = cc->EvalMult(cMRotations[1],cMRotations[1]);
a13a13 = cc->ModReduce(a13a13);
auto a13a14 = cc->EvalMult(cMRotations[1],cMRotations[2]);
a13a14 = cc->ModReduce(a13a14);
auto a14a14 = cc->EvalMult(cMRotations[2],cMRotations[2]);
a14a14 = cc->ModReduce(a14a14);
auto a22a33 = cc->EvalMult(cMRotations[4],cMRotations[9]);
a22a33 = cc->ModReduce(a22a33);
auto a22a34 = cc->EvalMult(cMRotations[4],cMRotations[10]);
a22a34 = cc->ModReduce(a22a34);
auto a22a44 = cc->EvalMult(cMRotations[4],cMRotations[14]);
a22a44 = cc->ModReduce(a22a44);
auto a23a23 = cc->EvalMult(cMRotations[5],cMRotations[5]);
a23a23 = cc->ModReduce(a23a23);
auto a23a24 = cc->EvalMult(cMRotations[5],cMRotations[6]);
a23a24 = cc->ModReduce(a23a24);
auto a23a34 = cc->EvalMult(cMRotations[5],cMRotations[10]);
a23a34 = cc->ModReduce(a23a34);
auto a23a44 = cc->EvalMult(cMRotations[5],cMRotations[14]);
a23a44 = cc->ModReduce(a23a44);
auto a24a24 = cc->EvalMult(cMRotations[6],cMRotations[6]);
a24a24 = cc->ModReduce(a24a24);
auto a24a33 = cc->EvalMult(cMRotations[6],cMRotations[9]);
a24a33 = cc->ModReduce(a24a33);
auto a24a34 = cc->EvalMult(cMRotations[6],cMRotations[10]);
a24a34 = cc->ModReduce(a24a34);
auto a33a44 = cc->EvalMult(cMRotations[9],cMRotations[14]);
a33a44 = cc->ModReduce(a33a44);
auto a34a34 = cc->EvalMult(cMRotations[10],cMRotations[10]);
a34a34 = cc->ModReduce(a34a34);
/*
* det = a[1,4]*a[1,4]*a[2,3]*a[2,3] - 2*a[1,3]*a[1,4]*a[2,3]*a[2,4] + a[1,3]*a[1,3]*a[2,4]*a[2,4] -
a[1,4]*a[1,4]*a[2,2]*a[3,3] + 2*a[1,2]*a[1,4]*a[2,4]*a[3,3] - a[1,1]*a[2,4]*a[2,4]*a[3,3] +
2*a[1,3]*a[1,4]*a[2,2]*a[3,4] - 2*a[1,2]*a[1,4]*a[2,3]*a[3,4] - 2*a[1,2]*a[1,3]*a[2,4]*a[3,4] +
2*a[1,1]*a[2,3]*a[2,4]*a[3,4] + a[1,2]*a[1,2]*a[3,4]*a[3,4] - a[1,1]*a[2,2]*a[3,4]*a[3,4] -
a[1,3]*a[1,3]*a[2,2]*a[4,4] + 2*a[1,2]*a[1,3]*a[2,3]*a[4,4] - a[1,1]*a[2,3]*a[2,3]*a[4,4] -
a[1,2]*a[1,2]*a[3,3]*a[4,4] + a[1,1]*a[2,2]*a[3,3]*a[4,4]
*/
auto cd = cc->EvalMultNoRelin(a14a14,a23a23);
auto temp = cc->EvalMultNoRelin(a13a14,a23a24);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalSub(cd,temp);
cd = cc->EvalAdd(cd,cc->EvalMultNoRelin(a13a13,a24a24));
cd = cc->EvalSub(cd,cc->EvalMultNoRelin(a14a14,a22a33));
//std::cerr << *cd << std::endl;
temp = cc->EvalMultNoRelin(a12a14,a24a33);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalAdd(cd,temp);
cd = cc->EvalSub(cd,cc->EvalMultNoRelin(a11a24,a24a33));
temp = cc->EvalMultNoRelin(a13a14,a22a34);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalAdd(cd,temp);
temp = cc->EvalMultNoRelin(a12a14,a23a34);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalSub(cd,temp);
temp = cc->EvalMultNoRelin(a12a13,a24a34);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalSub(cd,temp);
temp = cc->EvalMultNoRelin(a11a23,a24a34);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalAdd(cd,temp);
cd = cc->EvalAdd(cd,cc->EvalMultNoRelin(a12a12,a34a34));
cd = cc->EvalSub(cd,cc->EvalMultNoRelin(a11a22,a34a34));
cd = cc->EvalSub(cd,cc->EvalMultNoRelin(a13a13,a22a44));
temp = cc->EvalMultNoRelin(a12a13,a23a44);
temp = cc->EvalAdd(temp,temp);
cd = cc->EvalAdd(cd,temp);
cd = cc->EvalSub(cd,cc->EvalMultNoRelin(a11a23,a23a44));
cd = cc->EvalSub(cd,cc->EvalMultNoRelin(a12a12,a33a44));
cd = cc->EvalAdd(cd,cc->EvalMultNoRelin(a11a22,a33a44));
cd = cc->Relinearize(cd);
Ciphertext<DCRTPoly> cdOld;
for (size_t i = 1; i < k*k; i = i*2)
{
cdOld = cd;
if (i < 3)
cd = cc->GetEncryptionAlgorithm()->EvalAtIndex(cdOld,i,evalSum);
else if ( (i == 4) || (i == 8) )
cd = cc->GetEncryptionAlgorithm()->EvalAtIndex(cdOld,i,evalSumRows);
else
cd = cc->GetEncryptionAlgorithm()->EvalAtIndex(cdOld,i,rotKeys);
cd = cc->EvalAdd(cdOld,cd);
}
cd = cc->ModReduce(cd);
d = *cd;
/*
# Adjoint of a 4 by 4 symmetric matrix
adjoin_4by4_sim_matrix <- function(a){
b11 = -a[4,4]*a[2,3]*a[2,3] + 2*a[2,4]*a[3,4]*a[2,3] - a[2,2]*a[3,4]*a[3,4] - a[2,4]*a[2,4]*a[3,3] + a[2,2]*a[3,3]*a[4,4]
b12 = a[1,2]*a[3,4]*a[3,4] - a[1,4]*a[2,3]*a[3,4] - a[1,3]*a[2,4]*a[3,4] + a[1,4]*a[2,4]*a[3,3] + a[1,3]*a[2,3]*a[4,4] - a[1,2]*a[3,3]*a[4,4]
b13 = a[1,3]*a[2,4]*a[2,4] - a[1,4]*a[2,3]*a[2,4] - a[1,2]*a[3,4]*a[2,4] + a[1,4]*a[2,2]*a[3,4] - a[1,3]*a[2,2]*a[4,4] + a[1,2]*a[2,3]*a[4,4]
b14 = a[1,4]*a[2,3]*a[2,3] - a[1,3]*a[2,4]*a[2,3] - a[1,2]*a[3,4]*a[2,3] - a[1,4]*a[2,2]*a[3,3] + a[1,2]*a[2,4]*a[3,3] + a[1,3]*a[2,2]*a[3,4]
b22 = -a[4,4]*a[1,3]*a[1,3] + 2*a[1,4]*a[3,4]*a[1,3] - a[1,1]*a[3,4]*a[3,4] - a[1,4]*a[1,4]*a[3,3] + a[1,1]*a[3,3]*a[4,4]
b23 = a[2,3]*a[1,4]*a[1,4] - a[1,3]*a[2,4]*a[1,4] - a[1,2]*a[3,4]*a[1,4] + a[1,1]*a[2,4]*a[3,4] + a[1,2]*a[1,3]*a[4,4] - a[1,1]*a[2,3]*a[4,4]
b24 = a[2,4]*a[1,3]*a[1,3] - a[1,4]*a[2,3]*a[1,3] - a[1,2]*a[3,4]*a[1,3] + a[1,2]*a[1,4]*a[3,3] - a[1,1]*a[2,4]*a[3,3] + a[1,1]*a[2,3]*a[3,4]
b33 = -a[4,4]*a[1,2]*a[1,2] + 2*a[1,4]*a[2,4]*a[1,2] - a[1,1]*a[2,4]*a[2,4] - a[1,4]*a[1,4]*a[2,2] + a[1,1]*a[2,2]*a[4,4]
b34 = a[3,4]*a[1,2]*a[1,2] - a[1,4]*a[2,3]*a[1,2] - a[1,3]*a[2,4]*a[1,2] + a[1,3]*a[1,4]*a[2,2] + a[1,1]*a[2,3]*a[2,4] - a[1,1]*a[2,2]*a[3,4]
b44 = -a[3,3]*a[1,2]*a[1,2] + 2*a[1,3]*a[2,3]*a[1,2] - a[1,1]*a[2,3]*a[2,3] - a[1,3]*a[1,3]*a[2,2] + a[1,1]*a[2,2]*a[3,3]
b <- matrix(c(b11,b12,b13,b14,b12,b22,b23,b24,b13,b23,b33,b34,b14,b24,b34,b44), ncol=4, byrow=TRUE)
*/
for (size_t i = 1; i < k*k; i++) {
cMRotations[i-1] = cc->LevelReduce(cMRotations[i-1],nullptr);
}
cMReduced = cc->LevelReduce(cMReduced,nullptr);
temp = cc->EvalMultNoRelin(cMRotations[6],a23a34);
temp = cc->EvalAdd(temp,temp);
auto b11 = temp;
// We can a binary tree approach here and below
b11 = cc->EvalSub(b11,cc->EvalMultNoRelin(cMRotations[14],a23a23));
b11 = cc->EvalSub(b11,cc->EvalMultNoRelin(cMRotations[4],a34a34));
b11 = cc->EvalSub(b11,cc->EvalMultNoRelin(cMRotations[9],a24a24));
b11 = cc->EvalAdd(b11,cc->EvalMultNoRelin(cMRotations[4],a33a44));
b11 = cc->Relinearize(b11);
auto b12 = cc->EvalMultNoRelin(cMRotations[0],a34a34);
b12 = cc->EvalSub(b12,cc->EvalMultNoRelin(cMRotations[2],a23a34));
b12 = cc->EvalSub(b12,cc->EvalMultNoRelin(cMRotations[1],a24a34));
b12 = cc->EvalAdd(b12,cc->EvalMultNoRelin(cMRotations[2],a24a33));
b12 = cc->EvalAdd(b12,cc->EvalMultNoRelin(cMRotations[1],a23a44));
b12 = cc->EvalSub(b12,cc->EvalMultNoRelin(cMRotations[0],a33a44));
b12 = cc->Relinearize(b12);
auto b13 = cc->EvalMultNoRelin(cMRotations[1],a24a24);
b13 = cc->EvalSub(b13,cc->EvalMultNoRelin(cMRotations[2],a23a24));
b13 = cc->EvalSub(b13,cc->EvalMultNoRelin(cMRotations[0],a24a34));
b13 = cc->EvalAdd(b13,cc->EvalMultNoRelin(cMRotations[2],a22a34));
b13 = cc->EvalSub(b13,cc->EvalMultNoRelin(cMRotations[1],a22a44));
b13 = cc->EvalAdd(b13,cc->EvalMultNoRelin(cMRotations[0],a23a44));
b13 = cc->Relinearize(b13);
auto b14 = cc->EvalMultNoRelin(cMRotations[2],a23a23);
b14 = cc->EvalSub(b14,cc->EvalMultNoRelin(cMRotations[1],a23a24));
b14 = cc->EvalSub(b14,cc->EvalMultNoRelin(cMRotations[0],a23a34));
b14 = cc->EvalSub(b14,cc->EvalMultNoRelin(cMRotations[2],a22a33));
b14 = cc->EvalAdd(b14,cc->EvalMultNoRelin(cMRotations[0],a24a33));
b14 = cc->EvalAdd(b14,cc->EvalMultNoRelin(cMRotations[1],a22a34));
b14 = cc->Relinearize(b14);
temp = cc->EvalMultNoRelin(cMRotations[10],a13a14);
temp = cc->EvalAdd(temp,temp);
auto b22 = temp;
b22 = cc->EvalSub(b22,cc->EvalMultNoRelin(cMRotations[14],a13a13));
b22 = cc->EvalSub(b22,cc->EvalMultNoRelin(cMReduced,a34a34));
b22 = cc->EvalSub(b22,cc->EvalMultNoRelin(cMRotations[9],a14a14));
b22 = cc->EvalAdd(b22,cc->EvalMultNoRelin(cMReduced,a33a44));
b22 = cc->Relinearize(b22);
auto b23 = cc->EvalMultNoRelin(cMRotations[5],a14a14);
b23 = cc->EvalSub(b23,cc->EvalMultNoRelin(cMRotations[6],a13a14));
b23 = cc->EvalSub(b23,cc->EvalMultNoRelin(cMRotations[10],a12a14));
b23 = cc->EvalAdd(b23,cc->EvalMultNoRelin(cMReduced,a24a34));
b23 = cc->EvalAdd(b23,cc->EvalMultNoRelin(cMRotations[14],a12a13));
b23 = cc->EvalSub(b23,cc->EvalMultNoRelin(cMRotations[14],a11a23));
b23 = cc->Relinearize(b23);
// We can a binary tree approach here
auto b24 = cc->EvalMultNoRelin(cMRotations[6],a13a13);
b24 = cc->EvalSub(b24,cc->EvalMultNoRelin(cMRotations[5],a13a14));
b24 = cc->EvalSub(b24,cc->EvalMultNoRelin(cMRotations[10],a12a13));
b24 = cc->EvalAdd(b24,cc->EvalMultNoRelin(cMRotations[9],a12a14));
b24 = cc->EvalSub(b24,cc->EvalMultNoRelin(cMRotations[9],a11a24));
b24 = cc->EvalAdd(b24,cc->EvalMultNoRelin(cMRotations[10],a11a23));
b24 = cc->Relinearize(b24);
temp = cc->EvalMultNoRelin(cMRotations[6],a12a14);
temp = cc->EvalAdd(temp,temp);
auto b33 = temp;
b33 = cc->EvalSub(b33,cc->EvalMultNoRelin(cMRotations[14],a12a12));
b33 = cc->EvalSub(b33,cc->EvalMultNoRelin(cMReduced,a24a24));
b33 = cc->EvalSub(b33,cc->EvalMultNoRelin(cMRotations[4],a14a14));
b33 = cc->EvalAdd(b33,cc->EvalMultNoRelin(cMReduced,a22a44));
b33 = cc->Relinearize(b33);
// We can a binary tree approach here
auto b34 = cc->EvalMultNoRelin(cMRotations[10],a12a12);
b34 = cc->EvalSub(b34,cc->EvalMultNoRelin(cMRotations[5],a12a14));
b34 = cc->EvalSub(b34,cc->EvalMultNoRelin(cMRotations[6],a12a13));
b34 = cc->EvalAdd(b34,cc->EvalMultNoRelin(cMRotations[4],a13a14));
b34 = cc->EvalAdd(b34,cc->EvalMultNoRelin(cMReduced,a23a24));
b34 = cc->EvalSub(b34,cc->EvalMultNoRelin(cMReduced,a22a34));
b34 = cc->Relinearize(b34);
temp = cc->EvalMultNoRelin(cMRotations[5],a12a13);
temp = cc->EvalAdd(temp,temp);
auto b44 = temp;
// We can a binary tree approach here
b44 = cc->EvalSub(b44,cc->EvalMultNoRelin(cMRotations[9],a12a12));
b44 = cc->EvalSub(b44,cc->EvalMultNoRelin(cMReduced,a23a23));
b44 = cc->EvalSub(b44,cc->EvalMultNoRelin(cMRotations[4],a13a13));
b44 = cc->EvalAdd(b44,cc->EvalMultNoRelin(cMReduced,a22a33));
b44 = cc->Relinearize(b44);
shared_ptr<std::vector<std::vector<Ciphertext<DCRTPoly>>>> B1(new std::vector<std::vector<Ciphertext<DCRTPoly>>>(k));
// We can use a binary tree approach here
auto b = cc->EvalAdd(b11,cc->GetEncryptionAlgorithm()->EvalAtIndex(b12,k*k-1,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b13,k*k-2,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b14,k*k-3,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b12,k*k-4,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b22,k*k-5,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b23,k*k-6,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b24,k*k-7,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b13,k*k-8,evalSumRows));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b23,k*k-9,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b33,k*k-10,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b34,k*k-11,rotKeys));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b14,k*k-12,evalSumRows));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b24,k*k-13,rotKeys));
//b = cc->EvalAdd(b,cc->EvalAtIndex(b34,k*k-14));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b34,k*k-14,evalSum));
//b = cc->EvalAdd(b,cc->EvalAtIndex(b44,k*k-15));
b = cc->EvalAdd(b,cc->GetEncryptionAlgorithm()->EvalAtIndex(b44,k*k-15,evalSum));
b = cc->ModReduce(b);
size_t k2 = k*k;
for (size_t i = 0; i<k; i++)
(*B1)[i] = std::vector<Ciphertext<DCRTPoly>>(k);
(*B1)[0][0] = cc->ModReduce(CloneCiphertext(b11,k2,rotKeys, evalSumRows));
(*B1)[0][1] = cc->ModReduce(CloneCiphertext(b12,k2,rotKeys, evalSumRows));
(*B1)[0][2] = cc->ModReduce(CloneCiphertext(b13,k2,rotKeys, evalSumRows));
(*B1)[0][3] = cc->ModReduce(CloneCiphertext(b14,k2,rotKeys, evalSumRows));
(*B1)[1][0] = cc->ModReduce(CloneCiphertext(b12,k2,rotKeys, evalSumRows));
(*B1)[1][1] = cc->ModReduce(CloneCiphertext(b22,k2,rotKeys, evalSumRows));
(*B1)[1][2] = cc->ModReduce(CloneCiphertext(b23,k2,rotKeys, evalSumRows));
(*B1)[1][3] = cc->ModReduce(CloneCiphertext(b24,k2,rotKeys, evalSumRows));
(*B1)[2][0] = cc->ModReduce(CloneCiphertext(b13,k2,rotKeys, evalSumRows));
(*B1)[2][1] = cc->ModReduce(CloneCiphertext(b23,k2,rotKeys, evalSumRows));
(*B1)[2][2] = cc->ModReduce(CloneCiphertext(b33,k2,rotKeys, evalSumRows));
(*B1)[2][3] = cc->ModReduce(CloneCiphertext(b34,k2,rotKeys, evalSumRows));
(*B1)[3][0] = cc->ModReduce(CloneCiphertext(b14,k2,rotKeys, evalSumRows));
(*B1)[3][1] = cc->ModReduce(CloneCiphertext(b24,k2,rotKeys, evalSumRows));
(*B1)[3][2] = cc->ModReduce(CloneCiphertext(b34,k2,rotKeys, evalSumRows));
(*B1)[3][3] = cc->ModReduce(CloneCiphertext(b44,k2,rotKeys, evalSumRows));
B = *b;
return B1;
}
Ciphertext<DCRTPoly> CloneCiphertext(const Ciphertext<DCRTPoly> ciphertext, size_t size,
const std::map<usint,LPEvalKey<DCRTPoly>> &rotKeys, const std::map<usint,LPEvalKey<DCRTPoly>> &evalSumRows) {
Ciphertext<DCRTPoly> answer(new CiphertextImpl<DCRTPoly>(*ciphertext));
auto cc = ciphertext->GetCryptoContext();
for (size_t i = 1; i < size; i = i*2) {
if ( ((size-i) == 4) || ((size-i) == 8) )
answer = cc->EvalAdd(answer,cc->GetEncryptionAlgorithm()->EvalAtIndex(answer,size-i,evalSumRows));
else
answer = cc->EvalAdd(answer,cc->GetEncryptionAlgorithm()->EvalAtIndex(answer,size-i,rotKeys));
}
return answer;
}
shared_ptr<std::vector<Ciphertext<DCRTPoly>>> SplitIntoSingle(const Ciphertext<DCRTPoly> c, size_t N, size_t k,
const std::map<usint, LPEvalKey<DCRTPoly>> &rotKeys){
auto cc = c->GetCryptoContext();
const shared_ptr<LPCryptoParameters<DCRTPoly>> cryptoParams = c->GetCryptoParameters();
const auto elementParams = cryptoParams->GetElementParams();
usint m = elementParams->GetCyclotomicOrder();
shared_ptr<std::vector<Ciphertext<DCRTPoly>>> cVector(new std::vector<Ciphertext<DCRTPoly>>(N));
size_t k2 = k*k;
size_t NPow2k = k2*(1<<(size_t)std::ceil(log2(N)));
#pragma omp parallel for
for (size_t i = 0; i < N; i++) {
std::vector<std::complex<double>> mask(m/4);
for (size_t v = 0; v < mask.size(); v++)
{
if (((v % (NPow2k)) >= i*k2) && ((v % (NPow2k)) < (i+1)*k2))
mask[v] = 1;
else
mask[v] = 0;
}
Plaintext plaintext = cc->MakeCKKSPackedPlaintext(mask,1);
auto cTemp = cc->EvalMult(c,plaintext);
for (size_t j = k2; j < NPow2k; j=j*2 ) {
cTemp = cc->EvalAdd(cTemp,cc->GetEncryptionAlgorithm()->EvalAtIndex(cTemp,m/4-j,rotKeys));
}
(*cVector)[i] = cc->ModReduce(cTemp);
}
return cVector;
}
Ciphertext<DCRTPoly> BinaryTreeAdd(std::vector<Ciphertext<DCRTPoly>> &vector) {
auto cc = vector[0]->GetCryptoContext();
for(size_t j = 1; j < vector.size(); j=j*2) {
for(size_t i = 0; i<vector.size(); i = i + 2*j) {
if ((i+j)<vector.size())
vector[i] = cc->EvalAdd(vector[i],vector[i+j]);
}
}
Ciphertext<DCRTPoly> result(new CiphertextImpl<DCRTPoly>(*(vector[0])));
return result;
}
void ReadSNPFile(vector<string>& headers, std::vector<std::vector<double>> & dataColumns, std::vector<std::vector<double>> &x, std::vector<double> &y,
string dataFileName, size_t N, size_t M)
{
uint32_t cols = 0;
uint32_t xcols = 3;
string fileName = dataFileName + ".csv";
std::cerr << "file name = " << fileName << std::endl;
ifstream file(fileName);
string line, value;
if(file.good()) {
getline(file, line);
cols = std::count(line.begin(), line.end(), ',') + 1;
stringstream ss(line);
vector<string> result;
size_t tempcounter = 0;
for(uint32_t i = 0; i < cols; i++) {
string substr;
getline(ss, substr, ',');
if ((substr != "") && (i>4) && (i<M+5)) {
headers.push_back(substr);
tempcounter++;
}
}
cols = tempcounter;
}
size_t counter = 0;
while((file.good()) && (counter < N)) {
getline(file, line);
uint32_t curCols = std::count(line.begin(), line.end(), ',') + 1;
//std::cout << numCols << std::endl;
if (curCols > 2) {
stringstream ss(line);
std::vector<double> xrow(xcols);
for(uint32_t i = 0; i < 5; i++) {
string substr;
getline(ss, substr, ',');
if (i==1)
y.push_back(std::stod(substr));
else if (i>1)
xrow[i-2] = std::stod(substr);
}
x.push_back(xrow);
std::vector<double> row(cols);
for(uint32_t i = 5; i < cols + 5; i++) {
string substr;
getline(ss, substr, ',');
if (i < M+5)
{
double val;
val = std::stod(substr);
row[i-5] = val;
}
}
dataColumns.push_back(row);
}
counter++;
}
file.close();
// insert the intercept
for (size_t i = 0; i < x.size(); i++)
{
x[i].insert(x[i].begin(),double(1));
}
std::cout << "Read in data: ";
std::cout << dataFileName << std::endl;
}
void CompressEvalKeys(std::map<usint, LPEvalKey<DCRTPoly>> &ek, size_t level) {
std::map<usint, LPEvalKey<DCRTPoly>>::iterator it;
for ( it = ek.begin(); it != ek.end(); it++ )
{
std::vector<DCRTPoly> b = it->second->GetBVector();
std::vector<DCRTPoly> a = it->second->GetAVector();
for (size_t k = 0; k < a.size(); k++) {
a[k].DropLastElements(level);
b[k].DropLastElements(level);
}
it->second->ClearKeys();
it->second->SetAVector(std::move(a));
it->second->SetBVector(std::move(b));
}
}
shared_ptr<vector<DCRTPoly>> KeySwitchPrecompute(ConstCiphertext<DCRTPoly> cipherText)
{
Ciphertext<DCRTPoly> newCiphertext = cipherText->CloneEmpty();
const shared_ptr<LPCryptoParametersCKKS<DCRTPoly>> cryptoParamsLWE = std::dynamic_pointer_cast<LPCryptoParametersCKKS<DCRTPoly>>(cipherText->GetCryptoParameters());
const std::vector<DCRTPoly> &c = cipherText->GetElements();
uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
shared_ptr<std::vector<DCRTPoly>> digitsC2(new std::vector<DCRTPoly>(c[1].CRTDecompose(relinWindow)));
return digitsC2;
}
Ciphertext<DCRTPoly> HoistedAutomorphism(const LPEvalKey<DCRTPoly> ek,
ConstCiphertext<DCRTPoly> cipherText, const shared_ptr<vector<DCRTPoly>> digits, const usint index)
{
Ciphertext<DCRTPoly> newCiphertext = cipherText->CloneEmpty();
const shared_ptr<LPCryptoParametersCKKS<DCRTPoly>> cryptoParamsLWE = std::dynamic_pointer_cast<LPCryptoParametersCKKS<DCRTPoly>>(ek->GetCryptoParameters());
LPEvalKeyRelin<DCRTPoly> evalKey = std::static_pointer_cast<LPEvalKeyRelinImpl<DCRTPoly>>(ek);
const std::vector<DCRTPoly> &c = cipherText->GetElements();
std::vector<DCRTPoly> b = evalKey->GetBVector();
std::vector<DCRTPoly> a = evalKey->GetAVector();
size_t towersToDrop = b[0].GetParams()->GetParams().size() - c[0].GetParams()->GetParams().size();
for (size_t k = 0; k < b.size(); k++) {
a[k].DropLastElements(towersToDrop);
b[k].DropLastElements(towersToDrop);
}
std::vector<DCRTPoly> digitsC2(*digits);
for (size_t i=0; i < digitsC2.size(); i++)
digitsC2[i] = digitsC2[i].AutomorphismTransform(index);
DCRTPoly ct0(c[0].AutomorphismTransform(index));
DCRTPoly ct1;
ct1 = digitsC2[0] * a[0];
ct0 += digitsC2[0] * b[0];
for (usint i = 1; i < digitsC2.size(); ++i)
{
ct0 += digitsC2[i] * b[i];
ct1 += digitsC2[i] * a[i];
}
newCiphertext->SetElements({ ct0, ct1 });
newCiphertext->SetDepth(cipherText->GetDepth());
newCiphertext->SetLevel(cipherText->GetLevel());
newCiphertext->SetScalingFactor(cipherText->GetScalingFactor());
return newCiphertext;
}