最近在复现一篇联邦学习的安全聚合论文,需要用到Shamir秘密共享,就用Java实现了一下,有兴趣的可以看看。
package com.duwei.crypto;
import java.math.BigInteger;
import java.util.Random;
public class SecretShare {
private static BigInteger p;
private static Random random;
public static BigInteger[] share(BigInteger secret, int m, int t) {
BigInteger[] coefficients = new BigInteger[t];
coefficients[0] = secret;
for (int i = 1; i < t; i++) {
coefficients[i] = generateRandomBigInteger();
}
BigInteger[] userShares = new BigInteger[m];
for (int i = 0; i < m; i++) {
userShares[i] = computeShare(coefficients, (i + 1));
}
return userShares;
}
public static BigInteger computeShare(BigInteger[] coefficients, int userIndex) {
BigInteger index = new BigInteger(String.valueOf(userIndex));
int len = coefficients.length;
BigInteger temp = BigInteger.ONE;
BigInteger result = BigInteger.ZERO;
for (int i = 0; i < len; i++) {
BigInteger cur = coefficients[i].multiply(temp);
temp = temp.multiply(index);
result = result.add(cur).mod(p);
}
return result.mod(p);
}
public static BigInteger generateRandomBigInteger() {
BigInteger result;
do {
result = new BigInteger(p.bitLength(), random);
} while ((result.compareTo(p) >= 0) && (result.compareTo(BigInteger.ZERO) != 0));
return result;
}
public static void init(int bitLen) {
random = new Random();
p = BigInteger.probablePrime(bitLen,random);
}
public static BigInteger reconstruction(BigInteger[] shares, int t) throws Exception {
int n = shares.length;
if (t > n) {
throw new Exception("你当前收集的秘密份额不足以恢复秘密");
}
BigInteger result = new BigInteger("0");
for (int i = 0; i < t; i++) {
result = result.add(interpolation(shares, i + 1, t));
}
return result.mod(p);
}
public static BigInteger interpolation(BigInteger[] values, int xK, int t) {
BigInteger result;
BigInteger zero = BigInteger.ZERO;
BigInteger x_k = new BigInteger(String.valueOf(xK));
BigInteger up = BigInteger.ONE;
BigInteger down = BigInteger.ONE;
for (int i = 0; i < t; i++) {
BigInteger x_i = new BigInteger(String.valueOf((i + 1)));
if (x_i.equals(x_k)) {
continue;
}
up = up.multiply(zero.subtract(x_i));
down = down.multiply(x_k.subtract(x_i));
}
result = up.multiply(down.modInverse(p));
result = result.multiply(values[xK - 1]);
return result;
}
public static void main(String[] args) throws Exception {
init(1024);
int times = 1000;
System.out.println("测试开始.....");
for (int i = 0;i < times;i++){
BigInteger secret = generateRandomBigInteger();
int m = (int) (Math.random() * 100 ) + 5;
int t = (int) (Math.random() * 50 ) + 1;
while (t > m){
t = (int) (Math.random() * 50 ) + 1;
}
BigInteger[] shares = share(secret, m, t);
BigInteger reconstruction = reconstruction(shares, t);
if (reconstruction.compareTo(secret) != 0){
System.out.println("秘密值恢复错误");
}
}
System.out.println("测试结束.....");
}
}
|