1 {-# LANGUAGE BangPatterns        #-}
    2 {-# LANGUAGE CPP                 #-}
    3 {-# LANGUAGE OverloadedStrings   #-}
    4 {-# LANGUAGE RankNTypes          #-}
    5 {-# LANGUAGE ScopedTypeVariables #-}
    6 
    7 module Snap.Internal.Http.Server.TimeoutManager
    8   ( TimeoutManager
    9   , TimeoutThread
   10   , initialize
   11   , stop
   12   , register
   13   , tickle
   14   , set
   15   , modify
   16   , cancel
   17   ) where
   18 
   19 ------------------------------------------------------------------------------
   20 import           Control.Exception                (evaluate, finally)
   21 import qualified Control.Exception                as E
   22 import           Control.Monad                    (Monad ((>>=), return), mapM_, void)
   23 import qualified Data.ByteString.Char8            as S
   24 import           Data.IORef                       (IORef, newIORef, readIORef, writeIORef)
   25 import           Prelude                          (Bool, Double, IO, Int, Show (..), const, fromIntegral, max, max, min, null, otherwise, round, ($), ($!), (+), (++), (-), (.), (<=), (==))
   26 ------------------------------------------------------------------------------
   27 import           Control.Concurrent               (MVar, newEmptyMVar, putMVar, readMVar, takeMVar, tryPutMVar)
   28 ------------------------------------------------------------------------------
   29 import           Snap.Internal.Http.Server.Clock  (ClockTime)
   30 import qualified Snap.Internal.Http.Server.Clock  as Clock
   31 import           Snap.Internal.Http.Server.Common (atomicModifyIORef', eatException)
   32 import qualified Snap.Internal.Http.Server.Thread as T
   33 
   34 
   35 ------------------------------------------------------------------------------
   36 type State = ClockTime
   37 
   38 canceled :: State
   39 canceled = 0
   40 
   41 isCanceled :: State -> Bool
   42 isCanceled = (== 0)
   43 
   44 
   45 ------------------------------------------------------------------------------
   46 data TimeoutThread = TimeoutThread {
   47       _thread     :: !T.SnapThread
   48     , _state      :: !(IORef State)
   49     , _hGetTime   :: !(IO ClockTime)
   50     }
   51 
   52 instance Show TimeoutThread where
   53     show = show . _thread
   54 
   55 
   56 ------------------------------------------------------------------------------
   57 -- | Given a 'State' value and the current time, apply the given modification
   58 -- function to the amount of time remaining.
   59 --
   60 smap :: ClockTime -> (ClockTime -> ClockTime) -> State -> State
   61 smap now f deadline | isCanceled deadline = deadline
   62                     | otherwise = t'
   63   where
   64     remaining    = max 0 (deadline - now)
   65     newremaining = f remaining
   66     t'           = now + newremaining
   67 
   68 
   69 ------------------------------------------------------------------------------
   70 data TimeoutManager = TimeoutManager {
   71       _defaultTimeout :: !ClockTime
   72     , _pollInterval   :: !ClockTime
   73     , _getTime        :: !(IO ClockTime)
   74     , _threads        :: !(IORef [TimeoutThread])
   75     , _morePlease     :: !(MVar ())
   76     , _managerThread  :: !(MVar T.SnapThread)
   77     }
   78 
   79 
   80 ------------------------------------------------------------------------------
   81 -- | Create a new TimeoutManager.
   82 initialize :: Double            -- ^ default timeout
   83            -> Double            -- ^ poll interval
   84            -> IO ClockTime      -- ^ function to get current time
   85            -> IO TimeoutManager
   86 initialize defaultTimeout interval getTime = E.uninterruptibleMask_ $ do
   87     conns <- newIORef []
   88     mp    <- newEmptyMVar
   89     mthr  <- newEmptyMVar
   90 
   91     let tm = TimeoutManager (Clock.fromSecs defaultTimeout)
   92                             (Clock.fromSecs interval)
   93                             getTime
   94                             conns
   95                             mp
   96                             mthr
   97 
   98     thr <- T.fork "snap-server: timeout manager" $ managerThread tm
   99     putMVar mthr thr
  100     return tm
  101 
  102 
  103 ------------------------------------------------------------------------------
  104 -- | Stop a TimeoutManager.
  105 stop :: TimeoutManager -> IO ()
  106 stop tm = readMVar (_managerThread tm) >>= T.cancelAndWait
  107 
  108 
  109 ------------------------------------------------------------------------------
  110 wakeup :: TimeoutManager -> IO ()
  111 wakeup tm = void $ tryPutMVar (_morePlease tm) $! ()
  112 
  113 
  114 ------------------------------------------------------------------------------
  115 -- | Register a new thread with the TimeoutManager.
  116 register :: TimeoutManager                        -- ^ manager to register
  117                                                   --   with
  118          -> S.ByteString                          -- ^ thread label
  119          -> ((forall a . IO a -> IO a) -> IO ())  -- ^ thread action to run
  120          -> IO TimeoutThread
  121 register tm label action = do
  122     now <- getTime
  123     let !state = now + defaultTimeout
  124     stateRef <- newIORef state
  125     th <- E.uninterruptibleMask_ $ do
  126         t <- T.fork label action
  127         let h = TimeoutThread t stateRef getTime
  128         atomicModifyIORef' threads (\x -> (h:x, ())) >>= evaluate
  129         return $! h
  130     wakeup tm
  131     return th
  132 
  133   where
  134     getTime        = _getTime tm
  135     threads        = _threads tm
  136     defaultTimeout = _defaultTimeout tm
  137 
  138 
  139 ------------------------------------------------------------------------------
  140 -- | Tickle the timeout on a connection to be at least N seconds into the
  141 -- future. If the existing timeout is set for M seconds from now, where M > N,
  142 -- then the timeout is unaffected.
  143 tickle :: TimeoutThread -> Int -> IO ()
  144 tickle th = modify th . max
  145 {-# INLINE tickle #-}
  146 
  147 
  148 ------------------------------------------------------------------------------
  149 -- | Set the timeout on a connection to be N seconds into the future.
  150 set :: TimeoutThread -> Int -> IO ()
  151 set th = modify th . const
  152 {-# INLINE set #-}
  153 
  154 
  155 ------------------------------------------------------------------------------
  156 -- | Modify the timeout with the given function.
  157 modify :: TimeoutThread -> (Int -> Int) -> IO ()
  158 modify th f = do
  159     now   <- getTime
  160     state <- readIORef stateRef
  161     let !state' = smap now f' state
  162     writeIORef stateRef state'
  163 
  164   where
  165     f' !x    = Clock.fromSecs $! fromIntegral $ f $ round $ Clock.toSecs x
  166     getTime  = _hGetTime th
  167     stateRef = _state th
  168 {-# INLINE modify #-}
  169 
  170 
  171 ------------------------------------------------------------------------------
  172 -- | Cancel a timeout.
  173 cancel :: TimeoutThread -> IO ()
  174 cancel h = E.uninterruptibleMask_ $ do
  175     T.cancel $ _thread h
  176     writeIORef (_state h) canceled
  177 {-# INLINE cancel #-}
  178 
  179 
  180 ------------------------------------------------------------------------------
  181 managerThread :: TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
  182 managerThread tm restore = restore loop `finally` cleanup
  183   where
  184     cleanup = E.uninterruptibleMask_ $
  185               eatException (readIORef threads >>= destroyAll)
  186 
  187     --------------------------------------------------------------------------
  188     getTime      = _getTime tm
  189     morePlease   = _morePlease tm
  190     pollInterval = _pollInterval tm
  191     threads      = _threads tm
  192 
  193     --------------------------------------------------------------------------
  194     loop = do
  195         now <- getTime
  196         nextWakeup <- E.uninterruptibleMask $ \restore' -> do
  197             handles <- atomicModifyIORef' threads (\x -> ([], x))
  198             if null handles
  199               then do restore' $ takeMVar morePlease
  200                       return now
  201               else do
  202                 (handles', next) <- processHandles now handles
  203                 atomicModifyIORef' threads (\x -> (handles' ++ x, ()))
  204                     >>= evaluate
  205                 return $! next
  206         now' <- getTime
  207         Clock.sleepFor $ max 0 (nextWakeup - now')
  208         loop
  209 
  210     --------------------------------------------------------------------------
  211     processHandles now handles = go handles [] (now + pollInterval)
  212       where
  213         go [] !kept !nextWakeup = return $! (kept, nextWakeup)
  214 
  215         go (x:xs) !kept !nextWakeup = do
  216             !state <- readIORef $ _state x
  217             (!kept', !next) <-
  218                 if isCanceled state
  219                   then do b <- T.isFinished (_thread x)
  220                           return $! if b
  221                                       then (kept, nextWakeup)
  222                                       else ((x:kept), nextWakeup)
  223                   else do t <- if state <= now
  224                                  then do T.cancel (_thread x)
  225                                          writeIORef (_state x) canceled
  226                                          return nextWakeup
  227                                  else return (min nextWakeup state)
  228                           return ((x:kept), t)
  229             go xs kept' next
  230 
  231     --------------------------------------------------------------------------
  232     destroyAll xs = do
  233         mapM_ (T.cancel . _thread) xs
  234         mapM_ (T.wait . _thread) xs