185 lines
5.7 KiB
C++
185 lines
5.7 KiB
C++
#include <aes.hpp>
|
|
#include <catch_amalgamated.hpp>
|
|
#include <fstream>
|
|
#include <string>
|
|
#include <util.h>
|
|
|
|
#include "assistant.h"
|
|
|
|
const size_t BLOCK_SIZE = 102400; // 100KB块大小
|
|
const size_t IV_SIZE = 16; // 随机值大小
|
|
|
|
struct SpeedRet {
|
|
std::string mode;
|
|
long long file_size;
|
|
long long encry_speed;
|
|
long long decry_speed;
|
|
bool verify;
|
|
};
|
|
|
|
bool test_speed(SpeedRet& ret)
|
|
{
|
|
std::string test_file("1.dat");
|
|
if (!random_file(test_file, 1024 * 1024 * 10)) {
|
|
std::cerr << "Failed to create test file" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
ret.decry_speed = 0;
|
|
ret.encry_speed = 0;
|
|
ret.mode = "";
|
|
ret.verify = false;
|
|
|
|
std::shared_ptr<int> deleter(new int(1), [test_file](int* p) {
|
|
delete p;
|
|
if (fs::exists(test_file)) {
|
|
fs::remove(test_file);
|
|
}
|
|
});
|
|
|
|
if (!fs::exists(test_file)) {
|
|
std::cerr << "Input file not found: " << test_file << std::endl;
|
|
return false;
|
|
}
|
|
|
|
size_t file_size = fs::file_size(test_file);
|
|
ret.file_size = file_size / (1024 * 1024);
|
|
if (file_size == 0) {
|
|
std::cerr << "Input file is empty" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
std::string key = "test_speed_key";
|
|
uint8_t ik[32]{};
|
|
hash(key.c_str(), ik);
|
|
|
|
fs::path decrypted_path = fs::path(test_file).replace_filename(
|
|
fs::path(test_file).stem().string() + "_decrypted" + fs::path(test_file).extension().string());
|
|
|
|
std::ofstream decrypted_file(decrypted_path, std::ios::binary);
|
|
if (!decrypted_file) {
|
|
std::cerr << "Failed to create decrypted file" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
std::ifstream in_file(test_file, std::ios::binary);
|
|
if (!in_file) {
|
|
std::cerr << "Failed to open input file" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
// 测试数据缓冲区(额外预留16字节空间)
|
|
std::vector<uint8_t> original_block(BLOCK_SIZE);
|
|
std::vector<uint8_t> processing_block(BLOCK_SIZE + IV_SIZE); // 加密/解密处理缓冲区
|
|
|
|
size_t total_bytes = 0;
|
|
size_t blocks_processed = 0;
|
|
bool verification_passed = true;
|
|
|
|
auto total_encrypt_time = std::chrono::microseconds(0);
|
|
auto total_decrypt_time = std::chrono::microseconds(0);
|
|
|
|
while (in_file) {
|
|
in_file.read(reinterpret_cast<char*>(original_block.data()), BLOCK_SIZE - IV_SIZE);
|
|
size_t bytes_read = in_file.gcount();
|
|
if (bytes_read == 0)
|
|
break;
|
|
|
|
memcpy(processing_block.data() + IV_SIZE, original_block.data(), bytes_read);
|
|
auto start_encrypt = std::chrono::high_resolution_clock::now();
|
|
if (!encrypt(ik, processing_block.data(), bytes_read + IV_SIZE)) {
|
|
std::cerr << "Encryption failed at block " << blocks_processed << std::endl;
|
|
verification_passed = false;
|
|
break;
|
|
}
|
|
auto end_encrypt = std::chrono::high_resolution_clock::now();
|
|
total_encrypt_time +=
|
|
std::chrono::duration_cast<std::chrono::microseconds>(end_encrypt - start_encrypt);
|
|
|
|
auto start_decrypt = std::chrono::high_resolution_clock::now();
|
|
if (!decrypt(ik, processing_block.data(), bytes_read + IV_SIZE)) {
|
|
std::cerr << "Decryption failed at block " << blocks_processed << std::endl;
|
|
verification_passed = false;
|
|
break;
|
|
}
|
|
auto end_decrypt = std::chrono::high_resolution_clock::now();
|
|
total_decrypt_time +=
|
|
std::chrono::duration_cast<std::chrono::microseconds>(end_decrypt - start_decrypt);
|
|
|
|
if (memcmp(original_block.data(), processing_block.data() + IV_SIZE, bytes_read) != 0) {
|
|
std::cerr << "Data mismatch at block " << blocks_processed << std::endl;
|
|
verification_passed = false;
|
|
break;
|
|
}
|
|
|
|
decrypted_file.write(reinterpret_cast<const char*>(processing_block.data() + IV_SIZE), bytes_read);
|
|
total_bytes += bytes_read;
|
|
blocks_processed++;
|
|
}
|
|
|
|
in_file.close();
|
|
decrypted_file.close();
|
|
|
|
#if !defined(NDEBUG) || defined(_DEBUG) || defined(DEBUG)
|
|
// Debug 模式
|
|
ret.mode = "Debug";
|
|
#else
|
|
// Release 模式
|
|
ret.mode = "Release";
|
|
#endif
|
|
|
|
// 计算吞吐量(只计算有效数据部分)
|
|
double encrypt_throughput =
|
|
(double)total_bytes / (1024 * 1024) / (total_encrypt_time.count() / 1000000.0);
|
|
double decrypt_throughput =
|
|
(double)total_bytes / (1024 * 1024) / (total_decrypt_time.count() / 1000000.0);
|
|
|
|
ret.encry_speed = encrypt_throughput;
|
|
ret.decry_speed = decrypt_throughput;
|
|
ret.verify = verification_passed;
|
|
|
|
fs::remove(decrypted_path);
|
|
return verification_passed;
|
|
}
|
|
|
|
bool correctness_test()
|
|
{
|
|
std::string key = "demokey";
|
|
uint8_t ik[32]{};
|
|
hash(key.c_str(), ik);
|
|
|
|
int offset = 16;
|
|
char* msg = new char[256]{};
|
|
std::shared_ptr<int> deleter(new int(), [msg](int* p) {
|
|
delete p;
|
|
delete[] msg;
|
|
});
|
|
|
|
char source[] = "hello world";
|
|
memset(msg, 0, 256);
|
|
auto len = std::snprintf(msg + offset, 256 - offset, "%s", source);
|
|
if (!encrypt(ik, (uint8_t*)msg, len + offset)) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t ik2[32]{};
|
|
hash(key.c_str(), ik2);
|
|
if (!decrypt(ik2, (uint8_t*)msg, len + offset)) {
|
|
return false;
|
|
}
|
|
return std::memcmp(source, msg + offset, len) == 0;
|
|
}
|
|
|
|
TEST_CASE("transm encry part", "[encry]")
|
|
{
|
|
SECTION("speed of encryption")
|
|
{
|
|
SpeedRet ret{};
|
|
auto r = test_speed(ret);
|
|
UNSCOPED_INFO("Encryption mode: " << ret.mode << "");
|
|
UNSCOPED_INFO("FileSize: " << ret.file_size << " MB");
|
|
UNSCOPED_INFO("Encryption speed: " << ret.encry_speed << " MB/s");
|
|
UNSCOPED_INFO("Decryption speed: " << ret.decry_speed << " MB/s");
|
|
REQUIRE(r == true);
|
|
}
|
|
} |