前言

记录一手程序设计的实验作业,所使用的算法都是一些普通的算法,时间原因来不及做出太好的,只是稍微对几点想法说一下

函数体系及实现

核心是我们的头文件

#ifndef WLC_BN_H
#define WLC_BN_H

#include "wlc_types.h"
#include <stdlib.h>
#include <stdio.h>
#include <time.h>

#ifdef __cplusplus
extern "C" {
#endif
#ifndef MAX_BN_DIGS
#define MAX_BN_DIGS 1024
#endif

#ifndef dig_t
#if WBITS == 64
typedef uint64_t dig_t;
#else
typedef uint32_t dig_t; // 这个程序中WBITS是32
#endif
#endif

#define DIG_BYTES (WBITS / 8)
#define DIG_MASK ((1ULL << WBITS) - 1) // 掩码 0xFFFFFFFF

// 返回值(错误码)定义
typedef enum {
BN_SUCCESS = 0,
BN_ERR_NULL_PTR = -1,
BN_ERR_INVALID_SIZE = -2,
BN_ERR_BUFFER_OVERFLOW = -3,
BN_ERR_MEMORY = -4,
BN_ERR_MODULUS_EVEN = -5,
BN_ERR_INVALID_PARAM = -6,
BN_ERR_NO_MODULAR_INVERSE = -7,
BN_ERR_MONTGOMERY_INVALID = -8,
BN_ERR_NO_INVERSE = -9,
BN_ERR_ALL_NEGATIVE_RESULT = -10,
BN_ERR_FIRST_NEGATIVE_RESULT = -11,
BN_ERR_SECOND_NEGATIVE_RESULT = -12,
BN_ERR_RANDOM_ALLOC_FAIL = -13,
} bn_err_t;

// Resize函数能否截断有用数据
typedef enum {
BN_RESIZE_SAFE,
BN_RESIZE_TRUNCATE
} bn_resize_mode_t;

// BN结构体,封装大整数
typedef struct {
dig_t *data; // 数据数组
int used_digs; // 实际使用的dig_t个数,方便我们判断数的大小
int capacity; // 容量(dig_t个数),方便我们判断内存大小是否满足需要
} bn_t;

// 蒙哥马利上下文结构体
typedef struct {
bn_t N; // 模数
bn_t R; // R = 2^k
bn_t N_prime; // N' = -N⁻¹ mod R
bn_t R2; // R2 = R^2 mod N
int k; // R的位数
int k_digs; // k对应的dig_t数量
} mont_ctx_t;

// ==================== 内存管理 ====================
bn_err_t bn_init(bn_t *bn, int capacity);
bn_err_t bn_init_copy(bn_t *bn, const dig_t *data, int digs);
bn_err_t bn_free(bn_t *bn);
bn_err_t bn_resize(bn_t *bn, int new_capacity, bn_resize_mode_t mode);
bn_err_t bn_ensure_capacity(bn_t *bn, int required_digs);

// ==================== 工具函数 ====================
int bn_get_bits(const bn_t *bn);
int bn_cmp(const bn_t *a, const bn_t *b);
int bn_is_zero(const bn_t *bn);
int bn_is_even(const bn_t *bn);
int bn_is_one(const bn_t *bn);
bn_err_t bn_set_zero(bn_t *bn);
bn_err_t bn_set_one(bn_t *bn);
bn_err_t bn_copy(bn_t *dst, const bn_t *src);
bn_err_t bn_rand(bn_t *bn, int bits);
bn_err_t bn_rand_range(bn_t *bn, const bn_t *max);
bn_err_t bn_rand_security(bn_t *bn, int bits);
void bn_print(const bn_t *bn, const char *name);
void bn_print_hex(const bn_t *bn, const char *name);
bn_err_t bn_truncate_replace(bn_t *dst, const bn_t *src, int num_digs);
bn_err_t bn_truncate(bn_t *bn, int digs);
bn_err_t bn_truncate_bits(bn_t *bn, int bits);

// ==================== 基础运算 ====================
bn_err_t bn_add(bn_t *r, const bn_t *a, const bn_t *b);
bn_err_t bn_sub(bn_t *r, const bn_t *a, const bn_t *b);
bn_err_t bn_sub_signed(bn_t *r, const bn_t *a, const bn_t *b);
bn_err_t bn_sub_with_sign(bn_t *r, int *is_negative, const bn_t *a, const bn_t *b);
bn_err_t bn_mul(bn_t *r, const bn_t *a, const bn_t *b);
bn_err_t bn_sqr(bn_t *r, const bn_t *a);
bn_err_t bn_div(bn_t *q, bn_t *r, const bn_t *a, const bn_t *b);
bn_err_t bn_mod(bn_t *r, const bn_t *a, const bn_t *m);
bn_err_t bn_lsh(bn_t *r, const bn_t *a, int bits);
bn_err_t bn_rsh(bn_t *r, const bn_t *a, int bits);
dig_t bn_add_dig(bn_t *r, const bn_t *a, dig_t b);
dig_t bn_sub_dig(bn_t *r, const bn_t *a, dig_t b);

// ==================== 模运算 ====================
bn_err_t bn_mod_add(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m);
bn_err_t bn_mod_sub(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m);
bn_err_t bn_mod_mul(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m);
bn_err_t bn_mod_sqr(bn_t *r, const bn_t *a, const bn_t *m);
bn_err_t bn_mod_exp(bn_t *r, const bn_t *a, const bn_t *e, const bn_t *m);
bn_err_t bn_mod_inv(bn_t *r, const bn_t *a, const bn_t *m);
bn_err_t bn_mod_hlv(bn_t *r, const bn_t *a, const bn_t *m);

// ==================== 蒙哥马利算法 ====================
// 蒙哥马利上下文管理
bn_err_t mont_ctx_init(mont_ctx_t *ctx, const bn_t *N);
void mont_ctx_free(mont_ctx_t *ctx);
bn_err_t mont_ctx_compute(mont_ctx_t *ctx, const bn_t *N);

// 蒙哥马利预计算
bn_err_t mont_compute_R(bn_t *R, const bn_t *N, int *k);
bn_err_t mont_compute_N_prime(bn_t *N_prime, const bn_t *N, const bn_t *R);

// 蒙哥马利转换
bn_err_t mont_map(bn_t *a_mont, const bn_t *a, const mont_ctx_t *ctx);
bn_err_t mont_reduce(bn_t *a, const bn_t *a_mont, const mont_ctx_t *ctx);

// 蒙哥马利运算
bn_err_t mont_add(bn_t *r, const bn_t *a, const bn_t *b, const mont_ctx_t *ctx);
bn_err_t mont_sub(bn_t *r, const bn_t *a, const bn_t *b, const mont_ctx_t *ctx);
bn_err_t mont_mul(bn_t *r, const bn_t *a, const bn_t *b, const mont_ctx_t *ctx);
bn_err_t mont_sqr(bn_t *r, const bn_t *a, const mont_ctx_t *ctx);
bn_err_t mont_exp(bn_t *r, const bn_t *a, const bn_t *e, const mont_ctx_t *ctx);
bn_err_t mont_redc_internal(bn_t *r, const bn_t *T, const bn_t *N, const bn_t *N_prime, int k);
// 使用蒙哥马利的模运算(高级接口)
bn_err_t bn_mod_mul_mont(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m);
bn_err_t bn_mod_exp_mont(bn_t *r, const bn_t *a, const bn_t *e, const bn_t *m);

// ==================== 辅助函数 ====================
bn_err_t bn_from_hex(bn_t *bn, const char *hex_str);
bn_err_t bn_from_bytes(bn_t *bn, const unsigned char *bytes, int byte_len);
bn_err_t bn_to_bytes(unsigned char *bytes, int *byte_len, const bn_t *bn, int max_len);
bn_err_t bn_extend(bn_t *bn, int digs, dig_t fill);



#ifdef __cplusplus
}
#endif

#endif /* WLC_BN_H */

内存管理

鉴于我们是进行大整数运算,可能会申请大量的内存,因此我设立了一系列函数帮助我在堆上管理内存

bn_err_t bn_init(bn_t *bn, int capacity);
bn_err_t bn_init_copy(bn_t *bn, const dig_t *data, int digs);

bn_err_t bn_free(bn_t *bn);

bn_err_t bn_resize(bn_t *bn, int new_capacity);
bn_err_t bn_ensure_capacity(bn_t *bn, int required_digs);

