#include "RelayServer.h"
#include <Communicate.h>
#include <InfoCommunicate.hpp>

RemoteServer::RemoteServer()
{
}

bool RemoteServer::Init(const wxString& ip, unsigned short port)
{
    thRun_= true;
    wxIPV4address addr;

    if (!addr.Hostname(ip)) {
        wxLogError(wxT("Invalid IP address: %s"), ip);
        return false;
    }

    addr.Service(port);
    server_ = std::make_unique<wxSocketServer>(addr);
    if (!server_->IsOk()) {
        wxLogError(wxT("Failed to create server socket."));
        return false;
    }
    if (!server_->GetLocal(addr)) {
        wxLogError(wxT("Failed to get local address."));
        return false;
    }
    wxLogMessage(wxT("Server socket created on %s:%d"), addr.IPAddress(), addr.Service());
    //wxLogInfo(wxT("Server socket created on %s:%d"), addr.IPAddress(), addr.Service());

    serverId_ = wxNewId();
    server_->SetFlags(wxSOCKET_WAITALL);
    server_->SetEventHandler(*this, serverId_);

    server_->SetNotify(wxSOCKET_CONNECTION_FLAG | wxSOCKET_LOST_FLAG);
    server_->Notify(true);
    Bind(wxEVT_SOCKET, &RemoteServer::OnServerEvent, this, serverId_);

    return true;
}

int RemoteServer::Run()
{
    wxEventLoop loop;
    return loop.Run();
}

void RemoteServer::OnServerEvent(wxSocketEvent& event)
{
    auto* sock = event.GetSocket();
    switch (event.GetSocketEvent()) {
    case wxSOCKET_CONNECTION: {
        auto newer = std::shared_ptr<wxSocketBase>(server_->Accept(false));
        if (!newer) {
            wxLogError(wxT("Failed to accept client connection."));
            return;
        }
        wxIPV4address addr;
        newer->GetPeer(addr);
        wxString id = wxString::Format("%s:%d", addr.IPAddress(), addr.Service());
        wxLogMessage(wxT("Client connected: %s"), id);

        std::unique_lock<std::shared_mutex> lock(clientsMutex_);
        auto client = std::make_shared<TranClient>();
        client->wxSock = newer;
        client->onlineTime =
            std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
        client->lastRecvTime = std::chrono::high_resolution_clock::now();
        clients_[id] = client;
        threads_[id] = std::thread(&RemoteServer::thClientThread, this, newer, id);
        break;
    }
    case wxSOCKET_LOST: {
        wxIPV4address addr;
        sock->GetPeer(addr);
        wxString id = wxString::Format("%s:%d", addr.IPAddress(), addr.Service());
        wxLogMessage(wxT("Client disconnected: %s"), id);
        std::unique_lock<std::shared_mutex> lock(clientsMutex_);
        if (clients_.find(id) != clients_.end()) {
            clients_.erase(id);
        }
        if (threads_.find(id) != threads_.end()) {
            threads_[id].detach();
            threads_.erase(id);
        }
        break;
    }
    default:
        break;
    }
}

void RemoteServer::thClientThread(const std::shared_ptr<wxSocketBase>& wxSock, const wxString& id)
{
    wxLogMessage(wxT("Client thread started: %s"), id);
    std::shared_ptr<TranClient> client = nullptr;

    {
        std::shared_lock<std::shared_mutex> lock(clientsMutex_);
        client = clients_[id];
    }

    InfoCommunicate info;
    while (thRun_) {
        wxSock->Read(client->buf.data(), gBufferSize);
        auto br = wxSock->LastCount();
        if (br == 0) {
            wxLogMessage(wxT("Client disconnected: %s"), id);
            break;
        } else if (wxSock->Error()) {
            wxLogMessage(wxT("%s Client error: %s"), id, wxSock->LastError());
            break;
        }
        client->buffer.Push(client->buf.data(), br);
        while (true) {
            auto* frame = Communicate::ParseBuffer(client->buffer);
            if (!frame) {
                break;
            }
            std::stringstream ss;
            ss.write(frame->data, frame->len);
            cereal::BinaryInputArchive inputArchive(ss);
            inputArchive(info);
            delete frame;
        }
    }
}