Chinaunix首页 | 论坛 | 博客
  • 博客访问: 96142
  • 博文数量: 14
  • 博客积分: 1445
  • 博客等级: 上尉
  • 技术积分: 240
  • 用 户 组: 普通用户
  • 注册时间: 2008-01-16 19:33
文章分类

全部博文(14)

文章存档

2011年(1)

2009年(2)

2008年(11)

我的朋友

分类: LINUX

2008-04-28 16:30:37

OpenSSL的运用比较广泛了,最主要的也许是它是开源的。

下面是我的一个简单例子:

#include
#include
#include
#include

#include "openssl/rsa.h"
#include "openssl/crypto.h"
#include "openssl/x509.h"
#include "openssl/pem.h"
#include "openssl/ssl.h"
#include "openssl/err.h"
#include "openssl/rand.h"

class SSLClientSocket
{
public:
    static const char HEADER_LINE_DELIMITER = '\n';

public:
    SSLClientSocket ();
   
    virtual ~SSLClientSocket ();

public:
    static int InitWinsock ();
   
 static int InitCtx(bool bServer,int VerType);

public:
    bool IsOpened () const;

    int Connect (const CString& host, int port);

    void Disconnect ();

    int Poll ();

public:
    virtual void OnConnected () = 0;

    virtual void OnDisconnected () = 0;

    virtual void OnMessage (MXMessage& msg) = 0;
   
    int SendData (const CString& buf);

private:
    void ResetSocketState ();

    int ParseHeaderLine (const CString &line);

    int ProcessInputBuffer ();

 int SSLrecv (char *buf, int len);
 int SSLsend (char *buf, int len);
private:
    SOCKET _socket;
    CString _input_buf;
    CString _output_buf;

    bool _is_reading_body;
    int _body_bytes_to_read;

    MXMessage _tmp_message;

private:
 int err;
 static SSL_CTX* ctx;
 static SSL*     ssl;
 static X509*    server_cert;
 static SSL_METHOD *meth;
 static int       seed_int[100]; /*存放随机序列*/

 WSADATA wsaData;

public:
 static CString strCERTF;  /*客户端的证书(需经CA签名)*/
 static CString strKEYF;  /*客户端的私钥(建议加密存储)*/
 static CString strCACERT;  /*CA 的证书*/
};

#endif // SSL_CLIENTSOCKET_H


#include "stdafx.h"
#include "ssl_clientsocket.h"

SSL_CTX* SSLClientSocket::ctx;
SSL*  SSLClientSocket::ssl;
X509*  SSLClientSocket::server_cert;
SSL_METHOD* SSLClientSocket::meth;
int   SSLClientSocket::seed_int[100]; /*存放随机序列*/

CString  SSLClientSocket::strCERTF="";  /*客户端的证书(需经CA签名)*/
CString  SSLClientSocket::strKEYF="";  /*客户端的私钥(建议加密存储)*/
CString  SSLClientSocket::strCACERT="";  /*CA 的证书*/


int SSLClientSocket::InitWinsock ()
{
 WORD sockVersion;
 WSADATA wsaData;

 sockVersion = MAKEWORD (2, 2);

    TRACE0 ("Initializing WinSock...\n");

 // Initialize Winsock as before
 if (WSAStartup (sockVersion, &wsaData) != NO_ERROR)
    {
        TRACE0 ("WinSock failed to initialize !!!\n");
        return -1;
    }
    TRACE0 ("WinSock initialized successfully\n");

 strCERTF = "gameclient-cert.pem";  /*客户端的证书(需经CA签名)*/
 strKEYF = "gameclient-key.pem";   /*客户端的私钥(建议加密存储)*/
 strCACERT = "cacert.pem";    /*CA 的证书*/

 int nret = InitCtx (false,0);
 TRACE1("The Return of Ctx initialization is %d\n",nret);

    return nret;
}

