1 {-# LANGUAGE BangPatterns        #-}
    2 {-# LANGUAGE CPP                 #-}
    3 {-# LANGUAGE OverloadedStrings   #-}
    4 {-# LANGUAGE RankNTypes          #-}
    5 {-# LANGUAGE ScopedTypeVariables #-}
    6 module Snap.Internal.Http.Server.Socket
    7   ( bindSocket
    8   , bindSocketImpl
    9   , bindUnixSocket
   10   , httpAcceptFunc
   11   , haProxyAcceptFunc
   12   , sendFileFunc
   13   , acceptAndInitialize
   14   ) where
   15 
   16 ------------------------------------------------------------------------------
   17 import           Control.Exception                 (bracketOnError, finally, throwIO)
   18 import           Control.Monad                     (when)
   19 import           Data.Bits                         (complement, (.&.))
   20 import           Data.ByteString.Char8             (ByteString)
   21 import           Network.Socket                    (Socket, SocketOption (NoDelay, ReuseAddr), accept, close, getSocketName, listen, setSocketOption, socket)
   22 import qualified Network.Socket                    as N
   23 #ifdef HAS_SENDFILE
   24 import           Network.Socket                    (fdSocket)
   25 import           System.Posix.IO                   (OpenMode (..), closeFd, defaultFileFlags, openFd)
   26 import           System.Posix.Types                (Fd (..))
   27 import           System.SendFile                   (sendFile, sendHeaders)
   28 #else
   29 import           Data.ByteString.Builder           (byteString)
   30 import           Data.ByteString.Builder.Extra     (flush)
   31 import           Network.Socket.ByteString         (sendAll)
   32 #endif
   33 #ifdef HAS_UNIX_SOCKETS
   34 import           Control.Exception                 (bracket)
   35 import qualified Control.Exception                 as E (catch)
   36 import           System.FilePath                   (isRelative)
   37 import           System.IO.Error                   (isDoesNotExistError)
   38 import           System.Posix.Files                (accessModes, removeLink, setFileCreationMask)
   39 #endif
   40 
   41 ------------------------------------------------------------------------------
   42 import qualified System.IO.Streams                 as Streams
   43 ------------------------------------------------------------------------------
   44 import           Snap.Internal.Http.Server.Address (AddressNotSupportedException (..), getAddress, getSockAddr)
   45 import           Snap.Internal.Http.Server.Types   (AcceptFunc (..), SendFileHandler)
   46 import qualified System.IO.Streams.Network.HAProxy as HA
   47 
   48 
   49 ------------------------------------------------------------------------------
   50 bindSocket :: ByteString -> Int -> IO Socket
   51 bindSocket = bindSocketImpl setSocketOption N.bindSocket listen
   52 {-# INLINE bindSocket #-}
   53 
   54 
   55 ------------------------------------------------------------------------------
   56 bindSocketImpl
   57     :: (Socket -> SocketOption -> Int -> IO ()) -- ^ mock setSocketOption
   58     -> (Socket -> N.SockAddr -> IO ())          -- ^ bindSocket
   59     -> (Socket -> Int -> IO ())                 -- ^ listen
   60     -> ByteString
   61     -> Int
   62     -> IO Socket
   63 bindSocketImpl _setSocketOption _bindSocket _listen bindAddr bindPort = do
   64     (family, addr) <- getSockAddr bindPort bindAddr
   65     bracketOnError (socket family N.Stream 0) N.close $ \sock -> do
   66         _setSocketOption sock ReuseAddr 1
   67         _setSocketOption sock NoDelay 1
   68         _bindSocket sock addr
   69         _listen sock 150
   70         return $! sock
   71 
   72 bindUnixSocket :: Maybe Int -> String -> IO Socket
   73 #if HAS_UNIX_SOCKETS
   74 bindUnixSocket mode path = do
   75    when (isRelative path) $
   76       throwIO $ AddressNotSupportedException
   77                 $! "Refusing to bind unix socket to non-absolute path: " ++ path
   78 
   79    bracketOnError (socket N.AF_UNIX N.Stream 0) N.close $ \sock -> do
   80       E.catch (removeLink path) $ \e -> when (not $ isDoesNotExistError e) $ throwIO e
   81       case mode of
   82          Nothing -> N.bindSocket sock (N.SockAddrUnix path)
   83          Just mode' -> bracket (setFileCreationMask $ modeToMask mode')
   84                               setFileCreationMask
   85                               (const $ N.bindSocket sock (N.SockAddrUnix path))
   86       N.listen sock 150
   87       return $! sock
   88    where
   89      modeToMask p = accessModes .&. complement (fromIntegral p)
   90 #else
   91 bindUnixSocket _ path = throwIO (AddressNotSupportedException $ "unix:" ++ path)
   92 #endif
   93 
   94 ------------------------------------------------------------------------------
   95 -- TODO(greg): move buffer size configuration into config
   96 bUFSIZ :: Int
   97 bUFSIZ = 4064
   98 
   99 
  100 ------------------------------------------------------------------------------
  101 acceptAndInitialize :: Socket        -- ^ bound socket
  102                     -> (forall b . IO b -> IO b)
  103                     -> ((Socket, N.SockAddr) -> IO a)
  104                     -> IO a
  105 acceptAndInitialize boundSocket restore f =
  106     bracketOnError (restore $ accept boundSocket)
  107                    (close . fst)
  108                    f
  109 
  110 
  111 ------------------------------------------------------------------------------
  112 haProxyAcceptFunc :: Socket     -- ^ bound socket
  113                   -> AcceptFunc
  114 haProxyAcceptFunc boundSocket =
  115     AcceptFunc $ \restore ->
  116     acceptAndInitialize boundSocket restore $ \(sock, saddr) -> do
  117         (readEnd, writeEnd)      <- Streams.socketToStreamsWithBufferSize
  118                                         bUFSIZ sock
  119         localPInfo               <- HA.socketToProxyInfo sock saddr
  120         pinfo                    <- HA.decodeHAProxyHeaders localPInfo readEnd
  121         (localPort, localHost)   <- getAddress $ HA.getDestAddr pinfo
  122         (remotePort, remoteHost) <- getAddress $ HA.getSourceAddr pinfo
  123         let cleanup              =  Streams.write Nothing writeEnd
  124                                         `finally` close sock
  125         return $! ( sendFileFunc sock
  126                   , localHost
  127                   , localPort
  128                   , remoteHost
  129                   , remotePort
  130                   , readEnd
  131                   , writeEnd
  132                   , cleanup
  133                   )
  134 
  135 
  136 ------------------------------------------------------------------------------
  137 httpAcceptFunc :: Socket                     -- ^ bound socket
  138                -> AcceptFunc
  139 httpAcceptFunc boundSocket =
  140     AcceptFunc $ \restore ->
  141     acceptAndInitialize boundSocket restore $ \(sock, remoteAddr) -> do
  142         localAddr                <- getSocketName sock
  143         (localPort, localHost)   <- getAddress localAddr
  144         (remotePort, remoteHost) <- getAddress remoteAddr
  145         (readEnd, writeEnd)      <- Streams.socketToStreamsWithBufferSize bUFSIZ
  146                                                                           sock
  147         let cleanup              =  Streams.write Nothing writeEnd
  148                                       `finally` close sock
  149         return $! ( sendFileFunc sock
  150                   , localHost
  151                   , localPort
  152                   , remoteHost
  153                   , remotePort
  154                   , readEnd
  155                   , writeEnd
  156                   , cleanup
  157                   )
  158 
  159 
  160 ------------------------------------------------------------------------------
  161 sendFileFunc :: Socket -> SendFileHandler
  162 #ifdef HAS_SENDFILE
  163 sendFileFunc sock !_ builder fPath offset nbytes = bracket acquire closeFd go
  164   where
  165     sockFd    = Fd (fdSocket sock)
  166     acquire   = openFd fPath ReadOnly Nothing defaultFileFlags
  167     go fileFd = do sendHeaders builder sockFd
  168                    sendFile sockFd fileFd offset nbytes
  169 
  170 
  171 #else
  172 sendFileFunc sock buffer builder fPath offset nbytes =
  173     Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $
  174             \fileInput0 -> do
  175         fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
  176                      Streams.map byteString
  177         input     <- Streams.fromList [builder] >>=
  178                      flip Streams.appendInputStream fileInput
  179         output    <- Streams.makeOutputStream sendChunk >>=
  180                      Streams.unsafeBuilderStream (return buffer)
  181         Streams.supply input output
  182         Streams.write (Just flush) output
  183 
  184   where
  185     sendChunk (Just s) = sendAll sock s
  186     sendChunk Nothing  = return $! ()
  187 #endif