#include "util.h"

#include <cstdint>
#include <iostream>
#include <of_util.h>
#include <thread>

CTransProtocal::CTransProtocal() = default;

CTransProtocal::~CTransProtocal() = default;

/*
【 transm TCP 数据协议 】
    header 2 char: 0xFF 0xFE
    type   2 char:
    mark   1 char:
    from  32 char:
    to    32 char:
    len    4 char:
    data    xxxxx:
    tail   2 char: 0xFF 0xFF
 */
CFrameBuffer* CTransProtocal::parse(CMutBuffer& buffer)
{
    CFrameBuffer* result = nullptr;
    unsigned char header[] = {0xFF, 0xFE};
    unsigned char tail[] = {0xFF, 0xFF};

    // 如果超出 1MB的内容都无法解析成功,则认为是有无效客户端参与链接。
    auto cur_len = static_cast<size_t>(buffer.get_len());
    if (cur_len > MAX_FRAME_SIZE) {
        buffer.clear();
        // 这里故意延迟。
        std::this_thread::sleep_for(std::chrono::seconds(600));
        return result;
    }

    int find = buffer.index_of((const char*)header, sizeof(header));
    if (find < 0) {
        return result;
    }
    // reinterpret_cast 的指针转换直接访问内存可能导致未对齐问题(某些架构下)
    int16_t type{};
    char mark{};
    int32_t len{};
    std::memcpy(&type, buffer.get_data() + find + 2, sizeof(type));
    std::memcpy(&mark, buffer.get_data() + find + 2 + 2, sizeof(mark));
    std::memcpy(&len, buffer.get_data() + find + 2 + 2 + 1 + 32 + 32, sizeof(len));

    int32_t tail_index = find + 2 + 2 + 1 + 32 + 32 + 4 + len;
    if (buffer.get_len() - 2 < tail_index || len < 0) {
        return result;
    }
    if (std::memcmp(buffer.get_data() + tail_index, tail, sizeof(tail)) != 0) {
        return result;
    }
    result = new CFrameBuffer();
    if (len > 0) {
        result->data_ = new char[len]();
    }
    result->len_ = len;
    result->fid_ = std::string(buffer.get_data() + find + 2 + 2 + 1);
    result->tid_ = std::string(buffer.get_data() + find + 2 + 2 + 1 + 32);
    result->mark_ = mark;
    result->type_ = static_cast<FrameType>(type);
    if (len > 0) {
        std::memcpy(result->data_, buffer.get_data() + find + 2 + 2 + 1 + 4 + 32 + 32, len);
    }
    buffer.remove_of(0, tail_index + 2);
    return result;
}

bool CTransProtocal::pack(CFrameBuffer* buf, char** out_buf, int& len)
{
    if (buf == nullptr) {
        return false;
    }
    if (buf->data_ == nullptr) {
        buf->len_ = 0;
    }
    unsigned char header[] = {0xFF, 0xFE};
    unsigned char tail[] = {0xFF, 0xFF};
    len = buf->len_ + 75;
    *out_buf = new char[len]{};
    std::memset(*out_buf, 0x0, len);
    std::memcpy(*out_buf, header, 2);
    std::memcpy(*out_buf + 2, &buf->type_, 2);
    std::memcpy(*out_buf + 2 + 2, &buf->mark_, 1);
    if (!buf->fid_.empty()) {
        std::memcpy(*out_buf + 2 + 2 + 1, buf->fid_.data(), buf->fid_.size());
    }
    if (!buf->tid_.empty()) {
        std::memcpy(*out_buf + 2 + 2 + 1 + 32, buf->tid_.data(), buf->tid_.size());
    }
    std::memcpy(*out_buf + 2 + 2 + 1 + 32 + 32, &buf->len_, 4);
    if (buf->data_ != nullptr) {
        std::memcpy(*out_buf + 2 + 2 + 1 + 32 + 32 + 4, buf->data_, buf->len_);
    }
    std::memcpy(*out_buf + len - 2, tail, 2);
    return true;
}

void CTransProtocal::display_progress(float percent)
{
    if (percent > 1.0 || percent < 0.0) {
        return;
    }
    const int barWidth = 38;
    int pos = static_cast<int>(barWidth * percent);

    std::cout << "[";
    for (int i = 0; i < barWidth; ++i) {
        if (i < pos) {
            std::cout << "=";
        } else if (i == pos) {
            std::cout << ">";
        } else {
            std::cout << " ";
        }
    }
    // \r 回到行首
    std::cout << "] " << int(percent * 100.0f) << " %\r";
    std::cout.flush();
}