int SSLClientSocket::InitCtx(bool bServer,int VerType)
{
 OpenSSL_add_ssl_algorithms();     /*初始化*/
 SSL_load_error_strings();      /*为打印调试信息作准备*/

 if (bServer)         /*采用什么协议(SSLv2/SSLv3/TLSv1)在此指定*/
  meth = TLSv1_server_method();  
 else
  meth = TLSv1_client_method();
 ctx = SSL_CTX_new (meth);
 if (ctx == NULL)
  return -2;

 SSL_CTX_set_verify (ctx, SSL_VERIFY_PEER, NULL);   /*验证与否*/
 SSL_CTX_load_verify_locations (ctx, strCACERT.GetBuffer(), NULL); /*若验证,则放置CA证书*/

 if (SSL_CTX_use_certificate_file (ctx, strCERTF.GetBuffer(), SSL_FILETYPE_PEM) <= 0) {
  //ERR_print_errors_fp(stderr);
  return -3;
 }
 if (SSL_CTX_use_PrivateKey_file(ctx, strKEYF.GetBuffer(), SSL_FILETYPE_PEM) <= 0) {
  //ERR_print_errors_fp(stderr);
  return -4;
 }

 if (!SSL_CTX_check_private_key(ctx)) {
  TRACE0("Private key does not match the certificate public key\n");
  return -5;
 } 

 /*构建随机数生成机制,WIN32平台必需*/
 srand((unsigned)time (NULL));
 for( int i = 0; i < 100; i++ )
  seed_int[i] = rand();
 RAND_seed(seed_int, sizeof(seed_int));
 return 0;
}

SSLClientSocket::SSLClientSocket ()
    : _socket (0)
{
    this->ResetSocketState ();
 //this->InitWinsock ();
}

SSLClientSocket::~SSLClientSocket ()
{
    if (this->_socket != 0)
    {
        TRACE0 ("SSLClientSocket being destructed. Closing socket...\n");
        this->Disconnect ();
    }
}

//
// Public Methods
//

bool SSLClientSocket::IsOpened () const
{
    return this->_socket != 0;
}

int SSLClientSocket::Poll ()
{
    if (this->_socket == 0)
        return -1;

    fd_set read_fds;
    fd_set write_fds;

    FD_ZERO (&read_fds);
    FD_ZERO (&write_fds);

    FD_SET (this->_socket, &read_fds);
    if (!this->_output_buf.IsEmpty ())
    {
        FD_SET (this->_socket, &write_fds);
    }

    timeval timeout = { 0, 0 };
    int nret = select (1, &read_fds, &write_fds, 0, &timeout);
    if (nret == SOCKET_ERROR)
        return -2;

    if (FD_ISSET (this->_socket, &write_fds))
    {
        int len = SSLsend (this->_output_buf.GetBuffer (), this->_output_buf.GetLength ());
        if (len > 0)
        {
            if (len == this->_output_buf.GetLength ())
            {
                this->_output_buf.Empty ();
            }
            else
            {
                this->_output_buf = this->_output_buf.Mid (len);
            }
        }
        else if (len == SOCKET_ERROR)
        {
            return -3;
        }
    }
    if (FD_ISSET (this->_socket, &read_fds))
    {
        const int BUF_SIZE = 1024;
        char buf[BUF_SIZE];

        int len = SSLrecv (buf, BUF_SIZE);
        if (len > 0)
        {
            this->_input_buf.Append (buf, len);

            this->ProcessInputBuffer ();
        }
        else if (len == 0 || len == SOCKET_ERROR)
        {
            // TODO: Check for connection reset
            this->Disconnect ();
            this->OnDisconnected ();
            return -4;
        }
    }

    return 0;
}

void SSLClientSocket::Disconnect ()
{
    if (this->_socket != 0)
    {
        shutdown (this->_socket, 2);
        closesocket (this->_socket);
  return;
        this->_socket = 0;
        this->ResetSocketState ();
  if (ssl != NULL) SSL_free (ssl);
  if (ctx != NULL) SSL_CTX_free (ctx); 
    }
}