首先是

bn_err_t bn_init(bn_t *bn, int capacity);

帮助我们初始化bn_t这个结构体,得到一个至少容量为1的dig容量,并且calloc会将内存清零便于我们快速使用

bn_err_t bn_init(bn_t *bn, int capacity) {
if (!bn) return BN_ERR_NULL_PTR;
if (capacity <= 0) capacity = 1;

bn->data = (dig_t *)calloc(capacity, sizeof(dig_t));
if (!bn->data) return BN_ERR_MEMORY;

bn->capacity = capacity;
bn->used_digs = 0;
return BN_SUCCESS;
}

方便我们申请一个bn_t结构体进行后续使用

bn_t temp;
bn_init(&temp, 0);

然后是

bn_err_t bn_init_copy(bn_t *bn, const dig_t *data, int digs);

一样是初始化结构体,但是进一步实现了数的复制,并且可以通过参数设置复制的digs数

bn_err_t bn_init_copy(bn_t *bn, const dig_t *data, int digs) {
bn_err_t err = bn_init(bn, digs);
if (err != BN_SUCCESS) return err;

if (data && digs > 0) {
memcpy(bn->data, data, digs * sizeof(dig_t));
bn->used_digs = digs;
while (bn->used_digs > 0 && bn->data[bn->used_digs - 1] == 0) { // 假如复制数据使前面的digs全是0(也称前导0),则收缩
bn->used_digs--;
}
}
return BN_SUCCESS;
}

然后是内存释放函数

bn_err_t bn_free(bn_t *bn);

帮助我们释放内存

bn_err_t bn_free(bn_t *bn) {
if (bn && bn->data) {
free(bn->data);
bn->data = NULL; // 清空指针,防止UAF
bn->capacity = 0;
bn->used_digs = 0;
return BN_SUCCESS;
}
return BN_ERR_NULL_PTR;
}

然后是动态内存调整

bn_err_t bn_resize(bn_t *bn, int new_capacity, bn_resize_mode_t mode);

方便我们调整capacity的大小

bn_err_t bn_resize(bn_t *bn, int new_capacity, bn_resize_mode_t mode) {
if (!bn) return BN_ERR_NULL_PTR;
if (new_capacity <= 0) return BN_ERR_INVALID_SIZE;

if (new_capacity == bn->capacity) {
return BN_SUCCESS;
}

if (new_capacity < bn->capacity) {
if (mode == BN_RESIZE_SAFE) {
// 不允许截断非0数据
for (int i = new_capacity; i < bn->used_digs; i++) {
if (bn->data[i] != 0) {
return BN_ERR_TRUNCATION;
}
}
}
if (bn->used_digs > new_capacity) {
bn->used_digs = new_capacity; // 缩减大小
}
}

dig_t *new_data = (dig_t *)realloc(bn->data, new_capacity * sizeof(dig_t));
if (!new_data) return BN_ERR_MEMORY; // 失败也保存原来数据

bn->data = new_data;

// 清零新增大小的内存
if (new_capacity > bn->capacity) {
memset(bn->data + bn->capacity, 0, (new_capacity - bn->capacity) * sizeof(dig_t));
}

bn->capacity = new_capacity;
return BN_SUCCESS;
}

然后是capacity校验函数,防止出现内存不足的情况

bn_err_t bn_ensure_capacity(bn_t *bn, int required_digs);

这里不可能发生截断,所以使用了BN_RESIZE_TRUNCATE

bn_err_t bn_ensure_capacity(bn_t *bn, int required_digs) {
if (!bn) return BN_ERR_NULL_PTR;
if (required_digs <= 0) return BN_ERR_INVALID_SIZE;

if (required_digs <= bn->capacity) {
return BN_SUCCESS;
}

// 增长策略:取所需大小的1.5倍
int new_capacity = required_digs * 3 / 2 + 1;
return bn_resize(bn, new_capacity, BN_RESIZE_TRUNCATE);
}

工具函数

对于大整数运算中,去判断或是设置某一特定位是1还是0便于我们操作,比如说设置一个,或者是在蒙哥马利模幂中判断是否为1,或者是得到我们的位数都需要依赖这两个函数

因此我们先设置这两个函数

static int bn_get_bit(const bn_t *bn, int bit) {
if (!bn || bit < 0) return 0;
int word_idx = bit / WBITS;
int bit_idx = bit % WBITS;
if (word_idx >= bn->used_digs) return 0; // 零拓展
return (bn->data[word_idx] >> bit_idx) & 1; // 右移加&运算
}

static bn_err_t bn_set_bit(bn_t *bn, int bit, int value) {
if (!bn || bit < 0) return BN_ERR_NULL_PTR;
int word_idx = bit / WBITS;
int bit_idx = bit % WBITS;
if (word_idx >= bn->capacity) {
bn_err_t err = bn_ensure_capacity(bn, word_idx + 1);
if(err != BN_SUCCESS){
return err;
}
}
if (word_idx >= bn->used_digs) {
bn->used_digs = word_idx + 1;
}
if (value) {
bn->data[word_idx] |= (dig_t)1 << bit_idx; // 设为1
} else {
bn->data[word_idx] &= ~((dig_t)1 << bit_idx); // 设为0
// 如果清除的是最高位且 word 变为 0,更新 used_digs
if (word_idx == bn->used_digs - 1 && bn->data[word_idx] == 0) {
while (bn->used_digs > 0 && bn->data[bn->used_digs - 1] == 0) {
bn->used_digs--;
}
}
}
return BN_SUCCESS;
}

然后我们设置一个用来判断数字有多少位的函数

int bn_get_bits(const bn_t *bn) {
if (!bn || bn->used_digs == 0) return 0; // 空数据

int i = bn->used_digs - 1;
dig_t top = bn->data[i]; // 只需要取最高位的数字dig_t来判断

#if WBITS == 64
return i * WBITS + (64 - __builtin_clzll(top));
#else
return i * WBITS + (32 - __builtin_clz(top)); // 这个是32位程序使用的
#endif
}

判断是否为0,是否为偶数(奇数)这类函数也顺便实现一下

int bn_is_zero(const bn_t *bn) {
if (!bn) return 0;
return (bn->used_digs == 0) || (bn->used_digs == 1 && bn->data[0] == 0);
}

int bn_is_one(const bn_t *bn) {
if (!bn) return 0;
return (bn->used_digs == 1 && bn->data[0] == 1);
}

int bn_is_even(const bn_t *bn) {
if (!bn) return 0;
if (bn->used_digs == 0) return 1;
return (bn->data[0] & 1) == 0;
}
// 判断正确则返回1,否则返回0

比较函数对于相减和模计算比较重要

int bn_cmp(const bn_t *a, const bn_t *b) {
if (!a || !b) return 0;

if (a->used_digs != b->used_digs) { // 先比较dig_t的数目
return a->used_digs > b->used_digs ? 1 : -1;
}

for (int i = a->used_digs - 1; i >= 0; i--) {
if (a->data[i] > b->data[i]) return 1; // 再从高dig_t比较到低dig_t
if (a->data[i] < b->data[i]) return -1;
}
return 0;
}

下面是几个比较常用的赋值函数

bn_err_t bn_set_zero(bn_t *bn) {
if (!bn) return BN_ERR_NULL_PTR;

if (bn->data) {
memset(bn->data, 0, bn->capacity * sizeof(dig_t));
}
bn->used_digs = 0;
return BN_SUCCESS;
}

清零函数高效清除数据

再设置一个置1的

bn_err_t bn_set_one(bn_t *bn) {
if (!bn) return BN_ERR_NULL_PTR;

bn_err_t err = bn_ensure_capacity(bn, 1);
if (err != BN_SUCCESS) return err;

if (bn->used_digs > 0) {
memset(bn->data, 0, bn->capacity * sizeof(dig_t));
}

// 设置值为1
bn->data[0] = 1;
bn->used_digs = 1;

return BN_SUCCESS;
}

然后是复制操作

bn_err_t bn_copy(bn_t *dst, const bn_t *src) {
if (!dst || !src) return BN_ERR_NULL_PTR;
if (dst == src) return BN_SUCCESS;
if (src->used_digs == 0) {
return bn_set_zero(dst); // 清零
}
bn_err_t err = bn_ensure_capacity(dst, src->used_digs);
if (err != BN_SUCCESS) return err;

memcpy(dst->data, src->data, src->used_digs * sizeof(dig_t)); // 复制
dst->used_digs = src->used_digs;

if (dst->used_digs < dst->capacity) {
memset(dst->data + dst->used_digs, 0, (dst->capacity - dst->used_digs) * sizeof(dig_t)); // 清零dst高位
}

return BN_SUCCESS;
}