CFrameBuffer::CFrameBuffer() = default;

CFrameBuffer::~CFrameBuffer()
{
    delete[] data_;
    len_ = 0;
}

void serialize(const CMessageInfo& msg_info, char** out_buf, int& len)
{
    CMessageInfo info(msg_info);
    info.cmd = localtou8(info.cmd);
    info.uuid = localtou8(info.uuid);
    info.str = localtou8(info.str);
    info.o = localtou8(info.o);

    // 计算总长度
    len = sizeof(int) * 4 + info.cmd.size() + info.uuid.size() + info.str.size() + info.o.size();
    *out_buf = new char[len];   // 分配内存(调用方负责释放)

    char* ptr = *out_buf;

    // 序列化 cmd
    int cmd_size = static_cast<int>(info.cmd.size());
    memcpy(ptr, &cmd_size, sizeof(int));
    ptr += sizeof(int);
    memcpy(ptr, info.cmd.data(), cmd_size);
    ptr += cmd_size;

    // 序列化 uuid
    int uuid_size = static_cast<int>(info.uuid.size());
    memcpy(ptr, &uuid_size, sizeof(int));
    ptr += sizeof(int);
    memcpy(ptr, info.uuid.data(), uuid_size);
    ptr += uuid_size;

    // 序列化 str
    int str_size = static_cast<int>(info.str.size());
    memcpy(ptr, &str_size, sizeof(int));
    ptr += sizeof(int);
    memcpy(ptr, info.str.data(), str_size);
    ptr += str_size;

    // 序列化 o
    int o_size = static_cast<int>(info.o.size());
    memcpy(ptr, &o_size, sizeof(int));
    ptr += sizeof(int);
    memcpy(ptr, info.o.data(), o_size);
}

bool deserialize(const char* data, int len, CMessageInfo& msg_info)
{
    CMessageInfo info;
    const char* ptr = data;
    int remaining = len;

    // 反序列化 cmd
    if (remaining < static_cast<int>(sizeof(int))) {
        return false;
    }

    int cmd_size;
    memcpy(&cmd_size, ptr, sizeof(int));
    ptr += sizeof(int);
    remaining -= sizeof(int);
    if (remaining < cmd_size) {
        return false;
    }

    info.cmd.assign(ptr, cmd_size);
    ptr += cmd_size;
    remaining -= cmd_size;

    // 反序列化 uuid
    if (remaining < static_cast<int>(sizeof(int))) {
        return false;
    }

    int uuid_size;
    memcpy(&uuid_size, ptr, sizeof(int));
    ptr += sizeof(int);
    remaining -= sizeof(int);
    if (remaining < uuid_size) {
        return false;
    }

    info.uuid.assign(ptr, uuid_size);
    ptr += uuid_size;
    remaining -= uuid_size;

    // 反序列化 str
    if (remaining < static_cast<int>(sizeof(int))) {
        return false;
    }

    int str_size;
    memcpy(&str_size, ptr, sizeof(int));
    ptr += sizeof(int);
    remaining -= sizeof(int);
    if (remaining < str_size) {
        return false;
    }

    info.str.assign(ptr, str_size);
    ptr += str_size;
    remaining -= str_size;

    // 反序列化 o
    if (remaining < static_cast<int>(sizeof(int))) {
        return false;
    }
    int o_size;
    memcpy(&o_size, ptr, sizeof(int));
    ptr += sizeof(int);
    remaining -= sizeof(int);
    if (remaining < o_size) {
        return false;
    }
    info.o.assign(ptr, o_size);

    info.cmd = u8tolocal(info.cmd);
    info.uuid = u8tolocal(info.uuid);
    info.str = u8tolocal(info.str);
    info.o = u8tolocal(info.o);
    msg_info = info;

    return true;
}

std::string u8tolocal(const std::string& str)
{
#ifdef _WIN32
    return CCodec::u8_to_ansi(str);
#else
    return str;
#endif
}

std::string localtou8(const std::string& str)
{
#ifdef _WIN32
    return CCodec::ansi_to_u8(str);
#else
    return str;
#endif
}

CMessageInfo::CMessageInfo(const CMessageInfo& info)
{
    if (&info == this) {
        return;
    }
    cmd = info.cmd;
    uuid = info.uuid;
    str = info.str;
    o = info.o;
}

CMessageInfo& CMessageInfo::operator=(const CMessageInfo& info)
{
    if (&info == this) {
        return *this;
    }
    cmd = info.cmd;
    uuid = info.uuid;
    str = info.str;
    o = info.o;
    return *this;
}