int SSLClientSocket::Connect(const CString& host, int port)
{
    if (this->_socket != 0)
    {
        TRACE0 ("Attempt to connect with existing connection. Closing socket...\n");
        this->Disconnect ();
    }

    _socket = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
    if (_socket == INVALID_SOCKET)
    {
        int nret = WSAGetLastError ();
        return -255;
    }

 LPHOSTENT hostEntry;

 // Specifying the server by its name;
    // another option is gethostbyaddr() (see below)
 hostEntry = gethostbyname (host);
 if (!hostEntry) {
  int nret = WSAGetLastError ();
  //ReportError(nret, "gethostbyname()"); // Report the error as before
  return -256;
 }

    // Fill a SOCKADDR_IN struct with address information
 SOCKADDR_IN serverInfo;

 serverInfo.sin_family = AF_INET;

 // At this point, we've successfully retrieved vital information about the server,
 // including its hostname, aliases, and IP addresses.  Wait; how could a single
 // computer have multiple addresses, and exactly what is the following line doing?
 // See the explanation below.
   
 serverInfo.sin_addr = *((LPIN_ADDR)*hostEntry->h_addr_list);
    // Change to network-byte order and
    // insert into port field
 serverInfo.sin_port = htons(port);

 // Connect to the server
 int nret = connect (this->_socket, (LPSOCKADDR)&serverInfo, sizeof(struct sockaddr));
    if (nret == SOCKET_ERROR)
    {
        nret = WSAGetLastError ();
        return -257;
    }

  /* TCP 链接已建立.开始 SSL 握手过程.......................... */
 nret = this->InitWinsock ();
    if (nret != 0)
        return -257 + nret;

 TRACE0("Begin SSL negotiation \n");

 ssl = SSL_new (ctx);
 if (ssl == NULL) {
  TRACE0("Failed to create a new ssl \n");
  return -1;
 }

 SSL_set_fd (ssl, (int)_socket);
 err = SSL_connect (ssl);
 if (err <= 0) {
  TRACE0 ("Failed to connect by ssl\n");
  return -263;
 }

 /*打印所有加密算法的信息(可选)*/
 TRACE1 ("SSL connection using %s\n", SSL_get_cipher (ssl));

 /*得到服务端的证书并打印些信息(可选) */
 server_cert = SSL_get_peer_certificate (ssl);      
 if (server_cert == NULL) {
  TRACE0 ("Failed to get certification\n");
  return -264;
 }
 TRACE0 ("Server certificate:\n");

 CString str = X509_NAME_oneline (X509_get_subject_name (server_cert),0,0);
 TRACE1 ("\t subject: %s\n", str);
 //free (str);

 str = X509_NAME_oneline (X509_get_issuer_name  (server_cert),0,0);
 TRACE1 ("\t issuer: %s\n", str);
 //free (str);

 X509_free (server_cert);  /*如不再需要,需将证书释放 */

    // Set socket to non-blocking mode
    unsigned long arg = 1;
    ioctlsocket (this->_socket, FIONBIO, &arg);

    // Disable SO_LINGER on socket
    int arg2 = 1;
    setsockopt (this->_socket, SOL_SOCKET, SO_DONTLINGER, (char*)&arg2, sizeof(int));

    this->OnConnected ();

    return 0;
}

//
// Helper Method
//

int SSLClientSocket::SendData (const CString& buf)
{
    if (this->_socket == 0)
        return -1;
    this->_output_buf.Append (buf);
    return this->_output_buf.GetLength ();
}

void SSLClientSocket::ResetSocketState ()
{
    this->_input_buf = "";
    this->_output_buf = "";
    this->_is_reading_body = false;
    this->_body_bytes_to_read = 0;
}

