/*=====================================================================*/
/*    serrano/prgm/project/bigloo/api/ssl/src/Posix/bglssl.c           */
/*    -------------------------------------------------------------    */
/*    Author      :  Manuel Serrano                                    */
/*    Creation    :  Wed Mar 23 16:54:42 2005                          */
/*    Last change :  Fri Aug  4 17:21:48 2006 (serrano)                */
/*    Copyright   :  2005-06 Manuel Serrano                            */
/*    -------------------------------------------------------------    */
/*    SSL socket client-side support                                   */
/*=====================================================================*/
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <fcntl.h>

#if defined( _MSC_VER) || defined( _MINGW_VER )
#  define _BGL_WIN32_VER
#endif

#include <bigloo_config.h>
#include <sys/types.h>
#ifndef _BGL_WIN32_VER
#   include <sys/socket.h>
#   include <netinet/in.h>
#   include <arpa/inet.h>
#   include <netdb.h>
#   include <time.h>
#   if( BGL_HAVE_SELECT )
#     include <sys/time.h>
#     include <sys/types.h>
#     include <unistd.h>
#   endif
#else
#   if defined( _MINGW_VER )
#      include "windows.h"
#   endif
#   include <winsock2.h>
#   include <mswsock.h>
#   include <io.h>
#endif
#include <fcntl.h>
#include <memory.h>
#include <errno.h>
#include <bigloo.h>

#define socklen_t void

#ifndef _BGL_WIN32_VER
#   define BAD_SOCKET(s) ((s) < 0)
#else
#   define BAD_SOCKET(s) ((s) == INVALID_SOCKET)
#endif

#define SOCKET_IO_BUFSIZE 1024

/*---------------------------------------------------------------------*/
/*    Imports                                                          */
/*---------------------------------------------------------------------*/
extern obj_t bigloo_mutex;

/*---------------------------------------------------------------------*/
/*    SSL mutex                                                        */
/*---------------------------------------------------------------------*/
static obj_t ssl_mutex = BUNSPEC;
DEFINE_STRING( ssl_mutex_name, _1, "ssl-mutex", sizeof( "ssl-mutex" ) + 1 );

/*---------------------------------------------------------------------*/
/*    SSL socket close hook                                            */
/*---------------------------------------------------------------------*/
static obj_t socket_close_hook( obj_t, obj_t );
static obj_t input_close_hook( obj_t, obj_t );
static obj_t output_close_hook( obj_t, obj_t );

DEFINE_STATIC_BGL_PROCEDURE( ssl_socket_close_hook, _2, socket_close_hook, 0L, BUNSPEC, 1 );
DEFINE_STATIC_BGL_PROCEDURE( ssl_input_close_hook, _3, input_close_hook, 0L, BUNSPEC, 1 );
DEFINE_STATIC_BGL_PROCEDURE( ssl_output_close_hook, _4, output_close_hook, 0L, BUNSPEC, 1 );

/*---------------------------------------------------------------------*/
/*    The global SSL context                                           */
/*---------------------------------------------------------------------*/
static SSL_CTX *ctx;

/*---------------------------------------------------------------------*/
/*    static void                                                      */
/*    bgl_ssl_init ...                                                 */
/*---------------------------------------------------------------------*/
static void
bgl_ssl_init() {
   static initialized = 0;

   bgl_mutex_lock( bigloo_mutex );
   
   if( !initialized ) {
      initialized = 1;

      /* the SSL dedicated lock */
      ssl_mutex = bgl_make_mutex( ssl_mutex_name );
      
      /* Initialize SSL context */
      SSL_library_init();
      SSL_load_error_strings();
      ctx = SSL_CTX_new( SSLv23_client_method() );
   }
   
   bgl_mutex_unlock( bigloo_mutex );
}

/*---------------------------------------------------------------------*/
/*    static long                                                      */
/*    sslread ...                                                      */
/*---------------------------------------------------------------------*/
static long
sslread( char *ptr, size_t size, size_t nmemb, obj_t port ) {
   int len = size;
   int r;
   if( nmemb != 1 ) len *= nmemb;
   SSL *ssl = PORT( port ).userdata;

loop:   
   if( (r = SSL_read( ssl, ptr, len )) <= 0 ) {
      if( (SSL_get_error( ssl, r ) == SSL_ERROR_SSL) &&
	  (errno == EINTR) )
	 goto loop;
   }
   
   return r;
}

