1 {-# LANGUAGE BangPatterns       #-}
    2 {-# LANGUAGE CPP                #-}
    3 {-# LANGUAGE DeriveDataTypeable #-}
    4 {-# LANGUAGE MagicHash          #-}
    5 {-# LANGUAGE OverloadedStrings  #-}
    6 {-# LANGUAGE Rank2Types         #-}
    7 {-# LANGUAGE Trustworthy        #-}
    8 {-# LANGUAGE UnboxedTuples      #-}
    9 
   10 module Snap.Internal.Http.Server.Parser
   11   ( IRequest(..)
   12   , HttpParseException(..)
   13   , readChunkedTransferEncoding
   14   , writeChunkedTransferEncoding
   15   , parseRequest
   16   , parseFromStream
   17   , parseCookie
   18   , parseUrlEncoded
   19   , getStdContentLength
   20   , getStdHost
   21   , getStdTransferEncoding
   22   , getStdCookie
   23   , getStdContentType
   24   , getStdConnection
   25   ) where
   26 
   27 ------------------------------------------------------------------------------
   28 #if !MIN_VERSION_base(4,8,0)
   29 import           Control.Applicative              ((<$>))
   30 #endif
   31 import           Control.Exception                (Exception, throwIO)
   32 import qualified Control.Exception                as E
   33 import           Control.Monad                    (void, when)
   34 import           Data.Attoparsec.ByteString.Char8 (Parser, hexadecimal, skipWhile, take)
   35 import qualified Data.ByteString.Char8            as S
   36 import           Data.ByteString.Internal         (ByteString (..), c2w, memchr, w2c)
   37 #if MIN_VERSION_bytestring(0, 10, 6)
   38 import           Data.ByteString.Internal         (accursedUnutterablePerformIO)
   39 #else
   40 import           Data.ByteString.Internal         (inlinePerformIO)
   41 #endif
   42 import qualified Data.ByteString.Unsafe           as S
   43 #if !MIN_VERSION_io_streams(1,2,0)
   44 import           Data.IORef                       (newIORef, readIORef, writeIORef)
   45 #endif
   46 import           Data.List                        (sort)
   47 import           Data.Typeable                    (Typeable)
   48 import qualified Data.Vector                      as V
   49 import qualified Data.Vector.Mutable              as MV
   50 import           Foreign.ForeignPtr               (withForeignPtr)
   51 import           Foreign.Ptr                      (minusPtr, nullPtr, plusPtr)
   52 import           Prelude                          hiding (take)
   53 ------------------------------------------------------------------------------
   54 import           Blaze.ByteString.Builder.HTTP    (chunkedTransferEncoding, chunkedTransferTerminator)
   55 import           Data.ByteString.Builder          (Builder)
   56 import           System.IO.Streams                (InputStream, OutputStream)
   57 import qualified System.IO.Streams                as Streams
   58 import           System.IO.Streams.Attoparsec     (parseFromStream)
   59 ------------------------------------------------------------------------------
   60 import           Snap.Internal.Http.Types         (Method (..))
   61 import           Snap.Internal.Parsing            (crlf, parseCookie, parseUrlEncoded, unsafeFromNat, (<?>))
   62 import           Snap.Types.Headers               (Headers)
   63 import qualified Snap.Types.Headers               as H
   64 
   65 
   66 ------------------------------------------------------------------------------
   67 newtype StandardHeaders = StandardHeaders (V.Vector (Maybe ByteString))
   68 type MStandardHeaders = MV.IOVector (Maybe ByteString)
   69 
   70 
   71 ------------------------------------------------------------------------------
   72 contentLengthTag, hostTag, transferEncodingTag, cookieTag, contentTypeTag,
   73   connectionTag, nStandardHeaders :: Int
   74 contentLengthTag    = 0
   75 hostTag             = 1
   76 transferEncodingTag = 2
   77 cookieTag           = 3
   78 contentTypeTag      = 4
   79 connectionTag       = 5
   80 nStandardHeaders    = 6
   81 
   82 
   83 ------------------------------------------------------------------------------
   84 findStdHeaderIndex :: ByteString -> Int
   85 findStdHeaderIndex "content-length"    = contentLengthTag
   86 findStdHeaderIndex "host"              = hostTag
   87 findStdHeaderIndex "transfer-encoding" = transferEncodingTag
   88 findStdHeaderIndex "cookie"            = cookieTag
   89 findStdHeaderIndex "content-type"      = contentTypeTag
   90 findStdHeaderIndex "connection"        = connectionTag
   91 findStdHeaderIndex _                   = -1
   92 
   93 
   94 ------------------------------------------------------------------------------
   95 getStdContentLength, getStdHost, getStdTransferEncoding, getStdCookie,
   96     getStdConnection, getStdContentType :: StandardHeaders -> Maybe ByteString
   97 getStdContentLength    (StandardHeaders v) = V.unsafeIndex v contentLengthTag
   98 getStdHost             (StandardHeaders v) = V.unsafeIndex v hostTag
   99 getStdTransferEncoding (StandardHeaders v) = V.unsafeIndex v transferEncodingTag
  100 getStdCookie           (StandardHeaders v) = V.unsafeIndex v cookieTag
  101 getStdContentType      (StandardHeaders v) = V.unsafeIndex v contentTypeTag
  102 getStdConnection       (StandardHeaders v) = V.unsafeIndex v connectionTag
  103 
  104 
  105 ------------------------------------------------------------------------------
  106 newMStandardHeaders :: IO MStandardHeaders
  107 newMStandardHeaders = MV.replicate nStandardHeaders Nothing
  108 
  109 
  110 ------------------------------------------------------------------------------
  111 -- | an internal version of the headers part of an HTTP request
  112 data IRequest = IRequest
  113     { iMethod         :: !Method
  114     , iRequestUri     :: !ByteString
  115     , iHttpVersion    :: (Int, Int)
  116     , iRequestHeaders :: Headers
  117     , iStdHeaders     :: StandardHeaders
  118     }
  119 
  120 ------------------------------------------------------------------------------
  121 instance Eq IRequest where
  122     a == b =
  123         and [ iMethod a      == iMethod b
  124             , iRequestUri a  == iRequestUri b
  125             , iHttpVersion a == iHttpVersion b
  126             , sort (H.toList (iRequestHeaders a))
  127                   == sort (H.toList (iRequestHeaders b))
  128             ]
  129 
  130 ------------------------------------------------------------------------------
  131 instance Show IRequest where
  132     show (IRequest m u (major, minor) hdrs _) =
  133         concat [ show m
  134                , " "
  135                , show u
  136                , " "
  137                , show major
  138                , "."
  139                , show minor
  140                , " "
  141                , show hdrs
  142                ]
  143 
  144 
  145 ------------------------------------------------------------------------------
  146 data HttpParseException = HttpParseException String deriving (Typeable, Show)
  147 instance Exception HttpParseException
  148 
  149 
  150 ------------------------------------------------------------------------------
  151 {-# INLINE parseRequest #-}
  152 parseRequest :: InputStream ByteString -> IO IRequest
  153 parseRequest input = do
  154     line <- pLine input
  155     let (!mStr, !s)     = bSp line
  156     let (!uri, !vStr)   = bSp s
  157     let method          = methodFromString mStr
  158     let !version        = pVer vStr
  159     let (host, uri')    = getHost uri
  160     let uri''           = if S.null uri' then "/" else uri'
  161 
  162     stdHdrs <- newMStandardHeaders
  163     MV.unsafeWrite stdHdrs hostTag host
  164     hdrs    <- pHeaders stdHdrs input
  165     outStd  <- StandardHeaders <$> V.unsafeFreeze stdHdrs
  166     return $! IRequest method uri'' version hdrs outStd
  167 
  168   where
  169     getHost s | "http://" `S.isPrefixOf` s
  170                   = let s'            = S.unsafeDrop 7 s
  171                         (!host, !uri) = breakCh '/' s'
  172                     in (Just $! host, uri)
  173               | "https://" `S.isPrefixOf` s
  174                   = let s'            = S.unsafeDrop 8 s
  175                         (!host, !uri) = breakCh '/' s'
  176                     in (Just $! host, uri)
  177               | otherwise = (Nothing, s)
  178 
  179     pVer s = if "HTTP/" `S.isPrefixOf` s
  180                then pVers (S.unsafeDrop 5 s)
  181                else (1, 0)
  182 
  183     bSp   = splitCh ' '
  184 
  185     pVers s = (c, d)
  186       where
  187         (!a, !b)   = splitCh '.' s
  188         !c         = unsafeFromNat a
  189         !d         = unsafeFromNat b
  190 
  191 
  192 ------------------------------------------------------------------------------
  193 pLine :: InputStream ByteString -> IO ByteString
  194 pLine input = go []
  195   where
  196     throwNoCRLF =
  197         throwIO $
  198         HttpParseException "parse error: expected line ending in crlf"
  199 
  200     throwBadCRLF =
  201         throwIO $
  202         HttpParseException "parse error: got cr without subsequent lf"
  203 
  204     go !l = do
  205         !mb <- Streams.read input
  206         !s  <- maybe throwNoCRLF return mb
  207 
  208         let !i = elemIndex '\r' s
  209         if i < 0
  210           then noCRLF l s
  211           else case () of
  212                  !_ | i+1 >= S.length s           -> lastIsCR l s i
  213                     | S.unsafeIndex s (i+1) == 10 -> foundCRLF l s i
  214                     | otherwise                   -> throwBadCRLF
  215 
  216     foundCRLF l s !i1 = do
  217         let !i2 = i1 + 2
  218         let !a = S.unsafeTake i1 s
  219         when (i2 < S.length s) $ do
  220             let !b = S.unsafeDrop i2 s
  221             Streams.unRead b input
  222 
  223         -- Optimize for the common case: dl is almost always "id"
  224         let !out = if null l then a else S.concat (reverse (a:l))
  225         return out
  226 
  227     noCRLF l s = go (s:l)
  228 
  229     lastIsCR l s !idx = do
  230         !t <- Streams.read input >>= maybe throwNoCRLF return
  231         if S.null t
  232           then lastIsCR l s idx
  233           else do
  234             let !c = S.unsafeHead t
  235             if c /= 10
  236               then throwBadCRLF
  237               else do
  238                   let !a = S.unsafeTake idx s
  239                   let !b = S.unsafeDrop 1 t
  240                   when (not $ S.null b) $ Streams.unRead b input
  241                   let !out = if null l then a else S.concat (reverse (a:l))
  242                   return out
  243 
  244 
  245 ------------------------------------------------------------------------------
  246 splitCh :: Char -> ByteString -> (ByteString, ByteString)
  247 splitCh !c !s = if idx < 0
  248                   then (s, S.empty)
  249                   else let !a = S.unsafeTake idx s
  250                            !b = S.unsafeDrop (idx + 1) s
  251                        in (a, b)
  252   where
  253     !idx = elemIndex c s
  254 {-# INLINE splitCh #-}
  255 
  256 
  257 ------------------------------------------------------------------------------
  258 breakCh :: Char -> ByteString -> (ByteString, ByteString)
  259 breakCh !c !s = if idx < 0
  260                   then (s, S.empty)
  261                   else let !a = S.unsafeTake idx s
  262                            !b = S.unsafeDrop idx s
  263                        in (a, b)
  264   where
  265     !idx = elemIndex c s
  266 {-# INLINE breakCh #-}
  267 
  268 
  269 ------------------------------------------------------------------------------
  270 splitHeader :: ByteString -> (ByteString, ByteString)
  271 splitHeader !s = if idx < 0
  272                    then (s, S.empty)
  273                    else let !a = S.unsafeTake idx s
  274                         in (a, skipSp (idx + 1))
  275   where
  276     !idx = elemIndex ':' s
  277     l    = S.length s
  278 
  279     skipSp !i | i >= l    = S.empty
  280               | otherwise = let c = S.unsafeIndex s i
  281                             in if isLWS $ w2c c
  282                                  then skipSp $ i + 1
  283                                  else S.unsafeDrop i s
  284 
  285 {-# INLINE splitHeader #-}
  286 
  287 
  288 
  289 ------------------------------------------------------------------------------
  290 isLWS :: Char -> Bool
  291 isLWS c = c == ' ' || c == '\t'
  292 {-# INLINE isLWS #-}
  293 
  294 
  295 ------------------------------------------------------------------------------
  296 pHeaders :: MStandardHeaders -> InputStream ByteString -> IO Headers
  297 pHeaders stdHdrs input = do
  298     hdrs    <- H.unsafeFromCaseFoldedList <$> go []
  299     return hdrs
  300 
  301   where
  302     go !list = do
  303         line <- pLine input
  304         if S.null line
  305           then return list
  306           else do
  307             let (!k0,!v) = splitHeader line
  308             let !k = toLower k0
  309             vf <- pCont id
  310             let vs = vf []
  311             let !v' = S.concat (v:vs)
  312             let idx = findStdHeaderIndex k
  313             when (idx >= 0) $ MV.unsafeWrite stdHdrs idx $! Just v'
  314 
  315             let l' = ((k, v'):list)
  316             go l'
  317 
  318     trimBegin = S.dropWhile isLWS
  319 
  320     pCont !dlist = do
  321         mbS  <- Streams.peek input
  322         maybe (return dlist)
  323               (\s -> if not (S.null s)
  324                        then if not $ isLWS $ w2c $ S.unsafeHead s
  325                               then return dlist
  326                               else procCont dlist
  327                        else Streams.read input >> pCont dlist)
  328               mbS
  329 
  330     procCont !dlist = do
  331         line <- pLine input
  332         let !t = trimBegin line
  333         pCont (dlist . (" ":) . (t:))
  334 
  335 
  336 ------------------------------------------------------------------------------
  337 methodFromString :: ByteString -> Method
  338 methodFromString "GET"     = GET
  339 methodFromString "POST"    = POST
  340 methodFromString "HEAD"    = HEAD
  341 methodFromString "PUT"     = PUT
  342 methodFromString "DELETE"  = DELETE
  343 methodFromString "TRACE"   = TRACE
  344 methodFromString "OPTIONS" = OPTIONS
  345 methodFromString "CONNECT" = CONNECT
  346 methodFromString "PATCH"   = PATCH
  347 methodFromString s         = Method s
  348 
  349 
  350 ------------------------------------------------------------------------------
  351 readChunkedTransferEncoding :: InputStream ByteString
  352                             -> IO (InputStream ByteString)
  353 readChunkedTransferEncoding input =
  354     Streams.makeInputStream $ parseFromStream pGetTransferChunk input
  355 
  356 
  357 ------------------------------------------------------------------------------
  358 writeChunkedTransferEncoding :: OutputStream Builder
  359                              -> IO (OutputStream Builder)
  360 #if MIN_VERSION_io_streams(1,2,0)
  361 writeChunkedTransferEncoding os = Streams.makeOutputStream f
  362   where
  363     f Nothing = do
  364         Streams.write (Just chunkedTransferTerminator) os
  365         Streams.write Nothing os
  366     f x = Streams.write (chunkedTransferEncoding `fmap` x) os
  367 
  368 #else
  369 writeChunkedTransferEncoding os = do
  370     -- make sure we only send the terminator once.
  371     eof <- newIORef True
  372     Streams.makeOutputStream $ f eof
  373   where
  374     f eof Nothing = readIORef eof >>= flip when (do
  375         writeIORef eof True
  376         Streams.write (Just chunkedTransferTerminator) os
  377         Streams.write Nothing os)
  378     f _ x = Streams.write (chunkedTransferEncoding `fmap` x) os
  379 #endif
  380 
  381 
  382                              ---------------------
  383                              -- parse functions --
  384                              ---------------------
  385 
  386 ------------------------------------------------------------------------------
  387 -- We treat chunks larger than this from clients as a denial-of-service attack.
  388 -- 256kB should be enough buffer.
  389 mAX_CHUNK_SIZE :: Int
  390 mAX_CHUNK_SIZE = (2::Int)^(18::Int)
  391 
  392 
  393 ------------------------------------------------------------------------------
  394 pGetTransferChunk :: Parser (Maybe ByteString)
  395 pGetTransferChunk = parser <?> "pGetTransferChunk"
  396   where
  397     parser = do
  398         !hex <- hexadecimal <?> "hexadecimal"
  399         skipWhile (/= '\r') <?> "skipToEOL"
  400         void crlf <?> "linefeed"
  401         if hex >= mAX_CHUNK_SIZE
  402           then return $! E.throw $! HttpParseException $!
  403                "pGetTransferChunk: chunk of size " ++ show hex ++ " too long."
  404           else if hex <= 0
  405             then (crlf >> return Nothing) <?> "terminal crlf after 0 length"
  406             else do
  407                 -- now safe to take this many bytes.
  408                 !x <- take hex <?> "reading data chunk"
  409                 void crlf <?> "linefeed after data chunk"
  410                 return $! Just x
  411 
  412 
  413 ------------------------------------------------------------------------------
  414 toLower :: ByteString -> ByteString
  415 toLower = S.map lower
  416   where
  417     lower c0 = let !c = c2w c0
  418                in if 65 <= c && c <= 90
  419                     then w2c $! c + 32
  420                     else c0
  421 
  422 
  423 ------------------------------------------------------------------------------
  424 -- | A version of elemIndex that doesn't allocate a Maybe. (It returns -1 on
  425 -- not found.)
  426 elemIndex :: Char -> ByteString -> Int
  427 #if MIN_VERSION_bytestring(0, 10, 6)
  428 elemIndex c (PS !fp !start !len) = accursedUnutterablePerformIO $
  429 #else
  430 elemIndex c (PS !fp !start !len) = inlinePerformIO $
  431 #endif
  432                                    withForeignPtr fp $ \p0 -> do
  433     let !p = plusPtr p0 start
  434     q <- memchr p w8 (fromIntegral len)
  435     return $! if q == nullPtr then (-1) else q `minusPtr` p
  436   where
  437     !w8 = c2w c
  438 {-# INLINE elemIndex #-}