openai-api/server.cxx

163 lines
5.0 KiB
C++

#include "server.h"
#include "util.hpp"
#include <iostream>
Server::Server(asio::io_context& io_context, short port) : io_context_(io_context), acceptor_(io_context)
{
port_ = port;
}
Server::~Server()
{
for (auto& client : clients_) {
client.second.detach();
}
}
void Server::start()
{
asio::ip::tcp::endpoint endpoint(asio::ip::tcp::v4(), port_);
try {
acceptor_.open(endpoint.protocol());
// acceptor_.set_option(asio::socket_base::reuse_address(true));
acceptor_.bind(endpoint);
acceptor_.listen();
do_accept();
} catch (const std::exception& e) {
#ifdef _WIN32
std::cerr << ansi_to_u8(e.what()) << '\n';
#else
std::cerr << e.what() << '\n';
#endif
}
}
void Server::stop()
{
}
void Server::set_worker(std::shared_ptr<COpenAI> worker, std::shared_ptr<CJsonOper> json)
{
worker_ = worker;
json_ = json;
}
void Server::set_token(long tokens)
{
tokens_ = tokens;
}
void Server::do_accept()
{
auto socket = std::make_shared<asio::ip::tcp::socket>(io_context_);
acceptor_.async_accept(*socket, [this, socket](const std::error_code& ec) {
if (!ec) {
auto endpoint = socket->remote_endpoint();
std::string client_key = endpoint.address().to_string() + ":" + std::to_string(endpoint.port());
std::unique_lock<std::mutex> lock(cli_mutex_);
client_map_[client_key] = std::make_shared<ClientCache>();
clients_.insert(std::make_pair(socket->remote_endpoint().address().to_string(),
std::thread([this, socket, client_key]() { th_client(socket, client_key); })));
}
do_accept();
});
}
void Server::th_client(const std::shared_ptr<asio::ip::tcp::socket>& socket, const std::string& client_key)
{
std::shared_ptr<int> deleter(new int(0), [&](int* p) {
std::unique_lock<std::mutex> lock(cli_mutex_);
delete p;
client_map_.erase(client_key);
if (clients_.find(client_key) != clients_.end()) {
clients_.at(client_key).detach();
clients_.erase(client_key);
}
std::cout << "th_client deleter client " << client_key << "exit." << std::endl;
});
asio::error_code error;
std::shared_ptr<ClientCache> cache = nullptr;
{
std::unique_lock<std::mutex> lock(cli_mutex_);
cache = client_map_[client_key];
}
while (true) {
auto len = socket->read_some(asio::buffer(cache->tmp_buf_), error);
if (error == asio::error::eof) {
break; // Connection closed cleanly by peer.
} else if (error) {
break; // Some other error.
}
cache->buffer_.push(cache->tmp_buf_.data(), len);
while (true) {
auto frame = com_parse(cache->buffer_);
if (frame == nullptr) {
break;
}
if (use_tokens_ > tokens_) {
std::cout << client_key << " tokens not enough" << std::endl;
FrameData req;
req.type = FrameType::TYPE_OUT_OF_LIMIT;
send_frame(socket, req);
continue;
}
std::cout << client_key << " 's data." << std::endl;
if (frame->type == FrameType::TYPE_REQUEST) {
ask_mutex_.lock();
std::string recv_data(frame->data, frame->len);
std::string out{};
if (!worker_->post(post_data(recv_data), out)) {
std::cout << client_key << " data post error" << std::endl;
FrameData req;
req.type = FrameType::TYPE_RESPONSE_ERROR;
send_frame(socket, req);
} else {
auto parse = json_->parse(out);
FrameData req;
req.type = FrameType::TYPE_RESPONSE_SUCCESS;
req.len = parse.message_content.size() + 1;
req.data = new char[req.len];
req.protk = parse.prompt_tokens;
req.coptk = parse.completion_tokens;
use_tokens_ += req.protk;
use_tokens_ += req.coptk;
std::cout << "Already use " << use_tokens_ << " tokens.\n";
memcpy(req.data, parse.message_content.c_str(), parse.message_content.size());
req.data[req.len - 1] = '\0';
send_frame(socket, req);
}
ask_mutex_.unlock();
}
delete frame;
}
}
}
std::string Server::post_data(const std::string& data)
{
return json_->format_request(data);
}
bool Server::send_frame(const std::shared_ptr<asio::ip::tcp::socket>& socket, FrameData& data)
{
asio::error_code error;
char* send_data{};
int len{};
if (!com_pack(&data, &send_data, len)) {
return false;
}
auto send_len = socket->send(asio::buffer(send_data, len));
delete[] send_data;
return send_len == len;
}