还有随机数生成器

bn_err_t bn_rand(bn_t *bn, int bits) {
if (!bn) return BN_ERR_NULL_PTR;
if (bits <= 0) return BN_ERR_INVALID_SIZE;

int digs = (bits + WBITS - 1) / WBITS;
bn_err_t err = bn_ensure_capacity(bn, digs);
if (err != BN_SUCCESS) return err;

static int seeded = 0;
if (!seeded) {
srand(time(NULL));
seeded = 1;
}

for (int i = 0; i < digs; i++) {
bn->data[i] = (dig_t)rand() << 17 | (dig_t)rand() << 2 | (dig_t)(rand() & 0x3); // rand返回0~0x7FFF
}

// 处理最高位,确保位数正确
int bit_mask = bits % WBITS;
if (bit_mask > 0) {
dig_t mask = (1ULL << bit_mask) - 1;
bn->data[digs - 1] &= mask;
// 确保最高位为1
if (bit_mask > 1) {
bn->data[digs - 1] |= (1ULL << (bit_mask - 1));
}
}else {
if (digs > 0) {
bn->data[digs - 1] |= (1ULL << (WBITS - 1));
}
}

bn->used_digs = digs;
while (bn->used_digs > 0 && bn->data[bn->used_digs - 1] == 0) {
bn->used_digs--;
}

return BN_SUCCESS;
}

rand函数生成的随机数质量不高,我们使用BCryptGenRandom来生成高质量的随机数

bn_err_t bn_rand_security(bn_t *bn, int bits) {
if (!bn) return BN_ERR_NULL_PTR;
if (bits <= 0) return BN_ERR_INVALID_SIZE;

int digs = (bits + WBITS - 1) / WBITS;
bn_err_t err = bn_ensure_capacity(bn, digs);
if (err != BN_SUCCESS) return err;

size_t total_bytes = digs * sizeof(dig_t);

NTSTATUS status = BCryptGenRandom(
NULL,
(PUCHAR)bn->data,
(ULONG)total_bytes,
BCRYPT_USE_SYSTEM_PREFERRED_RNG
);

if (status != 0) {
return BN_ERR_RANDOM_ALLOC_FAIL;
}

bn_truncate_bits(bn,bits); // 截断
int bit_mask = bits % WBITS;
int last_word = digs - 1;

if (bit_mask > 0) {
// 设置最高位为1
if (bit_mask > 1) {
bn->data[last_word] |= (1ULL << (bit_mask - 1));
}
} else {
if (digs > 0) {
bn->data[last_word] |= (1ULL << (WBITS - 1));
}
}

bn->used_digs = digs;
while (bn->used_digs > 0 && bn->data[bn->used_digs - 1] == 0) {
bn->used_digs--;
}

return BN_SUCCESS;
}

然后是打印函数

void bn_print(const bn_t *bn, const char *name) {
if (!bn) {
printf("\n%s: (null)\n", name ? name : "bn");
return;
}

printf("\n%s: ", name ? name : "bn");
if (bn_is_zero(bn)) {
printf("0\n");
return;
}
printf("0x");
for (int i = bn->used_digs - 1; i >= 0; i--) {
#if WBITS == 64
if (i == bn->used_digs - 1) {
printf("%" PRIx64, (uint64_t)bn->data[i]);
} else {
printf("%016" PRIx64, (uint64_t)bn->data[i]);
}
#else
if (i == bn->used_digs - 1) {
printf("%" PRIx32, (uint32_t)bn->data[i]);
} else {
printf("%08" PRIx32, (uint32_t)bn->data[i]);
}
#endif
printf(" ");
}
printf(" (%d digs)\n", bn->used_digs);
}

AI帮我设计了一个debug的printf

void bn_print_debug(const bn_t *bn, const char *name) {
if (!bn) {
printf("%s: (null)\n", name ? name : "bn");
return;
}

printf("\n=== %s Debug Info ===\n", name ? name : "bn");

// 基本状态
printf(" used_digs: %d\n", bn->used_digs);
printf(" capacity: %d\n", bn->capacity);

// 安全检查
if (!bn->data) {
printf(" ERROR: data pointer is NULL!\n");
return;
}
if (bn->capacity <= 0) {
printf(" ERROR: invalid capacity (%d)!\n", bn->capacity);
return;
}
if (bn->used_digs > bn->capacity) {
printf(" ERROR: used_digs (%d) > capacity (%d)!\n",
bn->used_digs, bn->capacity);
}

// 所有分配的字
printf(" All allocated words (%d total):\n", bn->capacity);
for (int i = bn->capacity - 1; i >= 0; i--) {
printf(" [%3d]: 0x", i);

#if WBITS == 64
printf("%016" PRIx64, (uint64_t)bn->data[i]);
#else
printf("%08" PRIx32, (uint32_t)bn->data[i]);
#endif

// 状态标记
if (i >= bn->used_digs) {
printf(" [unused]");
if (bn->data[i] != 0) {
printf(" (WARNING: non-zero!)");
}
} else if (i == bn->used_digs - 1) {
printf(" [MSB]"); // Most Significant Bit
if (bn->data[i] == 0) {
printf(" (ERROR: leading zero!)");
}
} else if (bn->data[i] == 0) {
printf(" [zero]");
}

printf("\n");
}

// 标准格式(紧凑)
printf(" Compact hex: ");
if (bn->used_digs == 0 || (bn->used_digs == 1 && bn->data[0] == 0)) {
printf("0");
} else {
printf("0x");
int i = bn->used_digs - 1;

// 最高位不带前导零
#if WBITS == 64
printf("%" PRIx64, (uint64_t)bn->data[i]);
#else
printf("%" PRIx32, (uint32_t)bn->data[i]);
#endif

// 其他位带前导零
for (i--; i >= 0; i--) {
#if WBITS == 64
printf("%016" PRIx64, (uint64_t)bn->data[i]);
#else
printf("%08" PRIx32, (uint32_t)bn->data[i]);
#endif
}
}
printf("\n");

// 大小估算
size_t used_bytes = bn->used_digs * sizeof(dig_t);
size_t alloc_bytes = bn->capacity * sizeof(dig_t);
printf(" Memory: %zu/%zu bytes used (%.1f%%)\n",
used_bytes, alloc_bytes,
alloc_bytes ? (100.0 * used_bytes / alloc_bytes) : 0.0);

printf("================================\n\n");
}

然后是截断函数,对于模有奇效

这种是复制型的

bn_err_t bn_truncate_copy(bn_t *dst, const bn_t *src, int digs) {
if (!dst || !src) return BN_ERR_NULL_PTR;
if (digs <= 0) return BN_ERR_INVALID_PARAM;

// 确保dst有足够容量
bn_err_t err = bn_ensure_capacity(dst, digs);
if (err != BN_SUCCESS) return err;

// 清零
memset(dst->data, 0, dst->capacity * sizeof(dig_t));

// 计算实际复制数量
int copy_digs = (src->used_digs < digs) ? src->used_digs : digs;

// 复制低位
memcpy(dst->data, src->data, copy_digs * sizeof(dig_t));

dst->used_digs = copy_digs;

// 去除前导零
while (dst->used_digs > 0 && dst->data[dst->used_digs - 1] == 0) {
dst->used_digs--;
}

return BN_SUCCESS;
}

这种是对自身数据进行截断的

bn_err_t bn_truncate(bn_t *bn, int digs) {
if (!bn) return BN_ERR_NULL_PTR;
if (digs < 0) return BN_ERR_INVALID_SIZE;

if (digs < bn->used_digs) {
bn->used_digs = digs;
}

if (digs < bn->capacity) {
memset(bn->data + digs, 0, (bn->capacity - digs) * sizeof(dig_t));
}

return BN_SUCCESS;
}

还有一种是对位截断的