/*---------------------------------------------------------------------*/
/*    static size_t                                                    */
/*    sslwrite ...                                                     */
/*---------------------------------------------------------------------*/
static size_t
sslwrite( void *ptr, size_t size, size_t nmemb, obj_t port ) {
   int len = size;
   if( nmemb != 1 ) len *= nmemb;
   SSL *ssl = (SSL *)PORT( port ).userdata;

   return (size_t)SSL_write( ssl, ptr, len );
}

/*---------------------------------------------------------------------*/
/*    static int                                                       */
/*    sslputc ...                                                      */
/*---------------------------------------------------------------------*/
static int
sslputc( int c, obj_t port ) {
   SSL *ssl = (SSL *)PORT( port ).userdata;
   char s[ 1 ];
   s[ 0 ] = c;
   return SSL_write( ssl, s, 1 ) == 1;
}

/*---------------------------------------------------------------------*/
/*    static obj_t                                                     */
/*    sslflush ...                                                     */
/*---------------------------------------------------------------------*/
static obj_t
sslflush( obj_t port ) {
   return BTRUE;
}

/*---------------------------------------------------------------------*/
/*    static obj_t                                                     */
/*    socket_close_hook ...                                            */
/*---------------------------------------------------------------------*/
static obj_t
socket_close_hook( obj_t env, obj_t s ) {
   SSL *ssl = (SSL *)SOCKET( s ).userdata;

   bgl_mutex_lock( ssl_mutex );
   
   SSL_shutdown( ssl );
   SSL_free( ssl );
   
   bgl_mutex_unlock( ssl_mutex );
   
   return s;
}

/*---------------------------------------------------------------------*/
/*    static obj_t                                                     */
/*    input_close_hook ...                                             */
/*---------------------------------------------------------------------*/
static obj_t
input_close_hook( obj_t env, obj_t ip ) {
   fclose( (FILE *)(PORT( ip ).stream ) );
}

/*---------------------------------------------------------------------*/
/*    static obj_t                                                     */
/*    output_close_hook ...                                            */
/*---------------------------------------------------------------------*/
static obj_t
output_close_hook( obj_t env, obj_t op ) {
   fclose( (FILE *)(PORT( op ).stream ) );
}

/*---------------------------------------------------------------------*/
/*    obj_t                                                            */
/*    bgl_make_ssl_client_socket ...                                   */
/*---------------------------------------------------------------------*/
BGL_RUNTIME_DEF obj_t
bgl_make_ssl_client_socket( obj_t hostname, int port, char bufp, int ms ) {
   obj_t s = make_client_socket( hostname, port, bufp, ms );
   obj_t ip, op;
   SSL *ssl;
   BIO *sbio;

   bgl_ssl_init();

   bgl_mutex_lock( ssl_mutex );
   
   sbio = BIO_new_socket( SOCKET( s ).fd, BIO_NOCLOSE );
   ssl = SSL_new( ctx );
   SSL_set_bio( ssl, sbio, sbio );
   SSL_set_mode( ssl, SSL_MODE_AUTO_RETRY );
   
   bgl_mutex_unlock( ssl_mutex );
  
   if( SSL_connect( ssl ) <= 0 ) {
      BIO_free( sbio );
      socket_close( s );
      C_SYSTEM_FAILURE( BGL_IO_ERROR,
			"make-client-ssl-socket",
			"Cannot create socket",
			BUNSPEC );
      
   }
  
   ip = SOCKET_INPUT( s );
   op = SOCKET_OUTPUT( s );
   
   PORT( ip ).userdata = ssl;
   PORT( ip ).chook = ssl_input_close_hook;
   PORT( ip ).sysclose = 0L;
   INPUT_PORT( ip ).sysread = &sslread;

   PORT( op ).userdata = ssl;
   PORT( op ).sysclose = 0L;
   PORT( op ).chook = ssl_output_close_hook;
   OUTPUT_PORT( op ).syswrite = &sslwrite;
   OUTPUT_PORT( op ).sysputc = &sslputc;
   OUTPUT_PORT( op ).sysflush = &sslflush;
   
   SOCKET( s ).userdata = ssl;
   SOCKET_CHOOK( s ) = ssl_socket_close_hook;
   
   return s;
}


