1 {-# LANGUAGE BangPatterns #-} 2 {-# LANGUAGE CPP #-} 3 {-# LANGUAGE DeriveDataTypeable #-} 4 {-# LANGUAGE MagicHash #-} 5 {-# LANGUAGE OverloadedStrings #-} 6 {-# LANGUAGE Rank2Types #-} 7 {-# LANGUAGE Trustworthy #-} 8 {-# LANGUAGE UnboxedTuples #-} 9 10 module Snap.Internal.Http.Server.Parser 11 ( IRequest(..) 12 , HttpParseException(..) 13 , readChunkedTransferEncoding 14 , writeChunkedTransferEncoding 15 , parseRequest 16 , parseFromStream 17 , parseCookie 18 , parseUrlEncoded 19 , getStdContentLength 20 , getStdHost 21 , getStdTransferEncoding 22 , getStdCookie 23 , getStdContentType 24 , getStdConnection 25 ) where 26 27 ------------------------------------------------------------------------------ 28 #if !MIN_VERSION_base(4,8,0) 29 import Control.Applicative ((<$>)) 30 #endif 31 import Control.Exception (Exception, throwIO) 32 import qualified Control.Exception as E 33 import Control.Monad (void, when) 34 import Data.Attoparsec.ByteString.Char8 (Parser, hexadecimal, skipWhile, take) 35 import qualified Data.ByteString.Char8 as S 36 import Data.ByteString.Internal (ByteString (..), c2w, memchr, w2c) 37 #if MIN_VERSION_bytestring(0, 10, 6) 38 import Data.ByteString.Internal (accursedUnutterablePerformIO) 39 #else 40 import Data.ByteString.Internal (inlinePerformIO) 41 #endif 42 import qualified Data.ByteString.Unsafe as S 43 #if !MIN_VERSION_io_streams(1,2,0) 44 import Data.IORef (newIORef, readIORef, writeIORef) 45 #endif 46 import Data.List (sort) 47 import Data.Typeable (Typeable) 48 import qualified Data.Vector as V 49 import qualified Data.Vector.Mutable as MV 50 import Foreign.ForeignPtr (withForeignPtr) 51 import Foreign.Ptr (minusPtr, nullPtr, plusPtr) 52 import Prelude hiding (take) 53 ------------------------------------------------------------------------------ 54 import Blaze.ByteString.Builder.HTTP (chunkedTransferEncoding, chunkedTransferTerminator) 55 import Data.ByteString.Builder (Builder) 56 import System.IO.Streams (InputStream, OutputStream) 57 import qualified System.IO.Streams as Streams 58 import System.IO.Streams.Attoparsec (parseFromStream) 59 ------------------------------------------------------------------------------ 60 import Snap.Internal.Http.Types (Method (..)) 61 import Snap.Internal.Parsing (crlf, parseCookie, parseUrlEncoded, unsafeFromNat, (<?>)) 62 import Snap.Types.Headers (Headers) 63 import qualified Snap.Types.Headers as H 64 65 66 ------------------------------------------------------------------------------ 67 newtype StandardHeaders = StandardHeaders (V.Vector (Maybe ByteString)) 68 type MStandardHeaders = MV.IOVector (Maybe ByteString) 69 70 71 ------------------------------------------------------------------------------ 72 contentLengthTag, hostTag, transferEncodingTag, cookieTag, contentTypeTag, 73 connectionTag, nStandardHeaders :: Int 74 contentLengthTag = 0 75 hostTag = 1 76 transferEncodingTag = 2 77 cookieTag = 3 78 contentTypeTag = 4 79 connectionTag = 5 80 nStandardHeaders = 6 81 82 83 ------------------------------------------------------------------------------ 84 findStdHeaderIndex :: ByteString -> Int 85 findStdHeaderIndex "content-length" = contentLengthTag 86 findStdHeaderIndex "host" = hostTag 87 findStdHeaderIndex "transfer-encoding" = transferEncodingTag 88 findStdHeaderIndex "cookie" = cookieTag 89 findStdHeaderIndex "content-type" = contentTypeTag 90 findStdHeaderIndex "connection" = connectionTag 91 findStdHeaderIndex _ = -1 92 93 94 ------------------------------------------------------------------------------ 95 getStdContentLength, getStdHost, getStdTransferEncoding, getStdCookie, 96 getStdConnection, getStdContentType :: StandardHeaders -> Maybe ByteString 97 getStdContentLength (StandardHeaders v) = V.unsafeIndex v contentLengthTag 98 getStdHost (StandardHeaders v) = V.unsafeIndex v hostTag 99 getStdTransferEncoding (StandardHeaders v) = V.unsafeIndex v transferEncodingTag 100 getStdCookie (StandardHeaders v) = V.unsafeIndex v cookieTag 101 getStdContentType (StandardHeaders v) = V.unsafeIndex v contentTypeTag 102 getStdConnection (StandardHeaders v) = V.unsafeIndex v connectionTag 103 104 105 ------------------------------------------------------------------------------ 106 newMStandardHeaders :: IO MStandardHeaders 107 newMStandardHeaders = MV.replicate nStandardHeaders Nothing 108 109 110 ------------------------------------------------------------------------------ 111 -- | an internal version of the headers part of an HTTP request 112 data IRequest = IRequest 113 { iMethod :: !Method 114 , iRequestUri :: !ByteString 115 , iHttpVersion :: (Int, Int) 116 , iRequestHeaders :: Headers 117 , iStdHeaders :: StandardHeaders 118 } 119 120 ------------------------------------------------------------------------------ 121 instance Eq IRequest where 122 a == b = 123 and [ iMethod a == iMethod b 124 , iRequestUri a == iRequestUri b 125 , iHttpVersion a == iHttpVersion b 126 , sort (H.toList (iRequestHeaders a)) 127 == sort (H.toList (iRequestHeaders b)) 128 ] 129 130 ------------------------------------------------------------------------------ 131 instance Show IRequest where 132 show (IRequest m u (major, minor) hdrs _) = 133 concat [ show m 134 , " " 135 , show u 136 , " " 137 , show major 138 , "." 139 , show minor 140 , " " 141 , show hdrs 142 ] 143 144 145 ------------------------------------------------------------------------------ 146 data HttpParseException = HttpParseException String deriving (Typeable, Show) 147 instance Exception HttpParseException 148 149 150 ------------------------------------------------------------------------------ 151 {-# INLINE parseRequest #-} 152 parseRequest :: InputStream ByteString -> IO IRequest 153 parseRequest input = do 154 line <- pLine input 155 let (!mStr, !s) = bSp line 156 let (!uri, !vStr) = bSp s 157 let method = methodFromString mStr 158 let !version = pVer vStr 159 let (host, uri') = getHost uri 160 let uri'' = if S.null uri' then "/" else uri' 161 162 stdHdrs <- newMStandardHeaders 163 MV.unsafeWrite stdHdrs hostTag host 164 hdrs <- pHeaders stdHdrs input 165 outStd <- StandardHeaders <$> V.unsafeFreeze stdHdrs 166 return $! IRequest method uri'' version hdrs outStd 167 168 where 169 getHost s | "http://" `S.isPrefixOf` s 170 = let s' = S.unsafeDrop 7 s 171 (!host, !uri) = breakCh '/' s' 172 in (Just $! host, uri) 173 | "https://" `S.isPrefixOf` s 174 = let s' = S.unsafeDrop 8 s 175 (!host, !uri) = breakCh '/' s' 176 in (Just $! host, uri) 177 | otherwise = (Nothing, s) 178 179 pVer s = if "HTTP/" `S.isPrefixOf` s 180 then pVers (S.unsafeDrop 5 s) 181 else (1, 0) 182 183 bSp = splitCh ' ' 184 185 pVers s = (c, d) 186 where 187 (!a, !b) = splitCh '.' s 188 !c = unsafeFromNat a 189 !d = unsafeFromNat b 190 191 192 ------------------------------------------------------------------------------ 193 pLine :: InputStream ByteString -> IO ByteString 194 pLine input = go [] 195 where 196 throwNoCRLF = 197 throwIO $ 198 HttpParseException "parse error: expected line ending in crlf" 199 200 throwBadCRLF = 201 throwIO $ 202 HttpParseException "parse error: got cr without subsequent lf" 203 204 go !l = do 205 !mb <- Streams.read input 206 !s <- maybe throwNoCRLF return mb 207 208 let !i = elemIndex '\r' s 209 if i < 0 210 then noCRLF l s 211 else case () of 212 !_ | i+1 >= S.length s -> lastIsCR l s i 213 | S.unsafeIndex s (i+1) == 10 -> foundCRLF l s i 214 | otherwise -> throwBadCRLF 215 216 foundCRLF l s !i1 = do 217 let !i2 = i1 + 2 218 let !a = S.unsafeTake i1 s 219 when (i2 < S.length s) $ do 220 let !b = S.unsafeDrop i2 s 221 Streams.unRead b input 222 223 -- Optimize for the common case: dl is almost always "id" 224 let !out = if null l then a else S.concat (reverse (a:l)) 225 return out 226 227 noCRLF l s = go (s:l) 228 229 lastIsCR l s !idx = do 230 !t <- Streams.read input >>= maybe throwNoCRLF return 231 if S.null t 232 then lastIsCR l s idx 233 else do 234 let !c = S.unsafeHead t 235 if c /= 10 236 then throwBadCRLF 237 else do 238 let !a = S.unsafeTake idx s 239 let !b = S.unsafeDrop 1 t 240 when (not $ S.null b) $ Streams.unRead b input 241 let !out = if null l then a else S.concat (reverse (a:l)) 242 return out 243 244 245 ------------------------------------------------------------------------------ 246 splitCh :: Char -> ByteString -> (ByteString, ByteString) 247 splitCh !c !s = if idx < 0 248 then (s, S.empty) 249 else let !a = S.unsafeTake idx s 250 !b = S.unsafeDrop (idx + 1) s 251 in (a, b) 252 where 253 !idx = elemIndex c s 254 {-# INLINE splitCh #-} 255 256 257 ------------------------------------------------------------------------------ 258 breakCh :: Char -> ByteString -> (ByteString, ByteString) 259 breakCh !c !s = if idx < 0 260 then (s, S.empty) 261 else let !a = S.unsafeTake idx s 262 !b = S.unsafeDrop idx s 263 in (a, b) 264 where 265 !idx = elemIndex c s 266 {-# INLINE breakCh #-} 267 268 269 ------------------------------------------------------------------------------ 270 splitHeader :: ByteString -> (ByteString, ByteString) 271 splitHeader !s = if idx < 0 272 then (s, S.empty) 273 else let !a = S.unsafeTake idx s 274 in (a, skipSp (idx + 1)) 275 where 276 !idx = elemIndex ':' s 277 l = S.length s 278 279 skipSp !i | i >= l = S.empty 280 | otherwise = let c = S.unsafeIndex s i 281 in if isLWS $ w2c c 282 then skipSp $ i + 1 283 else S.unsafeDrop i s 284 285 {-# INLINE splitHeader #-} 286 287 288 289 ------------------------------------------------------------------------------ 290 isLWS :: Char -> Bool 291 isLWS c = c == ' ' || c == '\t' 292 {-# INLINE isLWS #-} 293 294 295 ------------------------------------------------------------------------------ 296 pHeaders :: MStandardHeaders -> InputStream ByteString -> IO Headers 297 pHeaders stdHdrs input = do 298 hdrs <- H.unsafeFromCaseFoldedList <$> go [] 299 return hdrs 300 301 where 302 go !list = do 303 line <- pLine input 304 if S.null line 305 then return list 306 else do 307 let (!k0,!v) = splitHeader line 308 let !k = toLower k0 309 vf <- pCont id 310 let vs = vf [] 311 let !v' = S.concat (v:vs) 312 let idx = findStdHeaderIndex k 313 when (idx >= 0) $ MV.unsafeWrite stdHdrs idx $! Just v' 314 315 let l' = ((k, v'):list) 316 go l' 317 318 trimBegin = S.dropWhile isLWS 319 320 pCont !dlist = do 321 mbS <- Streams.peek input 322 maybe (return dlist) 323 (\s -> if not (S.null s) 324 then if not $ isLWS $ w2c $ S.unsafeHead s 325 then return dlist 326 else procCont dlist 327 else Streams.read input >> pCont dlist) 328 mbS 329 330 procCont !dlist = do 331 line <- pLine input 332 let !t = trimBegin line 333 pCont (dlist . (" ":) . (t:)) 334 335 336 ------------------------------------------------------------------------------ 337 methodFromString :: ByteString -> Method 338 methodFromString "GET" = GET 339 methodFromString "POST" = POST 340 methodFromString "HEAD" = HEAD 341 methodFromString "PUT" = PUT 342 methodFromString "DELETE" = DELETE 343 methodFromString "TRACE" = TRACE 344 methodFromString "OPTIONS" = OPTIONS 345 methodFromString "CONNECT" = CONNECT 346 methodFromString "PATCH" = PATCH 347 methodFromString s = Method s 348 349 350 ------------------------------------------------------------------------------ 351 readChunkedTransferEncoding :: InputStream ByteString 352 -> IO (InputStream ByteString) 353 readChunkedTransferEncoding input = 354 Streams.makeInputStream $ parseFromStream pGetTransferChunk input 355 356 357 ------------------------------------------------------------------------------ 358 writeChunkedTransferEncoding :: OutputStream Builder 359 -> IO (OutputStream Builder) 360 #if MIN_VERSION_io_streams(1,2,0) 361 writeChunkedTransferEncoding os = Streams.makeOutputStream f 362 where 363 f Nothing = do 364 Streams.write (Just chunkedTransferTerminator) os 365 Streams.write Nothing os 366 f x = Streams.write (chunkedTransferEncoding `fmap` x) os 367 368 #else 369 writeChunkedTransferEncoding os = do 370 -- make sure we only send the terminator once. 371 eof <- newIORef True 372 Streams.makeOutputStream $ f eof 373 where 374 f eof Nothing = readIORef eof >>= flip when (do 375 writeIORef eof True 376 Streams.write (Just chunkedTransferTerminator) os 377 Streams.write Nothing os) 378 f _ x = Streams.write (chunkedTransferEncoding `fmap` x) os 379 #endif 380 381 382 --------------------- 383 -- parse functions -- 384 --------------------- 385 386 ------------------------------------------------------------------------------ 387 -- We treat chunks larger than this from clients as a denial-of-service attack. 388 -- 256kB should be enough buffer. 389 mAX_CHUNK_SIZE :: Int 390 mAX_CHUNK_SIZE = (2::Int)^(18::Int) 391 392 393 ------------------------------------------------------------------------------ 394 pGetTransferChunk :: Parser (Maybe ByteString) 395 pGetTransferChunk = parser <?> "pGetTransferChunk" 396 where 397 parser = do 398 !hex <- hexadecimal <?> "hexadecimal" 399 skipWhile (/= '\r') <?> "skipToEOL" 400 void crlf <?> "linefeed" 401 if hex >= mAX_CHUNK_SIZE 402 then return $! E.throw $! HttpParseException $! 403 "pGetTransferChunk: chunk of size " ++ show hex ++ " too long." 404 else if hex <= 0 405 then (crlf >> return Nothing) <?> "terminal crlf after 0 length" 406 else do 407 -- now safe to take this many bytes. 408 !x <- take hex <?> "reading data chunk" 409 void crlf <?> "linefeed after data chunk" 410 return $! Just x 411 412 413 ------------------------------------------------------------------------------ 414 toLower :: ByteString -> ByteString 415 toLower = S.map lower 416 where 417 lower c0 = let !c = c2w c0 418 in if 65 <= c && c <= 90 419 then w2c $! c + 32 420 else c0 421 422 423 ------------------------------------------------------------------------------ 424 -- | A version of elemIndex that doesn't allocate a Maybe. (It returns -1 on 425 -- not found.) 426 elemIndex :: Char -> ByteString -> Int 427 #if MIN_VERSION_bytestring(0, 10, 6) 428 elemIndex c (PS !fp !start !len) = accursedUnutterablePerformIO $ 429 #else 430 elemIndex c (PS !fp !start !len) = inlinePerformIO $ 431 #endif 432 withForeignPtr fp $ \p0 -> do 433 let !p = plusPtr p0 start 434 q <- memchr p w8 (fromIntegral len) 435 return $! if q == nullPtr then (-1) else q `minusPtr` p 436 where 437 !w8 = c2w c 438 {-# INLINE elemIndex #-}