bn_err_t bn_truncate_bits(bn_t *bn, int bits) {
if (!bn) return BN_ERR_NULL_PTR;
if (bits <= 0) return BN_ERR_INVALID_SIZE;

// 计算需要保留的字数
int digs_needed = (bits + WBITS - 1) / WBITS;
int extra_bits = bits % WBITS;
// 如果原数位数少于需要的,不需要截断
if (bn->used_digs <= digs_needed) {
if (bn->used_digs == digs_needed) { // 只需处理最高字的掩码
if (extra_bits > 0) {
dig_t mask = ((dig_t)1 << extra_bits) - 1;
bn->data[digs_needed - 1] &= mask;
}
}
return BN_SUCCESS;
}

memset(bn->data + digs_needed, 0, (bn->used_digs - digs_needed) * sizeof(dig_t));
bn->used_digs = digs_needed;
// 处理最高字的掩码
if (extra_bits > 0) {
dig_t mask = ((dig_t)1 << extra_bits) - 1;
bn->data[digs_needed - 1] &= mask;
}

while (bn->used_digs > 0 && bn->data[bn->used_digs - 1] == 0) {
bn->used_digs--;
}

return BN_SUCCESS;
}

基础运算

加法就是简单的实现

bn_err_t bn_add(bn_t *r, const bn_t *a, const bn_t *b) {
if (!r || !a || !b) return BN_ERR_NULL_PTR;

// 添加原地操作支持
if (r == a || r == b) {
bn_t temp;
// 初始化时指定容量
int max_digs = MAX(a->used_digs, b->used_digs);
bn_init(&temp, max_digs + 1);

bn_err_t err = bn_add(&temp, a, b);
if (err != BN_SUCCESS) {
bn_free(&temp);
return err;
}
err = bn_copy(r, &temp);
bn_free(&temp);
return err;
}

int max_digs = MAX(a->used_digs, b->used_digs);
bn_err_t err = bn_ensure_capacity(r, max_digs + 1);
if (err != BN_SUCCESS) return err;

memset(r->data, 0, r->capacity * sizeof(dig_t));

uint64_t carry = 0;
int i;
for (i = 0; i < max_digs; i++) {
uint64_t a_val = (i < a->used_digs) ? a->data[i] : 0;
uint64_t b_val = (i < b->used_digs) ? b->data[i] : 0;
carry += a_val + b_val;
r->data[i] = (dig_t)(carry & DIG_MASK);
carry >>= WBITS;
}

if (carry > 0) {
r->data[i] = (dig_t)carry;
r->used_digs = max_digs + 1;
} else {
r->used_digs = max_digs;
}

while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

return BN_SUCCESS;
}

还有减法(a>b)

bn_err_t bn_sub(bn_t *r, const bn_t *a, const bn_t *b) {
if (!r || !a || !b) return BN_ERR_NULL_PTR;

int cmp = bn_cmp(a, b);
if (cmp < 0) {
return BN_ERR_ALL_NEGATIVE_RESULT;
}if (cmp == 0) {
bn_set_zero(r);
return BN_SUCCESS;
}
// 原地操作处理
if (r == a || r == b) {
bn_t temp;
bn_init(&temp, 0);
bn_err_t err = bn_sub(&temp, a, b);
if (err != BN_SUCCESS) {
bn_free(&temp);
return err;
}
err = bn_copy(r, &temp);
bn_free(&temp);
return err;
}

int max_digs = a->used_digs;
bn_err_t err = bn_ensure_capacity(r, max_digs);
if (err != BN_SUCCESS) return err;

// 清零
memset(r->data, 0, r->capacity * sizeof(dig_t));

uint64_t borrow = 0;

// 逐位相减
for (int i = 0; i < max_digs; i++) {
uint64_t a_val = (uint64_t)a->data[i];
uint64_t b_val = (i < b->used_digs) ? (uint64_t)b->data[i] : 0;
b_val += borrow; // 加上之前的借位

if (a_val < b_val) {
// 需要借位
a_val += ((uint64_t)1 << WBITS);
borrow = 1;
} else {
borrow = 0;
}

r->data[i] = (dig_t)(a_val - b_val);
}

r->used_digs = max_digs;

while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

return BN_SUCCESS;
}

既然有得到正数的减法,如何实现一个得到负数的减法,毕竟拓展欧几里得算法中间是会出现负数的

于是灵光一闪,使用错误码返回代表是正数是负数,BN_SUCCESS代表非负数,BN_ERR_ALL_NEGATIVE_RESULT代表负数

bn_err_t bn_sub_signed(bn_t *r, const bn_t *a, const bn_t *b) {
if (!r || !a || !b) return BN_ERR_NULL_PTR;
int cmp = bn_cmp(a, b);
if (cmp == 0) {
bn_set_zero(r);
return BN_SUCCESS;
}

if (r == a || r == b) {
bn_t temp;
bn_init(&temp, 0);
bn_err_t err = bn_sub_signed(&temp, a, b);
if (err != BN_SUCCESS && err != BN_ERR_ALL_NEGATIVE_RESULT) {
bn_free(&temp);
return err;
}
bn_copy(r, &temp);
bn_free(&temp);
return err; // 保持原返回值
}

int max_digs = MAX(a->used_digs, b->used_digs);
bn_err_t err = bn_ensure_capacity(r, max_digs);
if (err != BN_SUCCESS) return err;

memset(r->data, 0, r->capacity * sizeof(dig_t));

uint64_t borrow = 0;

if (cmp >= 0) {
for (int i = 0; i < a->used_digs; i++) {
uint64_t a_val = (uint64_t)a->data[i];
uint64_t b_val = (i < b->used_digs) ? (uint64_t)b->data[i] : 0;
b_val += borrow;
if (a_val < b_val) {
a_val += ((uint64_t)1 << WBITS);
borrow = 1;
} else {
borrow = 0;
}
r->data[i] = (dig_t)(a_val - b_val);
}
} else {
for (int i = 0; i < b->used_digs; i++) {
uint64_t b_val = (uint64_t)b->data[i];
uint64_t a_val = (i < a->used_digs) ? (uint64_t)a->data[i] : 0;
a_val += borrow;
if (b_val < a_val) {
b_val += ((uint64_t)1 << WBITS);
borrow = 1;
} else {
borrow = 0;
}
r->data[i] = (dig_t)(b_val - a_val);
}
}

r->used_digs = max_digs;
while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

if (cmp >= 0) {
return BN_SUCCESS; // 结果为正
} else {
return BN_ERR_ALL_NEGATIVE_RESULT; // 结果为负
}
}

假如不使用错误码判断呢,那再拓展一下,传入一个变量来判断是否是正数

bn_err_t bn_sub_with_sign(bn_t *r, int *is_negative, const bn_t *a, const bn_t *b) {
if (!r || !is_negative || !a || !b) return BN_ERR_NULL_PTR;

bn_err_t err = bn_sub_signed(r, a, b);

if (err == BN_SUCCESS) {
*is_negative = 0;
return BN_SUCCESS;
} else if (err == BN_ERR_ALL_NEGATIVE_RESULT) {
*is_negative = 1;
return BN_SUCCESS;
} else {
return err;
}
}

然后是乘法

bn_err_t bn_mul(bn_t *r, const bn_t *a, const bn_t *b) {
if (!r || !a || !b) return BN_ERR_NULL_PTR;
if (bn_is_zero(a) || bn_is_zero(b)){
return bn_set_zero(r);
}
if (bn_is_one(a)) {
return bn_copy(r, b);
}
if (bn_is_one(b)) {
return bn_copy(r, a);
}
if (r == a || r == b) {
bn_t temp;
bn_init(&temp, 0);
bn_err_t err = bn_mul(&temp, a, b);
if (err == BN_SUCCESS) {
err = bn_copy(r, &temp);
}
bn_free(&temp);
return err;
}

int result_digs = a->used_digs + b->used_digs;
bn_err_t err = bn_ensure_capacity(r, result_digs);
if (err != BN_SUCCESS) return err;

memset(r->data, 0, r->capacity * sizeof(dig_t));

for (int i = 0; i < a->used_digs; i++) {
if (a->data[i] == 0) continue;

uint64_t carry = 0;
for (int j = 0; j < b->used_digs; j++) {
carry += (uint64_t)r->data[i + j] + (uint64_t)a->data[i] * b->data[j];
r->data[i + j] = (dig_t)(carry & DIG_MASK);
carry >>= WBITS;
}

// 处理最后的进位
int k = i + b->used_digs;
while (carry > 0 && k < result_digs) {
carry += r->data[k];
r->data[k] = (dig_t)(carry & DIG_MASK);
carry >>= WBITS;
k++;
}
}

r->used_digs = result_digs;
while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

return BN_SUCCESS;
}

