236 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			236 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
// Copyright 2011, Shuo Chen.  All rights reserved.
 | 
						|
// http://code.google.com/p/muduo/
 | 
						|
//
 | 
						|
// Use of this source code is governed by a BSD-style license
 | 
						|
// that can be found in the License file.
 | 
						|
//
 | 
						|
// Author: Shuo Chen (chenshuo at chenshuo dot com)
 | 
						|
 | 
						|
#include "examples/protobuf/codec/codec.h"
 | 
						|
 | 
						|
#include "muduo/base/Logging.h"
 | 
						|
#include "muduo/net/Endian.h"
 | 
						|
#include "muduo/net/protorpc/google-inl.h"
 | 
						|
 | 
						|
#include <google/protobuf/descriptor.h>
 | 
						|
 | 
						|
#include <zlib.h>  // adler32
 | 
						|
 | 
						|
using namespace muduo;
 | 
						|
using namespace muduo::net;
 | 
						|
 | 
						|
void ProtobufCodec::fillEmptyBuffer(Buffer* buf, const google::protobuf::Message& message)
 | 
						|
{
 | 
						|
  // buf->retrieveAll();
 | 
						|
  assert(buf->readableBytes() == 0);
 | 
						|
 | 
						|
  const std::string& typeName = message.GetTypeName();
 | 
						|
  int32_t nameLen = static_cast<int32_t>(typeName.size()+1);
 | 
						|
  buf->appendInt32(nameLen);
 | 
						|
  buf->append(typeName.c_str(), nameLen);
 | 
						|
 | 
						|
  // code copied from MessageLite::SerializeToArray() and MessageLite::SerializePartialToArray().
 | 
						|
  GOOGLE_DCHECK(message.IsInitialized()) << InitializationErrorMessage("serialize", message);
 | 
						|
 | 
						|
  /**
 | 
						|
   * 'ByteSize()' of message is deprecated in Protocol Buffers v3.4.0 firstly.
 | 
						|
   * But, till to v3.11.0, it just getting start to be marked by '__attribute__((deprecated()))'.
 | 
						|
   * So, here, v3.9.2 is selected as maximum version using 'ByteSize()' to avoid
 | 
						|
   * potential effect for previous muduo code/projects as far as possible.
 | 
						|
   * Note: All information above just INFER from
 | 
						|
   * 1) https://github.com/protocolbuffers/protobuf/releases/tag/v3.4.0
 | 
						|
   * 2) MACRO in file 'include/google/protobuf/port_def.inc'.
 | 
						|
   * eg. '#define PROTOBUF_DEPRECATED_MSG(msg) __attribute__((deprecated(msg)))'.
 | 
						|
   * In addition, usage of 'ToIntSize()' comes from Impl of ByteSize() in new version's Protocol Buffers.
 | 
						|
   */
 | 
						|
 | 
						|
  #if GOOGLE_PROTOBUF_VERSION > 3009002
 | 
						|
    int byte_size = google::protobuf::internal::ToIntSize(message.ByteSizeLong());
 | 
						|
  #else
 | 
						|
    int byte_size = message.ByteSize();
 | 
						|
  #endif
 | 
						|
  buf->ensureWritableBytes(byte_size);
 | 
						|
 | 
						|
  uint8_t* start = reinterpret_cast<uint8_t*>(buf->beginWrite());
 | 
						|
  uint8_t* end = message.SerializeWithCachedSizesToArray(start);
 | 
						|
  if (end - start != byte_size)
 | 
						|
  {
 | 
						|
    #if GOOGLE_PROTOBUF_VERSION > 3009002
 | 
						|
      ByteSizeConsistencyError(byte_size, google::protobuf::internal::ToIntSize(message.ByteSizeLong()), static_cast<int>(end - start));
 | 
						|
    #else
 | 
						|
      ByteSizeConsistencyError(byte_size, message.ByteSize(), static_cast<int>(end - start));
 | 
						|
    #endif
 | 
						|
  }
 | 
						|
  buf->hasWritten(byte_size);
 | 
						|
 | 
						|
  int32_t checkSum = static_cast<int32_t>(
 | 
						|
      ::adler32(1,
 | 
						|
                reinterpret_cast<const Bytef*>(buf->peek()),
 | 
						|
                static_cast<int>(buf->readableBytes())));
 | 
						|
  buf->appendInt32(checkSum);
 | 
						|
  assert(buf->readableBytes() == sizeof nameLen + nameLen + byte_size + sizeof checkSum);
 | 
						|
  int32_t len = sockets::hostToNetwork32(static_cast<int32_t>(buf->readableBytes()));
 | 
						|
  buf->prepend(&len, sizeof len);
 | 
						|
}
 | 
						|
 | 
						|
//
 | 
						|
// no more google code after this
 | 
						|
//
 | 
						|
 | 
						|
//
 | 
						|
// FIXME: merge with RpcCodec
 | 
						|
//
 | 
						|
 | 
						|
namespace
 | 
						|
{
 | 
						|
  const string kNoErrorStr = "NoError";
 | 
						|
  const string kInvalidLengthStr = "InvalidLength";
 | 
						|
  const string kCheckSumErrorStr = "CheckSumError";
 | 
						|
  const string kInvalidNameLenStr = "InvalidNameLen";
 | 
						|
  const string kUnknownMessageTypeStr = "UnknownMessageType";
 | 
						|
  const string kParseErrorStr = "ParseError";
 | 
						|
  const string kUnknownErrorStr = "UnknownError";
 | 
						|
}
 | 
						|
 | 
						|
