#include "box_rsa.h"

#include "openssl/err.h"
#include "openssl/evp.h"
#include "openssl/pem.h"
#include <cassert>
#include <openssl/rsa.h>
#include <string>

constexpr size_t g_buffsize = 2048;

namespace cppbox {

class CRSAOperatorImp
{
public:
    CRSAOperatorImp()
    {
        err_ = new char[g_buffsize + 1];
        ioerr_ = BIO_new(BIO_s_mem());
    }
    ~CRSAOperatorImp()
    {
        BIO_free(ioerr_);
        delete[] err_;
    }

public:
    bool encrypt_pub(const HData& public_pem, const HData& data, HData& result)
    {
        mem_ = BIO_new_mem_buf((void*)public_pem.data, public_pem.len);
        key_ = PEM_read_bio_PUBKEY(mem_, nullptr, nullptr, nullptr);
        if (key_ == nullptr) {
            ERR_print_errors(ioerr_);
            clear();
            return false;
        }
        ctx_ = EVP_PKEY_CTX_new(key_, nullptr);
        EVP_PKEY_encrypt_init(ctx_);
        if (EVP_PKEY_encrypt(ctx_, nullptr, &result.len, data.data, data.len) <= 0) {
            clear();
            ERR_print_errors(ioerr_);
            return false;
        }
        alloc_data(result);
        if (EVP_PKEY_encrypt(ctx_, result.data, &result.len, data.data, data.len) <= 0) {
            free_data(result);
            clear();
            ERR_print_errors(ioerr_);
            return false;
        }
        clear();
        return true;
    }

    bool decrypt_pri(const HData& private_pem, const HData& data, HData& result)
    {
        mem_ = BIO_new_mem_buf((void*)private_pem.data, private_pem.len);
        key_ = PEM_read_bio_PrivateKey(mem_, nullptr, nullptr, nullptr);
        if (key_ == nullptr) {
            ERR_print_errors(ioerr_);
            clear();
            return false;
        }
        ctx_ = EVP_PKEY_CTX_new(key_, nullptr);
        EVP_PKEY_decrypt_init(ctx_);
        if (EVP_PKEY_decrypt(ctx_, nullptr, &result.len, data.data, data.len) <= 0) {
            clear();
            ERR_print_errors(ioerr_);
            return false;
        }
        alloc_data(result);
        if (EVP_PKEY_decrypt(ctx_, result.data, &result.len, data.data, data.len) <= 0) {
            free_data(result);
            clear();
            ERR_print_errors(ioerr_);
            return false;
        }
        result.data[result.len] = '\0';
        clear();
        return true;
    }

    bool encrypt_pub(const char* pub_path, const HData& data, HData& result)
    {
        FILE* fp = fopen(pub_path, "r");
        if (fp == nullptr) {
            std::snprintf(err_, g_buffsize, "Read File %s Failed.", pub_path);
            return false;
        }

        HData file_data{};
        fseek(fp, 0, SEEK_END);
        file_data.len = ftell(fp);
        fseek(fp, 0, SEEK_SET);

        alloc_data(file_data);
        if (file_data.data == NULL) {
            fclose(fp);
            std::snprintf(err_, g_buffsize, "Alloc Mem Failed: %zd", file_data.len);
            return false;
        }
        file_data.len = fread(file_data.data, 1, file_data.len, fp);
        fclose(fp);
        bool ret = encrypt_pub(file_data, data, result);
        free(file_data.data);
        return ret;
    }

    bool decrypt_pri(const char* pri_path, const HData& data, HData& result)
    {
        FILE* fp = fopen(pri_path, "r");
        if (fp == nullptr) {
            std::snprintf(err_, g_buffsize, "Read File %s Failed.", pri_path);
            return false;
        }

        HData file_data{};
        fseek(fp, 0, SEEK_END);
        file_data.len = ftell(fp);
        fseek(fp, 0, SEEK_SET);

        alloc_data(file_data);
        if (file_data.data == NULL) {
            fclose(fp);
            std::snprintf(err_, g_buffsize, "Alloc Mem Failed: %zd", file_data.len);
            return false;
        }
        file_data.len = fread(file_data.data, 1, file_data.len, fp);
        fclose(fp);
        bool ret = decrypt_pri(file_data, data, result);
        free(file_data.data);
        return ret;
    }