int SSLClientSocket::ProcessInputBuffer ()
{
    int stop = 0;
    int offset = 0;
    while (stop == 0 && offset < this->_input_buf.GetLength ())
    {
        if (!this->_is_reading_body)
        {
            while (stop == 0 && offset < this->_input_buf.GetLength ())
            {
                if (this->_input_buf[offset] == MXMessage::HEADER_LINE_DELIMITER)
                {
                    this->_body_bytes_to_read = this->_tmp_message.GetBodyLength ();
                    if (this->_body_bytes_to_read < 0)
                    {
                        TRACE1 ("Bad body length: %d\n", this->_body_bytes_to_read);
                        return -1;
                    }
                    else
                    {
                        TRACE1 ("Body length is %d bytes\n", this->_body_bytes_to_read);
                    }
 
                    this->_is_reading_body = true;
                    offset++;

                    if (this->_body_bytes_to_read == 0)
                    {
                        TRACE0 ("No body to read. Begin processing request\n");

                        this->OnMessage (this->_tmp_message);

                        this->_tmp_message.Clear ();
                        this->_is_reading_body = false;
                    }
                    else
                    {
                        TRACE1 ("Will need to read body of %d bytes.\n", this->_body_bytes_to_read);
                        break;
                    }
                }
                else
                {
                    int i = this->_input_buf.Find (MXMessage::HEADER_LINE_DELIMITER, offset);
                    if (i != -1)
                    {
                        // Parse header key-value
                        if (this->ParseHeaderLine (this->_input_buf.Mid (offset, i - offset)) != 0)
                        {
                            return -1;
                        }
                        offset = i + 1;
                    }
                    else // if (i == -1)
                    {
                        /* this->_input_buf does not contain MXMessage::HEADER_LINE_DELIMITER,
                           so stop processing input */
                        stop = 1;
                        break;
                    }
                }
            }
        }
        else // if (this->_is_reading_body)
        {
            if (this->_body_bytes_to_read > 0)
            {
                int len = this->_input_buf.GetLength () - offset;
                if (len < this->_body_bytes_to_read)
                {
                    this->_tmp_message.GetBody () += this->_input_buf.Mid (offset);
                    this->_body_bytes_to_read -= len;
                    offset = this->_input_buf.GetLength ();
                    stop = 1;

                    TRACE2 ("Read %d bytes for body. %d bytes remaining to read for body\n", len, this->_body_bytes_to_read);
                }
                else // if (len >= this->_body_bytes_to_read)
                {
                    TRACE1 ("Read %d bytes for body. 0 byte remaining to read for body\n", this->_body_bytes_to_read);
                   
                    this->_tmp_message.GetBody () += this->_input_buf.Mid (offset, this->_body_bytes_to_read);
                    offset += this->_body_bytes_to_read;
                    this->_body_bytes_to_read = 0;

                    this->OnMessage (this->_tmp_message);

                    this->_tmp_message.Clear ();
                    this->_is_reading_body = false;
                }
            }
            else
            {
                // TODO: process_request(incoming_temp_message)
                this->OnMessage (this->_tmp_message);

                this->_tmp_message.Clear ();
                this->_is_reading_body = false;
            }
        }
    }

    if (offset < this->_input_buf.GetLength ())
    {
        this->_input_buf.Delete (0, offset);
    }
    else
    {
        this->_input_buf.Empty ();
    }
    TRACE1 ("Input buffer now contains %d bytes\n", this->_input_buf.GetLength ());

    return 0;
}

int SSLClientSocket::ParseHeaderLine (const CString &line)
{
    int pos = line.Find (MXMessage::HEADER_DELIMITER);
    if (pos != -1)
    {
        if (pos == 0)
        {
            TRACE1 ("Empty header key: %s\n", line);
            return -1;
        }
        else if (pos == line.GetLength () - 1)
        {
            CString key = line.Left (pos);
            // TRACE1 ("Adding %s -> empty to header table\n", key);
            this->_tmp_message.GetHeaderMap ().SetAt (key, "");
        }
        else
        {
            CString key = line.Left (pos);
            CString value = line.Mid (pos + 1);
            // TRACE2 ("Adding %s -> %s to header table\n", key, value);
            this->_tmp_message.GetHeaderMap ().SetAt (key, value);
        }
        return 0;
    }
    else
    {
        TRACE1 ("Cannot parse header line: %s\n", line);
        return -1;
    }
}

//
// Virtual Methods
//

int SSLClientSocket::SSLrecv (char *buf, int len)
{
 int k=0;
 do
 {
  while (1)
  {
   k = SSL_read(SSLClientSocket::ssl,buf,len);
   if (k <= 0)
   {
    if (BIO_sock_should_retry(k))
    {
     Sleep(100);
     continue;//重试
    }
    return k;//错误退出
   }
   break;//成功
  }
 }while (SSL_pending(SSLClientSocket::ssl));

 return k;
}

int SSLClientSocket::SSLsend (char *buf, int len)
{
 if(ssl==NULL)
  return -1;
 int k=0;   //发送数量
 int offset=0;  //偏移
 while (0!=len)  //maximum record size of 16kB for SSLv3/TLSv1
 {
  k = SSL_write(SSLClientSocket::ssl, buf+offset, len);
  if (k <= 0)
  {
   if (BIO_sock_should_retry(k))
   {
    Sleep(100);
    continue; //重试
   }
   return k;  //出错
  }
  offset+=k;
  len-=k;
 }
 return offset;
}

上面的代码不仅仅可以用于客户端,也可以用于服务端

阅读(2588) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~