const string& ProtobufCodec::errorCodeToString(ErrorCode errorCode)
 | 
						|
{
 | 
						|
  switch (errorCode)
 | 
						|
  {
 | 
						|
   case kNoError:
 | 
						|
     return kNoErrorStr;
 | 
						|
   case kInvalidLength:
 | 
						|
     return kInvalidLengthStr;
 | 
						|
   case kCheckSumError:
 | 
						|
     return kCheckSumErrorStr;
 | 
						|
   case kInvalidNameLen:
 | 
						|
     return kInvalidNameLenStr;
 | 
						|
   case kUnknownMessageType:
 | 
						|
     return kUnknownMessageTypeStr;
 | 
						|
   case kParseError:
 | 
						|
     return kParseErrorStr;
 | 
						|
   default:
 | 
						|
     return kUnknownErrorStr;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
void ProtobufCodec::defaultErrorCallback(const muduo::net::TcpConnectionPtr& conn,
 | 
						|
                                         muduo::net::Buffer* buf,
 | 
						|
                                         muduo::Timestamp,
 | 
						|
                                         ErrorCode errorCode)
 | 
						|
{
 | 
						|
  LOG_ERROR << "ProtobufCodec::defaultErrorCallback - " << errorCodeToString(errorCode);
 | 
						|
  if (conn && conn->connected())
 | 
						|
  {
 | 
						|
    conn->shutdown();
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
int32_t asInt32(const char* buf)
 | 
						|
{
 | 
						|
  int32_t be32 = 0;
 | 
						|
  ::memcpy(&be32, buf, sizeof(be32));
 | 
						|
  return sockets::networkToHost32(be32);
 | 
						|
}
 | 
						|
 | 
						|
void ProtobufCodec::onMessage(const TcpConnectionPtr& conn,
 | 
						|
                              Buffer* buf,
 | 
						|
                              Timestamp receiveTime)
 | 
						|
{
 | 
						|
  while (buf->readableBytes() >= kMinMessageLen + kHeaderLen)
 | 
						|
  {
 | 
						|
    const int32_t len = buf->peekInt32();
 | 
						|
    if (len > kMaxMessageLen || len < kMinMessageLen)
 | 
						|
    {
 | 
						|
      errorCallback_(conn, buf, receiveTime, kInvalidLength);
 | 
						|
      break;
 | 
						|
    }
 | 
						|
    else if (buf->readableBytes() >= implicit_cast<size_t>(len + kHeaderLen))
 | 
						|
    {
 | 
						|
      ErrorCode errorCode = kNoError;
 | 
						|
      MessagePtr message = parse(buf->peek()+kHeaderLen, len, &errorCode);
 | 
						|
      if (errorCode == kNoError && message)
 | 
						|
      {
 | 
						|
        messageCallback_(conn, message, receiveTime);
 | 
						|
        buf->retrieve(kHeaderLen+len);
 | 
						|
      }
 | 
						|
      else
 | 
						|
      {
 | 
						|
        errorCallback_(conn, buf, receiveTime, errorCode);
 | 
						|
        break;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    else
 | 
						|
    {
 | 
						|
      break;
 | 
						|
    }
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
google::protobuf::Message* ProtobufCodec::createMessage(const std::string& typeName)
 | 
						|
{
 | 
						|
  google::protobuf::Message* message = NULL;
 | 
						|
  const google::protobuf::Descriptor* descriptor =
 | 
						|
    google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(typeName);
 | 
						|
  if (descriptor)
 | 
						|
  {
 | 
						|
    const google::protobuf::Message* prototype =
 | 
						|
      google::protobuf::MessageFactory::generated_factory()->GetPrototype(descriptor);
 | 
						|
    if (prototype)
 | 
						|
    {
 | 
						|
      message = prototype->New();
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return message;
 | 
						|
}
 | 
						|
 | 
						|
MessagePtr ProtobufCodec::parse(const char* buf, int len, ErrorCode* error)
 | 
						|
{
 | 
						|
  MessagePtr message;
 | 
						|
 | 
						|
  // check sum
 | 
						|
  int32_t expectedCheckSum = asInt32(buf + len - kHeaderLen);
 | 
						|
  int32_t checkSum = static_cast<int32_t>(
 | 
						|
      ::adler32(1,
 | 
						|
                reinterpret_cast<const Bytef*>(buf),
 | 
						|
                static_cast<int>(len - kHeaderLen)));
 | 
						|
  if (checkSum == expectedCheckSum)
 | 
						|
  {
 | 
						|
    // get message type name
 | 
						|
    int32_t nameLen = asInt32(buf);
 | 
						|
    if (nameLen >= 2 && nameLen <= len - 2*kHeaderLen)
 | 
						|
    {
 | 
						|
      std::string typeName(buf + kHeaderLen, buf + kHeaderLen + nameLen - 1);
 | 
						|
      // create message object
 | 
						|
      message.reset(createMessage(typeName));
 | 
						|
      if (message)
 | 
						|
      {
 | 
						|
        // parse from buffer
 | 
						|
        const char* data = buf + kHeaderLen + nameLen;
 | 
						|
        int32_t dataLen = len - nameLen - 2*kHeaderLen;
 | 
						|
        if (message->ParseFromArray(data, dataLen))
 | 
						|
        {
 | 
						|
          *error = kNoError;
 | 
						|
        }
 | 
						|
        else
 | 
						|
        {
 | 
						|
          *error = kParseError;
 | 
						|
        }
 | 
						|
      }
 | 
						|
      else
 | 
						|
      {
 | 
						|
        *error = kUnknownMessageType;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    else
 | 
						|
    {
 | 
						|
      *error = kInvalidNameLen;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  else
 | 
						|
  {
 | 
						|
    *error = kCheckSumError;
 | 
						|
  }
 | 
						|
 | 
						|
  return message;
 | 
						|
}
 |