241 lines
5.7 KiB
C++
241 lines
5.7 KiB
C++
#include "muduo/base/Logging.h"
|
|
#include "muduo/net/EventLoop.h"
|
|
#include "muduo/net/InetAddress.h"
|
|
#include "muduo/net/TcpClient.h"
|
|
#include "muduo/net/TcpServer.h"
|
|
|
|
#include <queue>
|
|
#include <utility>
|
|
|
|
#include <stdio.h>
|
|
#include <unistd.h>
|
|
|
|
using namespace muduo;
|
|
using namespace muduo::net;
|
|
|
|
typedef std::shared_ptr<TcpClient> TcpClientPtr;
|
|
|
|
// const int kMaxConns = 1;
|
|
const size_t kMaxPacketLen = 255;
|
|
const size_t kHeaderLen = 3;
|
|
|
|
const uint16_t kListenPort = 9999;
|
|
const char* socksIp = "127.0.0.1";
|
|
const uint16_t kSocksPort = 7777;
|
|
|
|
struct Entry
|
|
{
|
|
int connId;
|
|
TcpClientPtr client;
|
|
TcpConnectionPtr connection;
|
|
Buffer pending;
|
|
};
|
|
|
|
class DemuxServer : noncopyable
|
|
{
|
|
public:
|
|
DemuxServer(EventLoop* loop, const InetAddress& listenAddr, const InetAddress& socksAddr)
|
|
: loop_(loop),
|
|
server_(loop, listenAddr, "DemuxServer"),
|
|
socksAddr_(socksAddr)
|
|
{
|
|
server_.setConnectionCallback(
|
|
std::bind(&DemuxServer::onServerConnection, this, _1));
|
|
server_.setMessageCallback(
|
|
std::bind(&DemuxServer::onServerMessage, this, _1, _2, _3));
|
|
}
|
|
|
|
void start()
|
|
{
|
|
server_.start();
|
|
}
|
|
|
|
void onServerConnection(const TcpConnectionPtr& conn)
|
|
{
|
|
if (conn->connected())
|
|
{
|
|
if (serverConn_)
|
|
{
|
|
conn->shutdown();
|
|
}
|
|
else
|
|
{
|
|
serverConn_ = conn;
|
|
LOG_INFO << "onServerConnection set serverConn_";
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (serverConn_ == conn)
|
|
{
|
|
serverConn_.reset();
|
|
socksConns_.clear();
|
|
|
|
LOG_INFO << "onServerConnection reset serverConn_";
|
|
}
|
|
}
|
|
}
|
|
|
|
void onServerMessage(const TcpConnectionPtr& conn, Buffer* buf, Timestamp)
|
|
{
|
|
while (buf->readableBytes() > kHeaderLen)
|
|
{
|
|
int len = static_cast<uint8_t>(*buf->peek());
|
|
if (buf->readableBytes() < len + kHeaderLen)
|
|
{
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
int connId = static_cast<uint8_t>(buf->peek()[1]);
|
|
connId |= (static_cast<uint8_t>(buf->peek()[2]) << 8);
|
|
|
|
if (connId != 0)
|
|
{
|
|
assert(socksConns_.find(connId) != socksConns_.end());
|
|
TcpConnectionPtr& socksConn = socksConns_[connId].connection;
|
|
if (socksConn)
|
|
{
|
|
assert(socksConns_[connId].pending.readableBytes() == 0);
|
|
socksConn->send(buf->peek() + kHeaderLen, len);
|
|
}
|
|
else
|
|
{
|
|
socksConns_[connId].pending.append(buf->peek() + kHeaderLen, len);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
string cmd(buf->peek() + kHeaderLen, len);
|
|
doCommand(cmd);
|
|
}
|
|
buf->retrieve(len + kHeaderLen);
|
|
}
|
|
}
|
|
}
|
|
|
|
void doCommand(const string& cmd)
|
|
{
|
|
static const string kConn = "CONN ";
|
|
|
|
int connId = atoi(&cmd[kConn.size()]);
|
|
bool isUp = cmd.find(" IS UP") != string::npos;
|
|
LOG_INFO << "doCommand " << connId << " " << isUp;
|
|
if (isUp)
|
|
{
|
|
assert(socksConns_.find(connId) == socksConns_.end());
|
|
char connName[256];
|
|
snprintf(connName, sizeof connName, "SocksClient %d", connId);
|
|
Entry entry;
|
|
entry.connId = connId;
|
|
entry.client.reset(new TcpClient(loop_, socksAddr_, connName));
|
|
entry.client->setConnectionCallback(
|
|
std::bind(&DemuxServer::onSocksConnection, this, connId, _1));
|
|
entry.client->setMessageCallback(
|
|
std::bind(&DemuxServer::onSocksMessage, this, connId, _1, _2, _3));
|
|
// FIXME: setWriteCompleteCallback
|
|
socksConns_[connId] = entry;
|
|
entry.client->connect();
|
|
}
|
|
else
|
|
{
|
|
assert(socksConns_.find(connId) != socksConns_.end());
|
|
TcpConnectionPtr& socksConn = socksConns_[connId].connection;
|
|
if (socksConn)
|
|
{
|
|
socksConn->shutdown();
|
|
}
|
|
else
|
|
{
|
|
socksConns_.erase(connId);
|
|
}
|
|
}
|
|
}
|
|
|
|
void onSocksConnection(int connId, const TcpConnectionPtr& conn)
|
|
{
|
|
assert(socksConns_.find(connId) != socksConns_.end());
|
|
if (conn->connected())
|
|
{
|
|
socksConns_[connId].connection = conn;
|
|
Buffer& pendingData = socksConns_[connId].pending;
|
|
if (pendingData.readableBytes() > 0)
|
|
{
|
|
conn->send(&pendingData);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (serverConn_)
|
|
{
|
|
char buf[256];
|
|
int len = snprintf(buf, sizeof(buf), "DISCONNECT %d\r\n", connId);
|
|
Buffer buffer;
|
|
buffer.append(buf, len);
|
|
sendServerPacket(0, &buffer);
|
|
}
|
|
else
|
|
{
|
|
socksConns_.erase(connId);
|
|
}
|
|
}
|
|
}
|
|
|
|
void onSocksMessage(int connId, const TcpConnectionPtr& conn, Buffer* buf, Timestamp)
|
|
{
|
|
assert(socksConns_.find(connId) != socksConns_.end());
|
|
while (buf->readableBytes() > kMaxPacketLen)
|
|
{
|
|
Buffer packet;
|
|
packet.append(buf->peek(), kMaxPacketLen);
|
|
buf->retrieve(kMaxPacketLen);
|
|
sendServerPacket(connId, &packet);
|
|
}
|
|
if (buf->readableBytes() > 0)
|
|
{
|
|
sendServerPacket(connId, buf);
|
|
}
|
|
}
|
|
|
|
void sendServerPacket(int connId, Buffer* buf)
|
|
{
|
|
size_t len = buf->readableBytes();
|
|
LOG_DEBUG << len;
|
|
assert(len <= kMaxPacketLen);
|
|
uint8_t header[kHeaderLen] = {
|
|
static_cast<uint8_t>(len),
|
|
static_cast<uint8_t>(connId & 0xFF),
|
|
static_cast<uint8_t>((connId & 0xFF00) >> 8)
|
|
};
|
|
buf->prepend(header, kHeaderLen);
|
|
if (serverConn_)
|
|
{
|
|
serverConn_->send(buf);
|
|
}
|
|
}
|
|
|
|
EventLoop* loop_;
|
|
TcpServer server_;
|
|
TcpConnectionPtr serverConn_;
|
|
const InetAddress socksAddr_;
|
|
std::map<int, Entry> socksConns_;
|
|
};
|
|
|
|
int main(int argc, char* argv[])
|
|
{
|
|
LOG_INFO << "pid = " << getpid();
|
|
EventLoop loop;
|
|
InetAddress listenAddr(kListenPort);
|
|
if (argc > 1)
|
|
{
|
|
socksIp = argv[1];
|
|
}
|
|
InetAddress socksAddr(socksIp, kSocksPort);
|
|
DemuxServer server(&loop, listenAddr, socksAddr);
|
|
|
|
server.start();
|
|
|
|
loop.loop();
|
|
}
|
|
|