    bool generate_keypair(const char* pub_path, const char* pri_path)
    {
        FILE* fp_pub = fopen(pub_path, "w");
        if (fp_pub == nullptr) {
            std::snprintf(err_, g_buffsize, "Open File Failed: %s", pub_path);
            return false;
        }
        FILE* fp_pri = fopen(pri_path, "w");
        if (fp_pri == nullptr) {
            std::snprintf(err_, g_buffsize, "Open File Failed: %s", pri_path);
            fclose(fp_pub);
            return false;
        }
        key_ = EVP_RSA_gen(g_buffsize / 2);
        if (key_ == nullptr) {
            ERR_print_errors(ioerr_);
            return false;
        }
        PEM_write_PUBKEY(fp_pub, key_);
        PEM_write_PrivateKey(fp_pri, key_, nullptr, nullptr, 0, nullptr, nullptr);
        clear();

        fclose(fp_pub);
        fclose(fp_pri);

        return true;
    }

    bool generate_keypair(HData& pub, HData& pri)
    {
        key_ = EVP_RSA_gen(g_buffsize / 2);
        if (key_ == nullptr) {
            ERR_print_errors_fp(stderr);
            return false;
        }
        BIO* mem_pub = BIO_new(BIO_s_mem());
        if (mem_pub == nullptr) {
            std::snprintf(err_, g_buffsize, "Alloc public BIO Failed.");
            return false;
        }
        BIO* mem_pri = BIO_new(BIO_s_mem());
        if (mem_pri == nullptr) {
            BIO_free(mem_pub);
            std::snprintf(err_, g_buffsize, "Alloc private BIO Failed.");
            return false;
        }
        PEM_write_bio_PUBKEY(mem_pub, key_);
        PEM_write_bio_PrivateKey(mem_pri, key_, nullptr, nullptr, 0, nullptr, nullptr);

        pub.len = BIO_ctrl_pending(mem_pub);
        pri.len = BIO_ctrl_pending(mem_pri);

        alloc_data(pub);
        alloc_data(pri);

        pub.len = BIO_read(mem_pub, pub.data, pub.len);
        pri.len = BIO_read(mem_pri, pri.data, pri.len);

        clear();
        return true;
    }

    void free_data(HData& data)
    {
        free(data.data);
        data.len = 0;
    }
    void alloc_data(HData& data)
    {
        if (data.len < 1) {
            data.data = nullptr;
            return;
        }
        data.data = static_cast<unsigned char*>(malloc(data.len + 1));
    }

    void get_last_error(char* buf, int len)
    {
        int read_len = BIO_read(ioerr_, err_, g_buffsize);
        std::snprintf(buf, len, err_, read_len);
        buf[read_len] = '\0';
    }

private:
    void clear()
    {
        BIO_free(mem_);
        EVP_PKEY_free(key_);
        EVP_PKEY_CTX_free(ctx_);

        mem_ = nullptr;
        ctx_ = nullptr;
        key_ = nullptr;
    }

private:
    EVP_PKEY*     key_{};
    EVP_PKEY_CTX* ctx_{};
    BIO*          mem_{};
    BIO*          ioerr_{};
    char*         err_{};
};

CRSAOperator::CRSAOperator()
{
    imp_ = new CRSAOperatorImp();
    err_ = new char[g_buffsize];
}
CRSAOperator::~CRSAOperator()
{
    delete imp_;
    delete[] err_;
}

bool CRSAOperator::encrypt_pub(const HData& public_pem, const HData& data, HData& result)
{
    assert(imp_);
    return imp_->encrypt_pub(public_pem, data, result);
}
bool CRSAOperator::encrypt_pub(const char* pub_path, const HData& data, HData& result)
{
    assert(imp_);
    return imp_->encrypt_pub(pub_path, data, result);
}
bool CRSAOperator::decrypt_pri(const HData& private_pem, const HData& data, HData& result)
{
    assert(imp_);
    return imp_->decrypt_pri(private_pem, data, result);
}
bool CRSAOperator::decrypt_pri(const char* pri_path, const HData& data, HData& result)
{
    assert(imp_);
    return imp_->decrypt_pri(pri_path, data, result);
}
bool CRSAOperator::generate_keypair(const char* pub_path, const char* pri_path)
{
    assert(imp_);
    return imp_->generate_keypair(pub_path, pri_path);
}
bool CRSAOperator::generate_keypair(HData& pub, HData& pri)
{
    assert(imp_);
    return imp_->generate_keypair(pub, pri);
}
const char* CRSAOperator::get_last_error() const
{
    assert(imp_);
    imp_->get_last_error(err_, g_buffsize);
    return err_;
}

void CRSAOperator::free_hdata(HData& data)
{
    assert(imp_);
    imp_->free_data(data);
}

}   // namespace cppbox