399 lines
9.9 KiB
C++
399 lines
9.9 KiB
C++
#include "util.h"
|
|
|
|
#include <aes.hpp>
|
|
#include <cstdint>
|
|
#include <iostream>
|
|
#include <of_util.h>
|
|
#include <random>
|
|
#include <thread>
|
|
|
|
CTransProtocal::CTransProtocal() = default;
|
|
constexpr uint8_t kz = 16;
|
|
static bool use_encrypt = true;
|
|
CTransProtocal::~CTransProtocal() = default;
|
|
|
|
/*
|
|
【 transm TCP 数据协议 】
|
|
header 2 char: 0xFF 0xFE
|
|
type 2 char:
|
|
mark 1 char:
|
|
from 32 char:
|
|
to 32 char:
|
|
len 4 char:
|
|
data xxxxx:
|
|
tail 2 char: 0xFF 0xFF
|
|
*/
|
|
CFrameBuffer* CTransProtocal::parse(CMutBuffer& buffer)
|
|
{
|
|
CFrameBuffer* result = nullptr;
|
|
unsigned char header[] = {0xFF, 0xFE};
|
|
unsigned char tail[] = {0xFF, 0xFF};
|
|
|
|
// 如果超出 1MB的内容都无法解析成功,则认为是有无效客户端参与链接。
|
|
auto cur_len = static_cast<size_t>(buffer.get_len());
|
|
if (cur_len > MAX_FRAME_SIZE) {
|
|
buffer.clear();
|
|
// 这里故意延迟。
|
|
std::this_thread::sleep_for(std::chrono::seconds(600));
|
|
return result;
|
|
}
|
|
|
|
int find = buffer.index_of((const char*)header, sizeof(header));
|
|
if (find < 0) {
|
|
return result;
|
|
}
|
|
// reinterpret_cast 的指针转换直接访问内存可能导致未对齐问题(某些架构下)
|
|
int16_t type{};
|
|
char mark{};
|
|
int32_t len{};
|
|
std::memcpy(&type, buffer.get_data() + find + 2, sizeof(type));
|
|
std::memcpy(&mark, buffer.get_data() + find + 2 + 2, sizeof(mark));
|
|
std::memcpy(&len, buffer.get_data() + find + 2 + 2 + 1 + 32 + 32, sizeof(len));
|
|
|
|
int32_t tail_index = find + 2 + 2 + 1 + 32 + 32 + 4 + len;
|
|
if (buffer.get_len() - 2 < tail_index || len < 0) {
|
|
return result;
|
|
}
|
|
if (std::memcmp(buffer.get_data() + tail_index, tail, sizeof(tail)) != 0) {
|
|
return result;
|
|
}
|
|
result = new CFrameBuffer();
|
|
if (len > 0) {
|
|
result->data_ = new char[len]();
|
|
}
|
|
result->len_ = len;
|
|
result->fid_ = std::string(buffer.get_data() + find + 2 + 2 + 1);
|
|
result->tid_ = std::string(buffer.get_data() + find + 2 + 2 + 1 + 32);
|
|
result->mark_ = mark;
|
|
result->type_ = static_cast<FrameType>(type);
|
|
if (len > 0) {
|
|
std::memcpy(result->data_, buffer.get_data() + find + 2 + 2 + 1 + 4 + 32 + 32, len);
|
|
}
|
|
buffer.remove_of(0, tail_index + 2);
|
|
return result;
|
|
}
|
|
|
|
bool CTransProtocal::pack(CFrameBuffer* buf, char** out_buf, int& len)
|
|
{
|
|
if (buf == nullptr) {
|
|
return false;
|
|
}
|
|
if (buf->data_ == nullptr) {
|
|
buf->len_ = 0;
|
|
}
|
|
unsigned char header[] = {0xFF, 0xFE};
|
|
unsigned char tail[] = {0xFF, 0xFF};
|
|
len = buf->len_ + 75;
|
|
*out_buf = new char[len]{};
|
|
std::memset(*out_buf, 0x0, len);
|
|
std::memcpy(*out_buf, header, 2);
|
|
std::memcpy(*out_buf + 2, &buf->type_, 2);
|
|
std::memcpy(*out_buf + 2 + 2, &buf->mark_, 1);
|
|
if (!buf->fid_.empty()) {
|
|
std::memcpy(*out_buf + 2 + 2 + 1, buf->fid_.data(), buf->fid_.size());
|
|
}
|
|
if (!buf->tid_.empty()) {
|
|
std::memcpy(*out_buf + 2 + 2 + 1 + 32, buf->tid_.data(), buf->tid_.size());
|
|
}
|
|
std::memcpy(*out_buf + 2 + 2 + 1 + 32 + 32, &buf->len_, 4);
|
|
if (buf->data_ != nullptr) {
|
|
std::memcpy(*out_buf + 2 + 2 + 1 + 32 + 32 + 4, buf->data_, buf->len_);
|
|
}
|
|
std::memcpy(*out_buf + len - 2, tail, 2);
|
|
return true;
|
|
}
|
|
|
|
void CTransProtocal::display_progress(float percent)
|
|
{
|
|
if (percent > 1.0 || percent < 0.0) {
|
|
return;
|
|
}
|
|
const int barWidth = 38;
|
|
int pos = static_cast<int>(barWidth * percent);
|
|
|
|
std::cout << "[";
|
|
for (int i = 0; i < barWidth; ++i) {
|
|
if (i < pos) {
|
|
std::cout << "=";
|
|
} else if (i == pos) {
|
|
std::cout << ">";
|
|
} else {
|
|
std::cout << " ";
|
|
}
|
|
}
|
|
// \r 回到行首
|
|
std::cout << "] " << int(percent * 100.0f) << " %\r";
|
|
std::cout.flush();
|
|
}
|
|
|
|
CFrameBuffer::CFrameBuffer() = default;
|
|
|
|
CFrameBuffer::~CFrameBuffer()
|
|
{
|
|
delete[] data_;
|
|
len_ = 0;
|
|
}
|
|
|
|
void serialize(CMessageInfo& msg_info, char** out_buf, int& len, bool reuse_mem)
|
|
{
|
|
auto& info = msg_info;
|
|
info.id = localtou8(info.id);
|
|
info.uuid = localtou8(info.uuid);
|
|
info.str = localtou8(info.str);
|
|
|
|
// 计算总长度
|
|
len = sizeof(int) * 4 + info.id.size() + info.uuid.size() + info.str.size() + info.data.size() + kz + 1;
|
|
|
|
// 《这里为了效率》,
|
|
// 认为如果 *out_buf 不为空,则直接使用,且长度符合要求
|
|
// 调用方负责确保内存够用性(len <= 可用最大空间长度)和内存可用性。
|
|
// 即,如果调用方及高频率调用 serialize, 且每次 len <= 已分配空间就复用内存,完了再释放。
|
|
// 低频率或者 len 不固定时,每次都释放内存,并置 nullptr。
|
|
if (*out_buf) {
|
|
if (!reuse_mem) {
|
|
delete[] *out_buf;
|
|
*out_buf = new char[len]; // 分配内存(调用方负责释放)
|
|
}
|
|
} else {
|
|
*out_buf = new char[len];
|
|
}
|
|
|
|
std::memset(*out_buf, 0x0, kz + 1);
|
|
char* ptr = *out_buf + kz + 1;
|
|
|
|
// 序列化 cmd
|
|
int id_size = static_cast<int>(info.id.size());
|
|
memcpy(ptr, &id_size, sizeof(int));
|
|
ptr += sizeof(int);
|
|
memcpy(ptr, info.id.data(), id_size);
|
|
ptr += id_size;
|
|
|
|
// 序列化 uuid
|
|
int uuid_size = static_cast<int>(info.uuid.size());
|
|
memcpy(ptr, &uuid_size, sizeof(int));
|
|
ptr += sizeof(int);
|
|
memcpy(ptr, info.uuid.data(), uuid_size);
|
|
ptr += uuid_size;
|
|
|
|
// 序列化 str
|
|
int str_size = static_cast<int>(info.str.size());
|
|
memcpy(ptr, &str_size, sizeof(int));
|
|
ptr += sizeof(int);
|
|
memcpy(ptr, info.str.data(), str_size);
|
|
ptr += str_size;
|
|
|
|
// 序列化 o
|
|
int o_size = static_cast<int>(info.data.size());
|
|
memcpy(ptr, &o_size, sizeof(int));
|
|
ptr += sizeof(int);
|
|
memcpy(ptr, info.data.data(), o_size);
|
|
|
|
char* mark = *out_buf;
|
|
if (!use_encrypt) {
|
|
mark[0] = 0x00;
|
|
return;
|
|
}
|
|
uint8_t ik[32]{};
|
|
hash(msg_info.id.c_str(), ik);
|
|
encrypt(ik, (uint8_t*)(*out_buf + 1), len - 1);
|
|
mark[0] = 0x01;
|
|
}
|
|
|
|
bool deserialize(char* data, int len, CMessageInfo& msg_info)
|
|
{
|
|
if (len < (kz + 1)) {
|
|
return false;
|
|
}
|
|
|
|
auto& info = msg_info;
|
|
char* ptr = data + kz + 1;
|
|
uint8_t mark = data[0];
|
|
int remaining = len;
|
|
|
|
if (mark != 0x00) {
|
|
uint8_t ik[32]{};
|
|
hash(msg_info.id.c_str(), ik);
|
|
if (!decrypt(ik, (uint8_t*)(data + 1), len - 1)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// 反序列化 cmd
|
|
if (remaining < static_cast<int>(sizeof(int))) {
|
|
return false;
|
|
}
|
|
|
|
int id_size{};
|
|
memcpy(&id_size, ptr, sizeof(int));
|
|
ptr += sizeof(int);
|
|
remaining -= sizeof(int);
|
|
if (remaining < id_size) {
|
|
return false;
|
|
}
|
|
|
|
info.id.assign(ptr, id_size);
|
|
ptr += id_size;
|
|
remaining -= id_size;
|
|
|
|
// 反序列化 uuid
|
|
if (remaining < static_cast<int>(sizeof(int))) {
|
|
return false;
|
|
}
|
|
|
|
int uuid_size{};
|
|
memcpy(&uuid_size, ptr, sizeof(int));
|
|
ptr += sizeof(int);
|
|
remaining -= sizeof(int);
|
|
if (remaining < uuid_size) {
|
|
return false;
|
|
}
|
|
|
|
info.uuid.assign(ptr, uuid_size);
|
|
ptr += uuid_size;
|
|
remaining -= uuid_size;
|
|
|
|
// 反序列化 str
|
|
if (remaining < static_cast<int>(sizeof(int))) {
|
|
return false;
|
|
}
|
|
|
|
int str_size{};
|
|
memcpy(&str_size, ptr, sizeof(int));
|
|
ptr += sizeof(int);
|
|
remaining -= sizeof(int);
|
|
if (remaining < str_size) {
|
|
return false;
|
|
}
|
|
|
|
info.str.assign(ptr, str_size);
|
|
ptr += str_size;
|
|
remaining -= str_size;
|
|
|
|
// 反序列化 o
|
|
if (remaining < static_cast<int>(sizeof(int))) {
|
|
return false;
|
|
}
|
|
int o_size{};
|
|
memcpy(&o_size, ptr, sizeof(int));
|
|
ptr += sizeof(int);
|
|
remaining -= sizeof(int);
|
|
if (remaining < o_size) {
|
|
return false;
|
|
}
|
|
info.data.assign(ptr, o_size);
|
|
|
|
info.id = u8tolocal(info.id);
|
|
info.uuid = u8tolocal(info.uuid);
|
|
info.str = u8tolocal(info.str);
|
|
return true;
|
|
}
|
|
|
|
std::string u8tolocal(const std::string& str)
|
|
{
|
|
#ifdef _WIN32
|
|
return CCodec::u8_to_ansi(str);
|
|
#else
|
|
return str;
|
|
#endif
|
|
}
|
|
|
|
std::string localtou8(const std::string& str)
|
|
{
|
|
#ifdef _WIN32
|
|
return CCodec::ansi_to_u8(str);
|
|
#else
|
|
return str;
|
|
#endif
|
|
}
|
|
|
|
void hash(const char* data, uint8_t k[32])
|
|
{
|
|
uint32_t h = 5381;
|
|
for (const char* p = data; *p; p++) {
|
|
h = ((h << 5) + h) + *p; // DJB2
|
|
}
|
|
for (int i = 0; i < 32; i++) {
|
|
k[i] = (h >> (i % 4 * 8)) & 0xFF;
|
|
}
|
|
}
|
|
|
|
void rdm(uint8_t* o, size_t size)
|
|
{
|
|
/*
|
|
需要加密安全:坚持用random_device(慢)
|
|
需要性能:用 mt19937 + uniform_int_distribution
|
|
*/
|
|
std::random_device rd;
|
|
std::mt19937 gen(rd());
|
|
std::uniform_int_distribution<int> dist(0, 255);
|
|
std::generate(o, o + size, [&]() { return static_cast<uint8_t>(dist(gen)); });
|
|
}
|
|
|
|
bool encrypt(const uint8_t* k, uint8_t* m, size_t len)
|
|
{
|
|
if (len < kz) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t nonce[kz]{};
|
|
rdm(nonce, sizeof(nonce) - 4);
|
|
memcpy(m, nonce, kz);
|
|
|
|
struct AES_ctx ctx;
|
|
AES_init_ctx_iv(&ctx, k, nonce);
|
|
AES_CTR_xcrypt_buffer(&ctx, m + kz, len - kz);
|
|
return true;
|
|
}
|
|
|
|
bool decrypt(const uint8_t* k, uint8_t* m, size_t len)
|
|
{
|
|
if (len < kz) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t nonce[kz]{};
|
|
memcpy(nonce, m, kz);
|
|
|
|
struct AES_ctx ctx;
|
|
AES_init_ctx_iv(&ctx, k, nonce);
|
|
AES_CTR_xcrypt_buffer(&ctx, m + kz, len - kz);
|
|
return true;
|
|
}
|
|
|
|
void set_encrypt(bool encrypt)
|
|
{
|
|
use_encrypt = encrypt;
|
|
}
|
|
|
|
bool get_encrypt_status()
|
|
{
|
|
return use_encrypt;
|
|
}
|
|
|
|
CMessageInfo::CMessageInfo(const std::string& id) : id(id)
|
|
{
|
|
}
|
|
|
|
CMessageInfo::CMessageInfo(const CMessageInfo& info)
|
|
{
|
|
if (&info == this) {
|
|
return;
|
|
}
|
|
id = info.id;
|
|
uuid = info.uuid;
|
|
str = info.str;
|
|
data.assign(info.data.begin(), info.data.end());
|
|
}
|
|
|
|
CMessageInfo& CMessageInfo::operator=(const CMessageInfo& info)
|
|
{
|
|
if (&info == this) {
|
|
return *this;
|
|
}
|
|
id = info.id;
|
|
uuid = info.uuid;
|
|
str = info.str;
|
|
data.assign(info.data.begin(), info.data.end());
|
|
return *this;
|
|
}
|