2024-03-08 14:03:37 +08:00

237 lines
5.1 KiB
C++

#include "muduo/base/CountDownLatch.h"
#include "muduo/base/Logging.h"
#include "muduo/net/EventLoopThread.h"
#include "muduo/net/TcpClient.h"
#include <boost/tokenizer.hpp>
#include "examples/wordcount/hash.h"
#include <fstream>
#include <iostream>
#include <stdio.h>
#define __STDC_FORMAT_MACROS
#include <inttypes.h>
using namespace muduo;
using namespace muduo::net;
size_t g_batchSize = 65536;
const size_t kMaxHashSize = 10 * 1000 * 1000;
class SendThrottler : muduo::noncopyable
{
public:
SendThrottler(EventLoop* loop, const InetAddress& addr)
: client_(loop, addr, "Sender"),
connectLatch_(1),
disconnectLatch_(1),
cond_(mutex_),
congestion_(false)
{
LOG_INFO << "SendThrottler [" << addr.toIpPort() << "]";
client_.setConnectionCallback(
std::bind(&SendThrottler::onConnection, this, _1));
}
void connect()
{
client_.connect();
connectLatch_.wait();
}
void disconnect()
{
if (buffer_.readableBytes() > 0)
{
LOG_DEBUG << "send " << buffer_.readableBytes() << " bytes";
conn_->send(&buffer_);
}
conn_->shutdown();
disconnectLatch_.wait();
}
void send(const string& word, int64_t count)
{
buffer_.append(word);
// FIXME: use LogStream
char buf[64];
snprintf(buf, sizeof buf, "\t%" PRId64 "\r\n", count);
buffer_.append(buf);
if (buffer_.readableBytes() >= g_batchSize)
{
throttle();
LOG_TRACE << "send " << buffer_.readableBytes();
conn_->send(&buffer_);
}
}
private:
void onConnection(const TcpConnectionPtr& conn)
{
if (conn->connected())
{
conn->setHighWaterMarkCallback(
std::bind(&SendThrottler::onHighWaterMark, this), 1024*1024);
conn->setWriteCompleteCallback(
std::bind(&SendThrottler::onWriteComplete, this));
conn_ = conn;
connectLatch_.countDown();
}
else
{
conn_.reset();
disconnectLatch_.countDown();
}
}
void onHighWaterMark()
{
MutexLockGuard lock(mutex_);
congestion_ = true;
}
void onWriteComplete()
{
MutexLockGuard lock(mutex_);
bool oldCong = congestion_;
congestion_ = false;
if (oldCong)
{
cond_.notify();
}
}
void throttle()
{
MutexLockGuard lock(mutex_);
while (congestion_)
{
LOG_DEBUG << "wait ";
cond_.wait();
}
}
TcpClient client_;
TcpConnectionPtr conn_;
CountDownLatch connectLatch_;
CountDownLatch disconnectLatch_;
Buffer buffer_;
MutexLock mutex_;
Condition cond_;
bool congestion_;
};
class WordCountSender : muduo::noncopyable
{
public:
explicit WordCountSender(const std::string& receivers);
void connectAll()
{
for (size_t i = 0; i < buckets_.size(); ++i)
{
buckets_[i]->connect();
}
LOG_INFO << "All connected";
}
void disconnectAll()
{
for (size_t i = 0; i < buckets_.size(); ++i)
{
buckets_[i]->disconnect();
}
LOG_INFO << "All disconnected";
}
void processFile(const char* filename);
private:
EventLoopThread loopThread_;
EventLoop* loop_;
std::vector<std::unique_ptr<SendThrottler>> buckets_;
};
WordCountSender::WordCountSender(const std::string& receivers)
: loop_(loopThread_.startLoop())
{
typedef boost::tokenizer<boost::char_separator<char> > tokenizer;
boost::char_separator<char> sep(", ");
tokenizer tokens(receivers, sep);
for (tokenizer::iterator tok_iter = tokens.begin();
tok_iter != tokens.end(); ++tok_iter)
{
std::string ipport = *tok_iter;
size_t colon = ipport.find(':');
if (colon != std::string::npos)
{
uint16_t port = static_cast<uint16_t>(atoi(&ipport[colon+1]));
InetAddress addr(ipport.substr(0, colon), port);
buckets_.emplace_back(new SendThrottler(loop_, addr));
}
else
{
assert(0 && "Invalid address");
}
}
}
void WordCountSender::processFile(const char* filename)
{
LOG_INFO << "processFile " << filename;
WordCountMap wordcounts;
// FIXME: use mmap to read file
std::ifstream in(filename);
string word;
// FIXME: make local hash optional.
std::hash<string> hash;
while (in)
{
wordcounts.clear();
while (in >> word)
{
wordcounts[word] += 1;
if (wordcounts.size() > kMaxHashSize)
{
break;
}
}
LOG_INFO << "send " << wordcounts.size() << " records";
for (WordCountMap::iterator it = wordcounts.begin();
it != wordcounts.end(); ++it)
{
size_t idx = hash(it->first) % buckets_.size();
buckets_[idx]->send(it->first, it->second);
}
}
}
int main(int argc, char* argv[])
{
if (argc < 3)
{
printf("Usage: %s addresses_of_receivers input_file1 [input_file2]* \n", argv[0]);
printf("Example: %s 'ip1:port1,ip2:port2,ip3:port3' input_file1 input_file2 \n", argv[0]);
}
else
{
const char* batchSize = ::getenv("BATCH_SIZE");
if (batchSize)
{
g_batchSize = atoi(batchSize);
}
WordCountSender sender(argv[1]);
sender.connectAll();
for (int i = 2; i < argc; ++i)
{
sender.processFile(argv[i]);
}
sender.disconnectAll();
}
}