平方运算

bn_err_t bn_sqr(bn_t *r, const bn_t *a) {
return bn_mul(r, a, a);
}

左移和右移

bn_err_t bn_lsh(bn_t *r, const bn_t *a, int bits) {
if (!r || !a) return BN_ERR_NULL_PTR;
if (bits < 0) return BN_ERR_INVALID_PARAM;
if (bits == 0) return bn_copy(r, a);

if (r == a) {
bn_t temp;
bn_init(&temp, 0);
bn_err_t err = bn_lsh(&temp, a, bits);
if (err == BN_SUCCESS) {
err = bn_copy(r, &temp);
}
bn_free(&temp);
return err;
}

int dig_shift = bits / WBITS;
int bit_shift = bits % WBITS;
int result_digs = a->used_digs + dig_shift + (bit_shift > 0 ? 1 : 0);

bn_err_t err = bn_ensure_capacity(r, result_digs);
if (err != BN_SUCCESS) return err;

memset(r->data, 0, r->capacity * sizeof(dig_t));

if (bit_shift == 0) {
memcpy(r->data + dig_shift, a->data, a->used_digs * sizeof(dig_t));
} else {
uint64_t carry = 0;
for (int i = 0; i < a->used_digs; i++) {
uint64_t val = ((uint64_t)a->data[i] << bit_shift) | carry;
r->data[i + dig_shift] = (dig_t)(val & DIG_MASK);
carry = val >> WBITS;
}
if (carry > 0) {
r->data[a->used_digs + dig_shift] = (dig_t)carry;
}
}

r->used_digs = result_digs;
while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

return BN_SUCCESS;
}

bn_err_t bn_rsh(bn_t *r, const bn_t *a, int bits) {
if (!r || !a) return BN_ERR_NULL_PTR;
if (bits < 0) return BN_ERR_INVALID_PARAM;

if (bits == 0) return bn_copy(r, a);

if (bn_is_zero(a)) {
return bn_set_zero(r);
}

int dig_shift = bits / WBITS;
int bit_shift = bits % WBITS;

if (dig_shift >= a->used_digs) {
return bn_set_zero(r);
}

if (r == a) {
bn_t temp;
bn_init(&temp, 0);
bn_err_t err = bn_rsh(&temp, a, bits);
if (err == BN_SUCCESS) {
err = bn_copy(r, &temp);
}
bn_free(&temp);
return err;
}

int max_result_digs = a->used_digs - dig_shift;

bn_err_t err = bn_ensure_capacity(r, max_result_digs);
if (err != BN_SUCCESS) return err;

memset(r->data, 0, r->capacity * sizeof(dig_t));

if (bit_shift == 0) {
memcpy(r->data, a->data + dig_shift, max_result_digs * sizeof(dig_t));
r->used_digs = max_result_digs;
} else {
// 从最低的有效位置开始处理
for (int src_idx = dig_shift; src_idx < a->used_digs; src_idx++) {
int dst_idx = src_idx - dig_shift;
// 获取当前字
dig_t current_word = a->data[src_idx];
// 获取下一个(更高)字用于进位
dig_t next_word = (src_idx + 1 < a->used_digs) ? a->data[src_idx + 1] : 0;
// 当前字的低位部分(右移后保留的部分)
dig_t low_part = current_word >> bit_shift;
// 下一个字的高位部分(作为进位)需要左移 (WBITS - bit_shift) 位
int carry_shift = WBITS - bit_shift;
dig_t carry_part = 0;
if (carry_shift < WBITS) { // 避免移位超过字长
carry_part = next_word << carry_shift;
}
// 组合结果
r->data[dst_idx] = low_part | carry_part;
}

r->used_digs = max_result_digs;
}

while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

if (r->used_digs == 0) {
r->used_digs = 1;
r->data[0] = 0;
}

return BN_SUCCESS;
}

实现了大精度加法,为了加快小精度的速度,添加了对一个dig的特殊操作

dig_t bn_add_dig(bn_t *r, const bn_t *a, dig_t b) {
if (!r || !a) return 0;
if (b == 0) {
bn_copy(r, a);
return 0;
}
if (bn_is_zero(a)){
if(!(r->data)) bn_init(r, 1);
memset(r->data, 0, r->capacity * sizeof(dig_t));
r->data[0] = b;
r->used_digs = 1;
return 0;
}
bn_err_t err = bn_ensure_capacity(r, a->used_digs + 1);
if(err != BN_SUCCESS) return 0;
if(r != a){
err = bn_copy(r, a);
if(err != BN_SUCCESS) return 0;
}
uint64_t carry = b;
for (int i = 0; i < r->used_digs && carry > 0; i++) {
carry += r->data[i];
r->data[i] = (dig_t)(carry & DIG_MASK);
carry >>= WBITS;
}

if (carry > 0) {
r->data[r->used_digs] = (dig_t)carry;
r->used_digs++;
}

while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

return (dig_t)carry;
}

dig_t bn_sub_dig(bn_t *r, const bn_t *a, dig_t b) {
if (!r || !a) return 1;
if (a->used_digs ==0 ) return 0;
if ( a->used_digs == 1 && a->data[0] <= b ){
if(a->data[0] == b) {
bn_set_zero(r);
return 0;
}
return 1;
}

if (b == 0) {
bn_copy(r, a);
return 0;
}

if(r != a){
bn_err_t err = bn_copy(r, a);
if(err != BN_SUCCESS) return 0;
}

uint64_t borrow = b;
for (int i = 0; i < r->used_digs && borrow > 0; i++) {
if (r->data[i] < borrow) {
r->data[i] = (dig_t)((1ULL << WBITS) + r->data[i] - borrow);
borrow = 1;
} else {
r->data[i] -= (dig_t)borrow;
borrow = 0;
break;
}
}

while (r->used_digs > 0 && r->data[r->used_digs - 1] == 0) {
r->used_digs--;
}

return (dig_t)borrow;
}

然后是除法运算,这个主要是针对相差不大的数字

bn_err_t bn_div(bn_t *q, bn_t *r, const bn_t *a, const bn_t *b) {
if (!q || !r || !a || !b) return BN_ERR_NULL_PTR;
if (bn_is_zero(b)) return BN_ERR_INVALID_PARAM;

int cmp = bn_cmp(a, b);
if (cmp < 0) {
bn_set_zero(q);
bn_copy(r, a);
return BN_SUCCESS;
}
if (cmp == 0) {
bn_set_one(q);
bn_set_zero(r);
return BN_SUCCESS;
}

bn_t a1, q1, temp;

int cap = a->capacity + 1;
bn_init(&a1, cap);
bn_init(&q1, cap);
bn_init(&temp, cap);

bn_copy(&a1, a);
bn_set_zero(q);

int b_bits = bn_get_bits(b);
while (bn_cmp(&a1, b) >= 0) {
bn_set_one(&q1);
int shift = bn_get_bits(&a1) - b_bits;
while (shift >= 0) {
bn_lsh(&temp, b, shift);
if (bn_cmp(&temp, &a1) <= 0) {
break;
}
shift--;
}
if(shift < 0) break;
bn_lsh(&q1, &q1, shift);
bn_add(q, q, &q1);

bn_sub(&a1, &a1, &temp);
}

// while (bn_cmp(&a1, b) >= 0) { // 相差比较大的时候开销就很大了
// bn_set_one(&q1);
// int shift = 0;
// while (1) {
// bn_lsh(&temp, b, shift+1);
// if (bn_cmp(&temp, &a1) > 0) {
// break;
// }
// shift++;
// }
// bn_lsh(&q1, &q1, shift);
// bn_add(q, q, &q1);

// bn_rsh(r, &temp, 1);
// bn_sub(&a1, &a1, r);
// }

bn_copy(r, &a1);

bn_free(&a1);
bn_free(&q1);
bn_free(&temp);

return BN_SUCCESS;
}

还有取模运算

