1 {-# LANGUAGE CPP                 #-}
    2 {-# LANGUAGE DeriveDataTypeable  #-}
    3 {-# LANGUAGE OverloadedStrings   #-}
    4 {-# LANGUAGE ScopedTypeVariables #-}
    5 
    6 ------------------------------------------------------------------------------
    7 module Snap.Internal.Http.Server.TLS
    8   ( TLSException(..)
    9   , withTLS
   10   , bindHttps
   11   , httpsAcceptFunc
   12   , sendFileFunc
   13   ) where
   14 
   15 ------------------------------------------------------------------------------
   16 import           Data.ByteString.Char8             (ByteString)
   17 import qualified Data.ByteString.Char8             as S
   18 import           Data.Typeable                     (Typeable)
   19 import           Network.Socket                    (Socket)
   20 #ifdef OPENSSL
   21 import           Control.Exception                 (Exception, bracketOnError, finally, throwIO)
   22 import           Control.Monad                     (when)
   23 import           Data.ByteString.Builder           (byteString)
   24 import qualified Network.Socket                    as Socket
   25 import           OpenSSL                           (withOpenSSL)
   26 import           OpenSSL.Session                   (SSL, SSLContext)
   27 import qualified OpenSSL.Session                   as SSL
   28 import           Prelude                           (Bool, FilePath, IO, Int, Maybe (..), Monad (..), Show, flip, fromIntegral, fst, not, ($), ($!), (.))
   29 import           Snap.Internal.Http.Server.Address (getAddress, getSockAddr)
   30 import           Snap.Internal.Http.Server.Socket  (acceptAndInitialize)
   31 import qualified System.IO.Streams                 as Streams
   32 import qualified System.IO.Streams.SSL             as SStreams
   33 
   34 #else
   35 import           Control.Exception                 (Exception, throwIO)
   36 import           Prelude                           (Bool, FilePath, IO, Int, Show, id, ($))
   37 #endif
   38 ------------------------------------------------------------------------------
   39 import           Snap.Internal.Http.Server.Types   (AcceptFunc (..), SendFileHandler)
   40 ------------------------------------------------------------------------------
   41 
   42 data TLSException = TLSException S.ByteString
   43   deriving (Show, Typeable)
   44 instance Exception TLSException
   45 
   46 #ifndef OPENSSL
   47 type SSLContext = ()
   48 type SSL = ()
   49 
   50 ------------------------------------------------------------------------------
   51 sslNotSupportedException :: TLSException
   52 sslNotSupportedException = TLSException $ S.concat [
   53     "This version of snap-server was not built with SSL "
   54   , "support.\n"
   55   , "Please compile snap-server with -fopenssl to enable it."
   56   ]
   57 
   58 
   59 ------------------------------------------------------------------------------
   60 withTLS :: IO a -> IO a
   61 withTLS = id
   62 
   63 
   64 ------------------------------------------------------------------------------
   65 barf :: IO a
   66 barf = throwIO sslNotSupportedException
   67 
   68 
   69 ------------------------------------------------------------------------------
   70 bindHttps :: ByteString -> Int -> FilePath -> Bool -> FilePath
   71           -> IO (Socket, SSLContext)
   72 bindHttps _ _ _ _ _ = barf
   73 
   74 
   75 ------------------------------------------------------------------------------
   76 httpsAcceptFunc :: Socket -> SSLContext -> AcceptFunc
   77 httpsAcceptFunc _ _ = AcceptFunc $ \restore -> restore barf
   78 
   79 
   80 ------------------------------------------------------------------------------
   81 sendFileFunc :: SSL -> Socket -> SendFileHandler
   82 sendFileFunc _ _ _ _ _ _ _ = barf
   83 
   84 
   85 #else
   86 ------------------------------------------------------------------------------
   87 withTLS :: IO a -> IO a
   88 withTLS = withOpenSSL
   89 
   90 
   91 ------------------------------------------------------------------------------
   92 bindHttps :: ByteString
   93           -> Int
   94           -> FilePath
   95           -> Bool
   96           -> FilePath
   97           -> IO (Socket, SSLContext)
   98 bindHttps bindAddress bindPort cert chainCert key =
   99     withTLS $
  100     bracketOnError
  101         (do (family, addr) <- getSockAddr bindPort bindAddress
  102             sock <- Socket.socket family Socket.Stream 0
  103             return (sock, addr)
  104             )
  105         (Socket.close . fst)
  106         $ \(sock, addr) -> do
  107              Socket.setSocketOption sock Socket.ReuseAddr 1
  108              Socket.bindSocket sock addr
  109              Socket.listen sock 150
  110 
  111              ctx <- SSL.context
  112              SSL.contextSetPrivateKeyFile ctx key
  113              if chainCert
  114                then SSL.contextSetCertificateChainFile ctx cert
  115                else SSL.contextSetCertificateFile ctx cert
  116 
  117              certOK <- SSL.contextCheckPrivateKey ctx
  118              when (not certOK) $ do
  119                  throwIO $ TLSException certificateError
  120              return (sock, ctx)
  121   where
  122     certificateError =
  123       "OpenSSL says that the certificate doesn't match the private key!"
  124 
  125 
  126 ------------------------------------------------------------------------------
  127 httpsAcceptFunc :: Socket
  128                 -> SSLContext
  129                 -> AcceptFunc
  130 httpsAcceptFunc boundSocket ctx =
  131     AcceptFunc $ \restore ->
  132     acceptAndInitialize boundSocket restore $ \(sock, remoteAddr) -> do
  133         localAddr                <- Socket.getSocketName sock
  134         (localPort, localHost)   <- getAddress localAddr
  135         (remotePort, remoteHost) <- getAddress remoteAddr
  136         ssl                      <- restore (SSL.connection ctx sock)
  137 
  138         restore (SSL.accept ssl)
  139         (readEnd, writeEnd) <- SStreams.sslToStreams ssl
  140 
  141         let cleanup = (do Streams.write Nothing writeEnd
  142                           SSL.shutdown ssl $! SSL.Unidirectional)
  143                         `finally` Socket.close sock
  144 
  145         return $! ( sendFileFunc ssl
  146                   , localHost
  147                   , localPort
  148                   , remoteHost
  149                   , remotePort
  150                   , readEnd
  151                   , writeEnd
  152                   , cleanup
  153                   )
  154 
  155 
  156 ------------------------------------------------------------------------------
  157 sendFileFunc :: SSL -> SendFileHandler
  158 sendFileFunc ssl buffer builder fPath offset nbytes = do
  159     Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $ \fileInput0 -> do
  160         fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
  161                      Streams.map byteString
  162         input     <- Streams.fromList [builder] >>=
  163                      flip Streams.appendInputStream fileInput
  164         output    <- Streams.makeOutputStream sendChunk >>=
  165                      Streams.unsafeBuilderStream (return buffer)
  166         Streams.supply input output
  167         Streams.write Nothing output
  168 
  169   where
  170     sendChunk (Just s) = SSL.write ssl s
  171     sendChunk Nothing  = return $! ()
  172 #endif