分类: LINUX
2008-04-28 16:30:37
下面是我的一个简单例子:
#include
#include
#include
#include
#include "openssl/rsa.h"
#include "openssl/crypto.h"
#include "openssl/x509.h"
#include "openssl/pem.h"
#include "openssl/ssl.h"
#include "openssl/err.h"
#include "openssl/rand.h"
class SSLClientSocket
{
public:
static const char HEADER_LINE_DELIMITER = '\n';
public:
SSLClientSocket ();
virtual ~SSLClientSocket ();
public:
static int InitWinsock ();
static int InitCtx(bool bServer,int VerType);
public:
bool IsOpened () const;
int Connect (const CString& host, int port);
void Disconnect ();
int Poll ();
public:
virtual void OnConnected () = 0;
virtual void OnDisconnected () = 0;
virtual void OnMessage (MXMessage& msg) = 0;
int SendData (const CString& buf);
private:
void ResetSocketState ();
int ParseHeaderLine (const CString &line);
int ProcessInputBuffer ();
int SSLrecv (char *buf, int len);
int SSLsend (char *buf, int len);
private:
SOCKET _socket;
CString _input_buf;
CString _output_buf;
bool _is_reading_body;
int _body_bytes_to_read;
MXMessage _tmp_message;
private:
int err;
static SSL_CTX* ctx;
static SSL* ssl;
static X509* server_cert;
static SSL_METHOD *meth;
static int seed_int[100]; /*存放随机序列*/
WSADATA wsaData;
public:
static CString strCERTF; /*客户端的证书(需经CA签名)*/
static CString strKEYF; /*客户端的私钥(建议加密存储)*/
static CString strCACERT; /*CA 的证书*/
};
#endif // SSL_CLIENTSOCKET_H
#include "stdafx.h"
#include "ssl_clientsocket.h"
SSL_CTX* SSLClientSocket::ctx;
SSL* SSLClientSocket::ssl;
X509* SSLClientSocket::server_cert;
SSL_METHOD* SSLClientSocket::meth;
int SSLClientSocket::seed_int[100]; /*存放随机序列*/
CString SSLClientSocket::strCERTF=""; /*客户端的证书(需经CA签名)*/
CString SSLClientSocket::strKEYF=""; /*客户端的私钥(建议加密存储)*/
CString SSLClientSocket::strCACERT=""; /*CA 的证书*/
int SSLClientSocket::InitWinsock ()
{
WORD sockVersion;
WSADATA wsaData;
sockVersion = MAKEWORD (2, 2);
TRACE0 ("Initializing WinSock...\n");
// Initialize Winsock as before
if (WSAStartup (sockVersion, &wsaData) != NO_ERROR)
{
TRACE0 ("WinSock failed to initialize !!!\n");
return -1;
}
TRACE0 ("WinSock initialized successfully\n");
strCERTF = "gameclient-cert.pem"; /*客户端的证书(需经CA签名)*/
strKEYF = "gameclient-key.pem"; /*客户端的私钥(建议加密存储)*/
strCACERT = "cacert.pem"; /*CA 的证书*/
int nret = InitCtx (false,0);
TRACE1("The Return of Ctx initialization is %d\n",nret);
return nret;
}
int SSLClientSocket::InitCtx(bool bServer,int VerType)
{
OpenSSL_add_ssl_algorithms(); /*初始化*/
SSL_load_error_strings(); /*为打印调试信息作准备*/
if (bServer) /*采用什么协议(SSLv2/SSLv3/TLSv1)在此指定*/
meth = TLSv1_server_method();
else
meth = TLSv1_client_method();
ctx = SSL_CTX_new (meth);
if (ctx == NULL)
return -2;
SSL_CTX_set_verify (ctx, SSL_VERIFY_PEER, NULL); /*验证与否*/
SSL_CTX_load_verify_locations (ctx, strCACERT.GetBuffer(), NULL); /*若验证,则放置CA证书*/
if (SSL_CTX_use_certificate_file (ctx, strCERTF.GetBuffer(), SSL_FILETYPE_PEM) <= 0) {
//ERR_print_errors_fp(stderr);
return -3;
}
if (SSL_CTX_use_PrivateKey_file(ctx, strKEYF.GetBuffer(), SSL_FILETYPE_PEM) <= 0) {
//ERR_print_errors_fp(stderr);
return -4;
}
if (!SSL_CTX_check_private_key(ctx)) {
TRACE0("Private key does not match the certificate public key\n");
return -5;
}
/*构建随机数生成机制,WIN32平台必需*/
srand((unsigned)time (NULL));
for( int i = 0; i < 100; i++ )
seed_int[i] = rand();
RAND_seed(seed_int, sizeof(seed_int));
return 0;
}
SSLClientSocket::SSLClientSocket ()
: _socket (0)
{
this->ResetSocketState ();
//this->InitWinsock ();
}
SSLClientSocket::~SSLClientSocket ()
{
if (this->_socket != 0)
{
TRACE0 ("SSLClientSocket being destructed. Closing socket...\n");
this->Disconnect ();
}
}
//
// Public Methods
//
bool SSLClientSocket::IsOpened () const
{
return this->_socket != 0;
}
int SSLClientSocket::Poll ()
{
if (this->_socket == 0)
return -1;
fd_set read_fds;
fd_set write_fds;
FD_ZERO (&read_fds);
FD_ZERO (&write_fds);
FD_SET (this->_socket, &read_fds);
if (!this->_output_buf.IsEmpty ())
{
FD_SET (this->_socket, &write_fds);
}
timeval timeout = { 0, 0 };
int nret = select (1, &read_fds, &write_fds, 0, &timeout);
if (nret == SOCKET_ERROR)
return -2;
if (FD_ISSET (this->_socket, &write_fds))
{
int len = SSLsend (this->_output_buf.GetBuffer (), this->_output_buf.GetLength ());
if (len > 0)
{
if (len == this->_output_buf.GetLength ())
{
this->_output_buf.Empty ();
}
else
{
this->_output_buf = this->_output_buf.Mid (len);
}
}
else if (len == SOCKET_ERROR)
{
return -3;
}
}
if (FD_ISSET (this->_socket, &read_fds))
{
const int BUF_SIZE = 1024;
char buf[BUF_SIZE];
int len = SSLrecv (buf, BUF_SIZE);
if (len > 0)
{
this->_input_buf.Append (buf, len);
this->ProcessInputBuffer ();
}
else if (len == 0 || len == SOCKET_ERROR)
{
// TODO: Check for connection reset
this->Disconnect ();
this->OnDisconnected ();
return -4;
}
}
return 0;
}
void SSLClientSocket::Disconnect ()
{
if (this->_socket != 0)
{
shutdown (this->_socket, 2);
closesocket (this->_socket);
return;
this->_socket = 0;
this->ResetSocketState ();
if (ssl != NULL) SSL_free (ssl);
if (ctx != NULL) SSL_CTX_free (ctx);
}
}
int SSLClientSocket::Connect(const CString& host, int port)
{
if (this->_socket != 0)
{
TRACE0 ("Attempt to connect with existing connection. Closing socket...\n");
this->Disconnect ();
}
_socket = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (_socket == INVALID_SOCKET)
{
int nret = WSAGetLastError ();
return -255;
}
LPHOSTENT hostEntry;
// Specifying the server by its name;
// another option is gethostbyaddr() (see below)
hostEntry = gethostbyname (host);
if (!hostEntry) {
int nret = WSAGetLastError ();
//ReportError(nret, "gethostbyname()"); // Report the error as before
return -256;
}
// Fill a SOCKADDR_IN struct with address information
SOCKADDR_IN serverInfo;
serverInfo.sin_family = AF_INET;
// At this point, we've successfully retrieved vital information about the server,
// including its hostname, aliases, and IP addresses. Wait; how could a single
// computer have multiple addresses, and exactly what is the following line doing?
// See the explanation below.
serverInfo.sin_addr = *((LPIN_ADDR)*hostEntry->h_addr_list);
// Change to network-byte order and
// insert into port field
serverInfo.sin_port = htons(port);
// Connect to the server
int nret = connect (this->_socket, (LPSOCKADDR)&serverInfo, sizeof(struct sockaddr));
if (nret == SOCKET_ERROR)
{
nret = WSAGetLastError ();
return -257;
}
/* TCP 链接已建立.开始 SSL 握手过程.......................... */
nret = this->InitWinsock ();
if (nret != 0)
return -257 + nret;
TRACE0("Begin SSL negotiation \n");
ssl = SSL_new (ctx);
if (ssl == NULL) {
TRACE0("Failed to create a new ssl \n");
return -1;
}
SSL_set_fd (ssl, (int)_socket);
err = SSL_connect (ssl);
if (err <= 0) {
TRACE0 ("Failed to connect by ssl\n");
return -263;
}
/*打印所有加密算法的信息(可选)*/
TRACE1 ("SSL connection using %s\n", SSL_get_cipher (ssl));
/*得到服务端的证书并打印些信息(可选) */
server_cert = SSL_get_peer_certificate (ssl);
if (server_cert == NULL) {
TRACE0 ("Failed to get certification\n");
return -264;
}
TRACE0 ("Server certificate:\n");
CString str = X509_NAME_oneline (X509_get_subject_name (server_cert),0,0);
TRACE1 ("\t subject: %s\n", str);
//free (str);
str = X509_NAME_oneline (X509_get_issuer_name (server_cert),0,0);
TRACE1 ("\t issuer: %s\n", str);
//free (str);
X509_free (server_cert); /*如不再需要,需将证书释放 */
// Set socket to non-blocking mode
unsigned long arg = 1;
ioctlsocket (this->_socket, FIONBIO, &arg);
// Disable SO_LINGER on socket
int arg2 = 1;
setsockopt (this->_socket, SOL_SOCKET, SO_DONTLINGER, (char*)&arg2, sizeof(int));
this->OnConnected ();
return 0;
}
//
// Helper Method
//
int SSLClientSocket::SendData (const CString& buf)
{
if (this->_socket == 0)
return -1;
this->_output_buf.Append (buf);
return this->_output_buf.GetLength ();
}
void SSLClientSocket::ResetSocketState ()
{
this->_input_buf = "";
this->_output_buf = "";
this->_is_reading_body = false;
this->_body_bytes_to_read = 0;
}
int SSLClientSocket::ProcessInputBuffer ()
{
int stop = 0;
int offset = 0;
while (stop == 0 && offset < this->_input_buf.GetLength ())
{
if (!this->_is_reading_body)
{
while (stop == 0 && offset < this->_input_buf.GetLength ())
{
if (this->_input_buf[offset] == MXMessage::HEADER_LINE_DELIMITER)
{
this->_body_bytes_to_read = this->_tmp_message.GetBodyLength ();
if (this->_body_bytes_to_read < 0)
{
TRACE1 ("Bad body length: %d\n", this->_body_bytes_to_read);
return -1;
}
else
{
TRACE1 ("Body length is %d bytes\n", this->_body_bytes_to_read);
}
this->_is_reading_body = true;
offset++;
if (this->_body_bytes_to_read == 0)
{
TRACE0 ("No body to read. Begin processing request\n");
this->OnMessage (this->_tmp_message);
this->_tmp_message.Clear ();
this->_is_reading_body = false;
}
else
{
TRACE1 ("Will need to read body of %d bytes.\n", this->_body_bytes_to_read);
break;
}
}
else
{
int i = this->_input_buf.Find (MXMessage::HEADER_LINE_DELIMITER, offset);
if (i != -1)
{
// Parse header key-value
if (this->ParseHeaderLine (this->_input_buf.Mid (offset, i - offset)) != 0)
{
return -1;
}
offset = i + 1;
}
else // if (i == -1)
{
/* this->_input_buf does not contain MXMessage::HEADER_LINE_DELIMITER,
so stop processing input */
stop = 1;
break;
}
}
}
}
else // if (this->_is_reading_body)
{
if (this->_body_bytes_to_read > 0)
{
int len = this->_input_buf.GetLength () - offset;
if (len < this->_body_bytes_to_read)
{
this->_tmp_message.GetBody () += this->_input_buf.Mid (offset);
this->_body_bytes_to_read -= len;
offset = this->_input_buf.GetLength ();
stop = 1;
TRACE2 ("Read %d bytes for body. %d bytes remaining to read for body\n", len, this->_body_bytes_to_read);
}
else // if (len >= this->_body_bytes_to_read)
{
TRACE1 ("Read %d bytes for body. 0 byte remaining to read for body\n", this->_body_bytes_to_read);
this->_tmp_message.GetBody () += this->_input_buf.Mid (offset, this->_body_bytes_to_read);
offset += this->_body_bytes_to_read;
this->_body_bytes_to_read = 0;
this->OnMessage (this->_tmp_message);
this->_tmp_message.Clear ();
this->_is_reading_body = false;
}
}
else
{
// TODO: process_request(incoming_temp_message)
this->OnMessage (this->_tmp_message);
this->_tmp_message.Clear ();
this->_is_reading_body = false;
}
}
}
if (offset < this->_input_buf.GetLength ())
{
this->_input_buf.Delete (0, offset);
}
else
{
this->_input_buf.Empty ();
}
TRACE1 ("Input buffer now contains %d bytes\n", this->_input_buf.GetLength ());
return 0;
}
int SSLClientSocket::ParseHeaderLine (const CString &line)
{
int pos = line.Find (MXMessage::HEADER_DELIMITER);
if (pos != -1)
{
if (pos == 0)
{
TRACE1 ("Empty header key: %s\n", line);
return -1;
}
else if (pos == line.GetLength () - 1)
{
CString key = line.Left (pos);
// TRACE1 ("Adding %s -> empty to header table\n", key);
this->_tmp_message.GetHeaderMap ().SetAt (key, "");
}
else
{
CString key = line.Left (pos);
CString value = line.Mid (pos + 1);
// TRACE2 ("Adding %s -> %s to header table\n", key, value);
this->_tmp_message.GetHeaderMap ().SetAt (key, value);
}
return 0;
}
else
{
TRACE1 ("Cannot parse header line: %s\n", line);
return -1;
}
}
//
// Virtual Methods
//
int SSLClientSocket::SSLrecv (char *buf, int len)
{
int k=0;
do
{
while (1)
{
k = SSL_read(SSLClientSocket::ssl,buf,len);
if (k <= 0)
{
if (BIO_sock_should_retry(k))
{
Sleep(100);
continue;//重试
}
return k;//错误退出
}
break;//成功
}
}while (SSL_pending(SSLClientSocket::ssl));
return k;
}
int SSLClientSocket::SSLsend (char *buf, int len)
{
if(ssl==NULL)
return -1;
int k=0; //发送数量
int offset=0; //偏移
while (0!=len) //maximum record size of 16kB for SSLv3/TLSv1
{
k = SSL_write(SSLClientSocket::ssl, buf+offset, len);
if (k <= 0)
{
if (BIO_sock_should_retry(k))
{
Sleep(100);
continue; //重试
}
return k; //出错
}
offset+=k;
len-=k;
}
return offset;
}
上面的代码不仅仅可以用于客户端,也可以用于服务端