bn_err_t bn_mod(bn_t *r, const bn_t *a, const bn_t *m) {
if (!r || !a || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;

int cmp = bn_cmp(a, m);
if (cmp < 0) {
bn_copy(r, a);
return BN_SUCCESS;
}
if (cmp == 0) {
bn_set_zero(r);
return BN_SUCCESS;
}

if (r == a || r == m){
bn_t temp;
bn_init(&temp, a->used_digs);
bn_err_t err = bn_mod(&temp, a, m);
if (err == BN_SUCCESS) {
err = bn_copy(r, &temp);
}
bn_free(&temp);
return err;
}

bn_t a1, temp;

int cap = a->capacity + 1;
bn_init(&a1, cap);
bn_init(&temp, cap);

bn_copy(&a1, a);

int m_bits = bn_get_bits(m);


while (bn_cmp(&a1, m) >= 0) {
int a1_bits = bn_get_bits(&a1);
int shift = a1_bits - m_bits;
while (shift >= 0) {
bn_lsh(&temp, m, shift);
if (bn_cmp(&temp, &a1) <= 0) {
break;
}
shift--;
}
if(shift < 0) break;
bn_sub(&a1, &a1, &temp);
}
// while (bn_cmp(&a1, m) >= 0) { // 一样是改进了这个
// bn_set_one(&q1);
// int shift = 0;
// while (1) {
// bn_lsh(&temp, m, shift+1);
// if (bn_cmp(&temp, &a1) > 0) {
// break;
// }
// shift++;
// }
// bn_rsh(r, &temp, 1);
// bn_sub(&a1, &a1, r);
// }
bn_copy(r, &a1);

bn_free(&a1);
bn_free(&temp);

return BN_SUCCESS;
// bn_t q;
// bn_init(&q, a->capacity);

// bn_err_t err = bn_div(&q, r, a, m);
// bn_free(&q);

//return err;
}

模运算

一般的模加、模减、模乘

这里我没使用bn_mod是因为数比较接近,使用bn_mod的时间开销大,于是便使用了时间开销更小的bn_sub

bn_err_t bn_mod_add(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m) {
if (!r || !a || !b || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;

if (bn_is_zero(a)) return bn_mod_sub(r, b, a, m);
if (bn_is_zero(b)) return bn_mod_sub(r, a, b, m);


bn_t sum;
int max_capacity = (a->used_digs > b->used_digs ? a->used_digs : b->used_digs) + 1;
bn_init(&sum, max_capacity);

bn_err_t err = bn_add(&sum, a, b);
if (err != BN_SUCCESS) {
bn_free(&sum);
return err;
}

//bn_mod(&sum,&sum,m);
while (bn_cmp(&sum, m) >= 0) {
err = bn_sub(&sum, &sum, m);
if (err != BN_SUCCESS) {
bn_free(&sum);
return err;
}
}

if (r != &sum) {
err = bn_copy(r, &sum);
}
bn_free(&sum);
return err ? err : BN_SUCCESS;
}

bn_err_t bn_mod_sub(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m) {
if (!r || !a || !b || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;

bn_t a_mod, b_mod;
bn_init(&a_mod, m->used_digs);
bn_init(&b_mod, m->used_digs);

//bn_mod(&a_mod, a, m);
bn_copy(&a_mod,a);
while(bn_cmp(a,m)>=0){
bn_sub(&a_mod,&a_mod,m);
}

//bn_mod(&b_mod, b, m);
bn_copy(&b_mod,b);
while(bn_cmp(b,m)>=0){
bn_sub(&b_mod,&b_mod,m);
}

int cmp = bn_cmp(&a_mod, &b_mod);
if (cmp >= 0) {
bn_sub(r, &a_mod, &b_mod);
} else {
bn_t temp;
bn_init(&temp, m->used_digs);
bn_sub(&temp, m, &b_mod);
bn_add(r, &a_mod, &temp);

bn_free(&temp);
if (bn_cmp(r, m) >= 0) {
bn_sub(r, r, m);
}
}

bn_free(&a_mod);
bn_free(&b_mod);
return BN_SUCCESS;
}

bn_err_t bn_mod_mul(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m) {
if (!r || !a || !b || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;
if (bn_is_even(m)) return BN_ERR_MODULUS_EVEN;
bn_t temp;
bn_err_t err = bn_init(&temp,a->capacity+b->used_digs);
if(err!=BN_SUCCESS) return BN_ERR_MEMORY;
bn_mul(&temp,a,b);
if(err!=BN_SUCCESS) return BN_ERR_INVALID_PARAM;
bn_mod(r,&temp,m);
if(err!=BN_SUCCESS) return BN_ERR_INVALID_PARAM;
bn_free(&temp);
return err;
}

bn_err_t bn_mod_sqr(bn_t *r, const bn_t *a, const bn_t *m) {
return bn_mod_mul(r, a, a, m);
}

bn_err_t bn_mod_exp(bn_t *r, const bn_t *a, const bn_t *e, const bn_t *m) {
if (!r || !a || !e || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;

if (bn_is_zero(e)) {
return bn_set_one(r);
}
if (bn_is_one(e)) {
bn_t one;
bn_init(&one, 0);
bn_set_one(&one);
bn_err_t err = bn_mod_mul_mont(r, a, &one, m);
bn_free(&one);
return err;
}
bn_t result, base, exp;
bn_init(&result, 0);
bn_init(&base, 0);
bn_init(&exp, 0);
bn_err_t err = bn_copy(&base, a);
if (err) goto cleanup;
err = bn_copy(&exp, e);
if (err) goto cleanup;
bn_set_one(&result);
while (!bn_is_zero(&exp)) {
if (bn_is_even(&exp)) {
bn_t temp;
bn_init(&temp, 0);
err = bn_mod_mul(&temp, &base, &base, m);
if (err) {
bn_free(&temp);
goto cleanup;
}
bn_copy(&base, &temp);
bn_free(&temp);

bn_rsh(&exp, &exp, 1);
} else {
bn_t temp;
bn_init(&temp, 0);
err = bn_mod_mul(&temp, &result, &base, m);
if (err) {
bn_free(&temp);
goto cleanup;
}
bn_copy(&result, &temp);
bn_free(&temp);

bn_t one;
bn_init(&one, 0);
bn_set_one(&one);
err = bn_sub(&exp, &exp, &one);
bn_free(&one);
if (err) goto cleanup;
}
}

err = bn_copy(r, &result);
cleanup:
bn_free(&result);
bn_free(&base);
bn_free(&exp);
return err;
}

还有模折半

bn_err_t bn_mod_hlv(bn_t *r, const bn_t *a, const bn_t *m) {
if (!r || !a || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;

bn_t a_reduced;
bn_init(&a_reduced, a->used_digs);
bn_copy(&a_reduced, a);

// while (bn_cmp(&a_reduced, m) >= 0) {
// bn_sub(&a_reduced, &a_reduced, m);
// }

bn_mod(&a_reduced, a, m);
bn_err_t err;

if (bn_is_even(&a_reduced)) {
err = bn_rsh(r, &a_reduced, 1);
if (err != BN_SUCCESS) {
bn_free(&a_reduced);
return err;
}
} else {
bn_t sum;
bn_init(&sum, (a_reduced.used_digs > m->used_digs ?
a_reduced.used_digs : m->used_digs) + 1);
err = bn_add(&sum, &a_reduced, m);
if (err != BN_SUCCESS) {
bn_free(&a_reduced);
bn_free(&sum);
return err;
}
err = bn_rsh(r, &sum, 1);
bn_free(&sum);
if (err != BN_SUCCESS) {
bn_free(&a_reduced);
return err;
}
}

bn_free(&a_reduced);
return BN_SUCCESS;
}

进阶模运算

求逆元是模运算中比较常用的

有拓展欧几里得算法,但是这个好像在我的大整数运算中开销特别大,于是后来我改进成了牛顿迭代法,毕竟完成的模乘和模幂等等只需要求的逆元,使用牛顿迭代法特别快

首先是拓展欧几里得算法,因为我还没实现负数的运算,只能使用返回错误码来判断回溯是两个值的正负性

   BN_SUCCESS = 0,
BN_ERR_ALL_NEGATIVE_RESULT = -10,
BN_ERR_FIRST_NEGATIVE_RESULT = -11,
BN_ERR_SECOND_NEGATIVE_RESULT = -12,

也算是比较巧思了吧

bn_err_t extended_gcd(bn_t *gcd, bn_t *x, bn_t *y, const bn_t *a, const bn_t *b) {
if (!gcd || !x || !y || !a || !b) return BN_ERR_NULL_PTR;

if (bn_is_zero(b)) {
bn_copy(gcd, a);
bn_set_one(x);
bn_set_zero(y);
return BN_SUCCESS;
}

bn_t q, r, x1, y1, product, temp;

int cap = a->capacity > b->capacity ? a->capacity : b->capacity;
cap += 2;

bn_init(&q, cap);
bn_init(&r, cap);
bn_init(&x1, cap);
bn_init(&y1, cap);
bn_init(&product, cap);
bn_init(&temp, cap);
bn_set_zero(&x1);
bn_set_zero(&y1);
bn_set_zero(&q);
bn_set_zero(&r);
bn_div(&q, &r, a, b);

bn_err_t err = extended_gcd(gcd, &x1, &y1, b, &r);

bn_copy(x, &y1);
bn_mul(&product, &q, &y1);

// printf("err before %d ",err);
if(err == BN_ERR_FIRST_NEGATIVE_RESULT){
bn_add(y,&product,&x1);
err=BN_ERR_SECOND_NEGATIVE_RESULT;
}
else if(err == BN_ERR_SECOND_NEGATIVE_RESULT){
bn_add(y,&product,&x1);
err = BN_ERR_FIRST_NEGATIVE_RESULT;
}
else if(err == BN_SUCCESS){
if(bn_cmp(&x1, &product) >= 0){
bn_sub(y, &x1, &product);
err = BN_SUCCESS;
}else{
bn_sub(y, &product, &x1);
err = BN_ERR_SECOND_NEGATIVE_RESULT;
}
}
else{
if(bn_cmp(&x1, &product) <= 0){
bn_sub(y, &product, &x1);
err = BN_ERR_FIRST_NEGATIVE_RESULT;
}else{
bn_sub(y, &x1, &product);
err = BN_ERR_ALL_NEGATIVE_RESULT;
}
}
// printf("err after %d ",err);
// bn_print(&x1,"x1 ");
// bn_print(&y1,"y1 ");
// bn_print(&q,"q ");
// bn_print(&r,"r ");
// bn_print(&product,"product ");

bn_free(&q);
bn_free(&r);
bn_free(&x1);
bn_free(&y1);
bn_free(&product);
bn_free(&temp);

return err;
}

bn_err_t bn_mod_inv(bn_t *inv, const bn_t *a, const bn_t *m) {
if (!inv || !a || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m) || bn_is_zero(a)) return BN_ERR_INVALID_PARAM;


bn_t a_mod_m;
bn_init(&a_mod_m, m->capacity + 1);
bn_mod(&a_mod_m, a, m);

if (bn_is_zero(&a_mod_m)) {
bn_free(&a_mod_m);
return BN_ERR_NO_INVERSE;
}
int cap = m->capacity + 2;
bn_t one;
bn_init(&one, cap);
bn_set_one(&one);

if (bn_cmp(&a_mod_m, &one) == 0) {
bn_copy(inv, &one);
bn_free(&one);
bn_free(&a_mod_m);
return BN_SUCCESS;
}

bn_t gcd, x, y;


bn_init(&gcd, cap);
bn_init(&x, cap);
bn_init(&y, cap);
bn_set_zero(&gcd);
bn_set_zero(&x);
bn_set_zero(&y);

bn_err_t err = extended_gcd(&gcd, &x, &y, &a_mod_m, m);
if(bn_cmp(&gcd,&one)!=0){
err = BN_ERR_NO_INVERSE;
goto cleanup_all;
}

bn_t temp1;
bn_init(&temp1,cap);
bn_set_zero(&temp1);

if(err==BN_ERR_FIRST_NEGATIVE_RESULT || err == BN_ERR_ALL_NEGATIVE_RESULT){
bn_sub(inv,m,&x);
err = BN_SUCCESS;
}
else{
bn_copy(inv,&x);
err = BN_SUCCESS;
}
bn_free(&temp1);

cleanup_all:
bn_free(&gcd);
bn_free(&x);
bn_free(&y);
bn_free(&one);
bn_free(&a_mod_m);

return err;
}

牛顿迭代法

bn_err_t bn_mod_inv_usedNewton(bn_t *inv, const bn_t *N, const bn_t *R, const int k){
if (!inv || !N || !R) return BN_ERR_NULL_PTR;
if (bn_is_zero(N) || bn_is_zero(R)) return BN_ERR_INVALID_PARAM;
bn_t inv_0, temp_t, one, two, temp_1;
bn_init(&inv_0, 2*R->used_digs);
bn_init(&temp_t, 2*R->used_digs);
bn_init(&temp_1, 2*R->used_digs);
bn_init(&one, 1);
bn_init(&two, 1);
bn_set_one(&one);
bn_set_one(&inv_0);
bn_set_one(&temp_t);
bn_set_zero(&temp_1);
(&two)->data[0] = 2;
(&two)->used_digs = 1;
for(int i = 1; i<k ;i <<= 1){
bn_mul(&temp_t, N, &inv_0);
if(bn_cmp(&two,&temp_t)>=0){
bn_sub(&temp_t,&two, &temp_t);
}else{
bn_sub(&temp_t,&temp_t,&two);
bn_truncate_bits(&temp_t, MIN(k, 2*i));
bn_lsh(&temp_1, &one, MIN(k,2*i));
bn_sub(&temp_t,&temp_1, &temp_t);
}
bn_mul(&inv_0, &inv_0, &temp_t);
bn_truncate_bits(&inv_0, MIN(k, 2*i));
}
bn_copy(inv,&inv_0);
bn_free(&inv_0);
bn_free(&temp_1);
bn_free(&temp_t);
bn_free(&two);
bn_free(&one);
return BN_SUCCESS;
}
蒙哥马利算法

使用一个结构体储存我们预计算的蒙哥马利上下文

bn_err_t mont_ctx_init(mont_ctx_t *ctx, const bn_t *N) {
if (!ctx || !N) return BN_ERR_NULL_PTR;
if (bn_is_even(N)) return BN_ERR_MODULUS_EVEN;

memset(ctx, 0, sizeof(mont_ctx_t));

bn_err_t err = bn_copy(&ctx->N, N);
if (err != BN_SUCCESS) return err;

ctx->k = bn_get_bits(N);
ctx->k_digs = (ctx->k + WBITS - 1) / WBITS;

err = mont_compute_R(&ctx->R, N, &ctx->k);
if (err != BN_SUCCESS) {
bn_free(&ctx->N);
return err;
}

err = mont_compute_N_prime(&ctx->N_prime, N, &ctx->R);
if (err != BN_SUCCESS) {
bn_free(&ctx->N);
bn_free(&ctx->R);
return err;
}

bn_t R2_pre;
bn_init(&R2_pre, (ctx->k_digs)*2);
bn_mul(&R2_pre, &ctx->R, &ctx->R);
err = bn_mod(&ctx->R2, &R2_pre, N);
bn_free(&R2_pre);

if (err != BN_SUCCESS) {
bn_free(&ctx->N);
bn_free(&ctx->R);
bn_free(&ctx->N_prime);
return err;
}
return BN_SUCCESS;
}

void mont_ctx_free(mont_ctx_t *ctx) {
if (ctx) {
bn_free(&ctx->N);
bn_free(&ctx->R);
bn_free(&ctx->N_prime);
bn_free(&ctx->R2);
memset(ctx, 0, sizeof(mont_ctx_t));
}
}

bn_err_t mont_ctx_compute(mont_ctx_t *ctx, const bn_t *N) {
mont_ctx_free(ctx);
return mont_ctx_init(ctx, N);
}

首先是计算我们的R的位数,使用bn_get_bits就可以了

bn_err_t mont_compute_R(bn_t *R, const bn_t *N, int *k) {
if (!R || !N) return BN_ERR_NULL_PTR;

*k = bn_get_bits(N);
bn_err_t err = bn_set_one(R);
if (err != BN_SUCCESS) return err;

return bn_lsh(R, R, *k);
}

然后是构造我们的R,使用bn_set_bit设置我们上面得到的位为1

然后是另一个很重要的参数 先计算出逆元再相减

bn_err_t mont_compute_N_prime(bn_t *N_prime, const bn_t *N, const bn_t *R) {
if (!N_prime || !N || !R) return BN_ERR_NULL_PTR;

int k = bn_get_bits(N);
int k_digs = (k + WBITS - 1) / WBITS;

bn_t inv;
bn_init(&inv, 0);

bn_err_t err = bn_mod_inv_usedNewton(&inv, N, R, k);
if (err) goto cleanup;

err = bn_sub(N_prime, R, &inv);

cleanup:
bn_free(&inv);
return err;
}

R2就是,便于我们进行蒙哥马利域转换

转换为蒙哥马利形式

bn_err_t mont_map(bn_t *a_mont, const bn_t *a, const mont_ctx_t *ctx) {
if (!a_mont || !a || !ctx) return BN_ERR_NULL_PTR;

bn_t aR2;
bn_init(&aR2, (ctx->k_digs)*2);
bn_set_zero(&aR2);
bn_err_t err = bn_mul(&aR2, a, &ctx->R2);
if (err != BN_SUCCESS) {
bn_free(&aR2);
return err;
}
err = mont_redc_internal(a_mont, &aR2, &ctx->N, &ctx->N_prime, ctx->k);

bn_free(&aR2);
return err;
}

使用模约减返回 调用原来的上下文就可以避免重复计算

bn_err_t mont_reduce(bn_t *a, const bn_t *a_mont, const mont_ctx_t *ctx) {
if (!a || !a_mont || !ctx) return BN_ERR_NULL_PTR;
return mont_redc_internal(a, a_mont, &ctx->N, &ctx->N_prime, ctx->k);
}

调用的函数是

bn_err_t mont_redc_internal(bn_t *r, const bn_t *T, const bn_t *N, const bn_t *N_prime, int k) {
if (!r || !T || !N || !N_prime) return BN_ERR_NULL_PTR;

int k_digs = (k + WBITS - 1) / WBITS;
bn_err_t err;

bn_t temp_T,temp, R;
int T_capacity=T->capacity;
int N_capacity=N->capacity;
int max = (T_capacity>=N_capacity)?T_capacity:N_capacity;
bn_init(&temp_T, max*2);
bn_init(&temp, max*2);
bn_init(&R, max*2);
bn_set_zero(&temp_T);
bn_set_zero(&temp);
bn_set_zero(&R);
bn_set_bit(&R,k,1);

err = bn_copy(&temp_T, T);
if (err) {
bn_free(&temp_T);
return err;
}

bn_mul(&temp, N_prime, &temp_T);
bn_truncate_bits(&temp,k);
while(bn_cmp(&temp,&R)>=0){
bn_sub(&temp,&temp,&R);
}
bn_mul(&temp_T, N, &temp);
bn_add(&temp,&temp_T,T);

err = bn_rsh(&temp_T, &temp, k);

if (bn_cmp(&temp_T, N) >= 0) {
err = bn_sub(r, &temp_T, N);
}else{
bn_copy(r,&temp_T);
}
bn_free(&temp_T);
bn_free(&temp);
return err;
}

由于我们是按位截取的,所以循环减开销时间比较短,比起bn_mod速度快不少

然后就是我们的模乘和模幂的具体代码,都是经典算法

bn_err_t bn_mod_mul_mont(bn_t *r, const bn_t *a, const bn_t *b, const bn_t *m) {
if (!r || !a || !b || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;
if (bn_is_even(m)) return BN_ERR_MODULUS_EVEN;

mont_ctx_t ctx;
bn_err_t err = mont_ctx_init(&ctx, m);
if (err != BN_SUCCESS) return err;

bn_t a_mont, b_mont;
bn_init(&a_mont, (&ctx)->k_digs);
bn_init(&b_mont, (&ctx)->k_digs);

err = mont_map(&a_mont, a, &ctx);
if (err) goto cleanup;

err = mont_map(&b_mont, b, &ctx);
if (err) goto cleanup;

bn_t result_mont;
bn_init(&result_mont, 0);
err = mont_mul(&result_mont, &a_mont, &b_mont, &ctx);
if (err) goto cleanup2;

err = mont_reduce(r, &result_mont, &ctx);

cleanup2:
bn_free(&result_mont);
cleanup:
bn_free(&a_mont);
bn_free(&b_mont);
mont_ctx_free(&ctx);
return err;
}

bn_err_t mont_mul(bn_t *r, const bn_t *a, const bn_t *b, const mont_ctx_t *ctx) {
if (!r || !a || !b || !ctx) return BN_ERR_NULL_PTR;

bn_t T;
bn_init(&T, 0);
bn_err_t err = bn_mul(&T, a, b);
if (err != BN_SUCCESS) {
bn_free(&T);
return err;
}

err = mont_redc_internal(r, &T, &ctx->N, &ctx->N_prime, ctx->k);

bn_free(&T);
return err;
}

模幂算法也是

bn_err_t bn_mod_exp_mont(bn_t *r, const bn_t *a, const bn_t *e, const bn_t *m) {
if (!r || !a || !e || !m) return BN_ERR_NULL_PTR;
if (bn_is_zero(m)) return BN_ERR_INVALID_PARAM;
if (bn_is_even(m)) return BN_ERR_MODULUS_EVEN;

mont_ctx_t ctx;
bn_err_t err = mont_ctx_init(&ctx, m);
if (err != BN_SUCCESS) return err;

bn_t a_mont;
bn_init(&a_mont, 0);
err = mont_map(&a_mont, a, &ctx);
if (err) goto cleanup;

err = mont_exp(r, &a_mont, e, &ctx);
if (err) goto cleanup2;

err = mont_reduce(&a_mont, r, &ctx);
if (err == BN_SUCCESS) {
bn_copy(r, &a_mont);
}

cleanup2:
bn_free(&a_mont);
cleanup:
mont_ctx_free(&ctx);
return err;
}

bn_err_t mont_exp(bn_t *r, const bn_t *a, const bn_t *e, const mont_ctx_t *ctx) {
if (!r || !a || !e || !ctx) return BN_ERR_NULL_PTR;

bn_t result, base;
bn_init(&result, 0);
bn_init(&base, 0);

bn_err_t err = bn_copy(&base, a);
if (err) goto cleanup;

bn_t one;
bn_init(&one, 0);
bn_set_one(&one);
err = mont_map(&result, &one, ctx);
bn_free(&one);
if (err) goto cleanup;
bn_t temp;
bn_init(&temp, 0);
int bits = bn_get_bits(e);
for (int i = bits - 1; i >= 0; i--) {
err = mont_sqr(&temp, &result, ctx);
if (err) {
bn_free(&temp);
goto cleanup;
}
bn_copy(&result, &temp);

if (bn_get_bit(e, i)) {
err = mont_mul(&temp, &result, &base, ctx);
if (err) {
bn_free(&temp);
goto cleanup;
}
bn_copy(&result, &temp);
}
}

err = bn_copy(r, &result);

cleanup:
bn_free(&temp);
bn_free(&result);
bn_free(&base);
return err;
}

这就是我的代码啦

经过测试,速度如下,每一项都是随机的数据

+------------------------------------------------------------------------------+
+--- Bench module ---> bn
BENCH: bn_set_zero = 0.0024 us
BENCH: bn_rand = 0.0864 us
BENCH: bn_rand_security = 0.0656 us
BENCH: bn_copy = 0.0063 us
BENCH: bn_get_bits = 0.0017 us
BENCH: bn_get_bit = 0.0021 us
BENCH: bn_set_bit = 0.0065 us
BENCH: bn_truncate = 0.0013 us
BENCH: bn_truncate_bits = 0.0009 us
BENCH: bn_add = 0.0085 us
BENCH: bn_add_dig = 0.0056 us
BENCH: bn_sub = 0.0096 us
BENCH: bn_sub_dig = 0.0055 us
BENCH: bn_mul = 0.0423 us
BENCH: bn_sqr = 0.0420 us
BENCH: bn_div = 0.0080 us
BENCH: bn_mod = 0.1465 us
BENCH: bn_cmp = 0.0014 us
BENCH: bn_lsh = 0.0072 us
BENCH: bn_rsh = 0.0029 us
BENCH: bn_mod_add = 0.1829 us
BENCH: bn_mod_sub = 0.1042 us
BENCH: bn_mod_hlv = 0.0898 us
BENCH: mont_compute_R = 0.1209 us
BENCH: mod_inv = 139.5866 us
BENCH: mod_inv_useNewton = 1.7169 us
BENCH: mont_compute_N_prime = 2.4169 us
BENCH: mont_ctx_init = 13.0413 us
BENCH: bn_mod_rdc = 0.1674 us
BENCH: bn_mod_mul = 14.3162 us
BENCH: bn_mod_exp = 142.5800 us
+------------------------------------------------------------------------------+

感觉还是有点慢了,可以优化的地方还可以多一些,不过这里就告一段落了