446 lines
16 KiB
Haskell
446 lines
16 KiB
Haskell
{-# LANGUAGE BangPatterns #-}
|
|
{-# LANGUAGE CPP #-}
|
|
{-# LANGUAGE DoAndIfThenElse #-}
|
|
-- |
|
|
-- Module : Data.ByteString.Base64.Internal
|
|
-- Copyright : (c) 2010 Bryan O'Sullivan
|
|
--
|
|
-- License : BSD-style
|
|
-- Maintainer : bos@serpentine.com
|
|
-- Stability : experimental
|
|
-- Portability : GHC
|
|
--
|
|
-- Fast and efficient encoding and decoding of base64-encoded strings.
|
|
|
|
module Data.ByteString.Base64.Internal
|
|
( encodeWith
|
|
, decodeWithTable
|
|
, decodeLenientWithTable
|
|
, mkEncodeTable
|
|
, done
|
|
, peek8, poke8, peek8_32
|
|
, reChunkIn
|
|
, Padding(..)
|
|
, withBS
|
|
, mkBS
|
|
) where
|
|
|
|
import Data.Bits ((.|.), (.&.), shiftL, shiftR)
|
|
import qualified Data.ByteString as B
|
|
import Data.ByteString.Internal (ByteString(..), mallocByteString)
|
|
import Data.Word (Word8, Word16, Word32)
|
|
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr)
|
|
import Foreign.Ptr (Ptr, castPtr, minusPtr, plusPtr)
|
|
import Foreign.Storable (peek, peekElemOff, poke)
|
|
import System.IO.Unsafe (unsafePerformIO)
|
|
|
|
peek8 :: Ptr Word8 -> IO Word8
|
|
peek8 = peek
|
|
|
|
poke8 :: Ptr Word8 -> Word8 -> IO ()
|
|
poke8 = poke
|
|
|
|
peek8_32 :: Ptr Word8 -> IO Word32
|
|
peek8_32 = fmap fromIntegral . peek8
|
|
|
|
|
|
data Padding = Padded | Don'tCare | Unpadded deriving Eq
|
|
|
|
-- | Encode a string into base64 form. The result will always be a multiple
|
|
-- of 4 bytes in length.
|
|
encodeWith :: Padding -> EncodeTable -> ByteString -> ByteString
|
|
encodeWith !padding (ET alfaFP encodeTable) !bs = withBS bs go
|
|
where
|
|
go !sptr !slen
|
|
| slen > maxBound `div` 4 =
|
|
error "Data.ByteString.Base64.encode: input too long"
|
|
| otherwise = do
|
|
let dlen = (slen + 2) `div` 3 * 4
|
|
dfp <- mallocByteString dlen
|
|
withForeignPtr alfaFP $ \aptr ->
|
|
withForeignPtr encodeTable $ \ep -> do
|
|
let aidx n = peek8 (aptr `plusPtr` n)
|
|
sEnd = sptr `plusPtr` slen
|
|
finish !n = return $ mkBS dfp n
|
|
fill !dp !sp !n
|
|
| sp `plusPtr` 2 >= sEnd = complete (castPtr dp) sp n
|
|
| otherwise = {-# SCC "encode/fill" #-} do
|
|
i <- peek8_32 sp
|
|
j <- peek8_32 (sp `plusPtr` 1)
|
|
k <- peek8_32 (sp `plusPtr` 2)
|
|
let w = i `shiftL` 16 .|. j `shiftL` 8 .|. k
|
|
enc = peekElemOff ep . fromIntegral
|
|
poke dp =<< enc (w `shiftR` 12)
|
|
poke (dp `plusPtr` 2) =<< enc (w .&. 0xfff)
|
|
fill (dp `plusPtr` 4) (sp `plusPtr` 3) (n + 4)
|
|
complete dp sp n
|
|
| sp == sEnd = finish n
|
|
| otherwise = {-# SCC "encode/complete" #-} do
|
|
let peekSP m f = (f . fromIntegral) `fmap` peek8 (sp `plusPtr` m)
|
|
twoMore = sp `plusPtr` 2 == sEnd
|
|
equals = 0x3d :: Word8
|
|
doPad = padding == Padded
|
|
{-# INLINE equals #-}
|
|
!a <- peekSP 0 ((`shiftR` 2) . (.&. 0xfc))
|
|
!b <- peekSP 0 ((`shiftL` 4) . (.&. 0x03))
|
|
|
|
poke8 dp =<< aidx a
|
|
|
|
if twoMore
|
|
then do
|
|
!b' <- peekSP 1 ((.|. b) . (`shiftR` 4) . (.&. 0xf0))
|
|
!c <- aidx =<< peekSP 1 ((`shiftL` 2) . (.&. 0x0f))
|
|
poke8 (dp `plusPtr` 1) =<< aidx b'
|
|
poke8 (dp `plusPtr` 2) c
|
|
|
|
if doPad
|
|
then poke8 (dp `plusPtr` 3) equals >> finish (n + 4)
|
|
else finish (n + 3)
|
|
else do
|
|
poke8 (dp `plusPtr` 1) =<< aidx b
|
|
|
|
if doPad
|
|
then do
|
|
poke8 (dp `plusPtr` 2) equals
|
|
poke8 (dp `plusPtr` 3) equals
|
|
finish (n + 4)
|
|
else finish (n + 2)
|
|
|
|
|
|
withForeignPtr dfp (\dptr -> fill (castPtr dptr) sptr 0)
|
|
|
|
data EncodeTable = ET !(ForeignPtr Word8) !(ForeignPtr Word16)
|
|
|
|
-- The encoding table is constructed such that the expansion of a 12-bit
|
|
-- block to a 16-bit block can be done by a single Word16 copy from the
|
|
-- correspoding table entry to the target address. The 16-bit blocks are
|
|
-- stored in big-endian order, as the indices into the table are built in
|
|
-- big-endian order.
|
|
mkEncodeTable :: ByteString -> EncodeTable
|
|
#if MIN_VERSION_bytestring(0,11,0)
|
|
mkEncodeTable alphabet@(BS afp _) =
|
|
case table of BS fp _ -> ET afp (castForeignPtr fp)
|
|
#else
|
|
mkEncodeTable alphabet@(PS afp _ _) =
|
|
case table of PS fp _ _ -> ET afp (castForeignPtr fp)
|
|
#endif
|
|
where
|
|
ix = fromIntegral . B.index alphabet
|
|
table = B.pack $ concat $ [ [ix j, ix k] | j <- [0..63], k <- [0..63] ]
|
|
|
|
-- | Decode a base64-encoded string. This function strictly follows
|
|
-- the specification in <http://tools.ietf.org/rfc/rfc4648 RFC 4648>.
|
|
--
|
|
-- This function takes the decoding table (for @base64@ or @base64url@) as
|
|
-- the first parameter.
|
|
--
|
|
-- For validation of padding properties, see note: $Validation
|
|
--
|
|
decodeWithTable :: Padding -> ForeignPtr Word8 -> ByteString -> Either String ByteString
|
|
decodeWithTable padding !decodeFP bs
|
|
| B.length bs == 0 = Right B.empty
|
|
| otherwise = case padding of
|
|
Padded
|
|
| r == 0 -> withBS bs go
|
|
| r == 1 -> Left "Base64-encoded bytestring has invalid size"
|
|
| otherwise -> Left "Base64-encoded bytestring is unpadded or has invalid padding"
|
|
Don'tCare
|
|
| r == 0 -> withBS bs go
|
|
| r == 2 -> withBS (B.append bs (B.replicate 2 0x3d)) go
|
|
| r == 3 -> validateLastPad bs invalidPad $ withBS (B.append bs (B.replicate 1 0x3d)) go
|
|
| otherwise -> Left "Base64-encoded bytestring has invalid size"
|
|
Unpadded
|
|
| r == 0 -> validateLastPad bs noPad $ withBS bs go
|
|
| r == 2 -> validateLastPad bs noPad $ withBS (B.append bs (B.replicate 2 0x3d)) go
|
|
| r == 3 -> validateLastPad bs noPad $ withBS (B.append bs (B.replicate 1 0x3d)) go
|
|
| otherwise -> Left "Base64-encoded bytestring has invalid size"
|
|
where
|
|
!r = B.length bs `rem` 4
|
|
|
|
noPad = "Base64-encoded bytestring required to be unpadded"
|
|
invalidPad = "Base64-encoded bytestring has invalid padding"
|
|
|
|
go !sptr !slen = do
|
|
dfp <- mallocByteString (slen `quot` 4 * 3)
|
|
withForeignPtr decodeFP (\ !decptr ->
|
|
withForeignPtr dfp (\dptr ->
|
|
decodeLoop decptr sptr dptr (sptr `plusPtr` slen) dfp))
|
|
|
|
decodeLoop
|
|
:: Ptr Word8
|
|
-- ^ decoding table pointer
|
|
-> Ptr Word8
|
|
-- ^ source pointer
|
|
-> Ptr Word8
|
|
-- ^ destination pointer
|
|
-> Ptr Word8
|
|
-- ^ source end pointer
|
|
-> ForeignPtr Word8
|
|
-- ^ destination foreign pointer (used for finalizing string)
|
|
-> IO (Either String ByteString)
|
|
decodeLoop !dtable !sptr !dptr !end !dfp = go dptr sptr
|
|
where
|
|
err p = return . Left
|
|
$ "invalid character at offset: "
|
|
++ show (p `minusPtr` sptr)
|
|
|
|
padErr p = return . Left
|
|
$ "invalid padding at offset: "
|
|
++ show (p `minusPtr` sptr)
|
|
|
|
canonErr p = return . Left
|
|
$ "non-canonical encoding detected at offset: "
|
|
++ show (p `minusPtr` sptr)
|
|
|
|
look :: Ptr Word8 -> IO Word32
|
|
look !p = do
|
|
!i <- peek p
|
|
!v <- peekElemOff dtable (fromIntegral i)
|
|
return (fromIntegral v)
|
|
|
|
go !dst !src
|
|
| plusPtr src 4 >= end = do
|
|
!a <- look src
|
|
!b <- look (src `plusPtr` 1)
|
|
!c <- look (src `plusPtr` 2)
|
|
!d <- look (src `plusPtr` 3)
|
|
finalChunk dst src a b c d
|
|
|
|
| otherwise = do
|
|
!a <- look src
|
|
!b <- look (src `plusPtr` 1)
|
|
!c <- look (src `plusPtr` 2)
|
|
!d <- look (src `plusPtr` 3)
|
|
decodeChunk dst src a b c d
|
|
|
|
-- | Decodes chunks of 4 bytes at a time, recombining into
|
|
-- 3 bytes. Note that in the inner loop stage, no padding
|
|
-- characters are admissible.
|
|
--
|
|
decodeChunk !dst !src !a !b !c !d
|
|
| a == 0x63 = padErr src
|
|
| b == 0x63 = padErr (plusPtr src 1)
|
|
| c == 0x63 = padErr (plusPtr src 2)
|
|
| d == 0x63 = padErr (plusPtr src 3)
|
|
| a == 0xff = err src
|
|
| b == 0xff = err (plusPtr src 1)
|
|
| c == 0xff = err (plusPtr src 2)
|
|
| d == 0xff = err (plusPtr src 3)
|
|
| otherwise = do
|
|
let !w = (shiftL a 18
|
|
.|. shiftL b 12
|
|
.|. shiftL c 6
|
|
.|. d) :: Word32
|
|
|
|
poke8 dst (fromIntegral (shiftR w 16))
|
|
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
|
|
poke8 (plusPtr dst 2) (fromIntegral w)
|
|
go (plusPtr dst 3) (plusPtr src 4)
|
|
|
|
-- | Decode the final 4 bytes in the string, recombining into
|
|
-- 3 bytes. Note that in this stage, we can have padding chars
|
|
-- but only in the final 2 positions.
|
|
--
|
|
finalChunk !dst !src a b c d
|
|
| a == 0x63 = padErr src
|
|
| b == 0x63 = padErr (plusPtr src 1)
|
|
| c == 0x63 && d /= 0x63 = err (plusPtr src 3) -- make sure padding is coherent.
|
|
| a == 0xff = err src
|
|
| b == 0xff = err (plusPtr src 1)
|
|
| c == 0xff = err (plusPtr src 2)
|
|
| d == 0xff = err (plusPtr src 3)
|
|
| otherwise = do
|
|
let !w = (shiftL a 18
|
|
.|. shiftL b 12
|
|
.|. shiftL c 6
|
|
.|. d) :: Word32
|
|
|
|
poke8 dst (fromIntegral (shiftR w 16))
|
|
|
|
if c == 0x63 && d == 0x63
|
|
then
|
|
if sanityCheckPos b mask_4bits
|
|
then return $ Right $ mkBS dfp (1 + (dst `minusPtr` dptr))
|
|
else canonErr (plusPtr src 1)
|
|
else if d == 0x63
|
|
then
|
|
if sanityCheckPos c mask_2bits
|
|
then do
|
|
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
|
|
return $ Right $ mkBS dfp (2 + (dst `minusPtr` dptr))
|
|
else canonErr (plusPtr src 2)
|
|
else do
|
|
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
|
|
poke8 (plusPtr dst 2) (fromIntegral w)
|
|
return $ Right $ mkBS dfp (3 + (dst `minusPtr` dptr))
|
|
|
|
|
|
-- | Decode a base64-encoded string. This function is lenient in
|
|
-- following the specification from
|
|
-- <http://tools.ietf.org/rfc/rfc4648 RFC 4648>, and will not
|
|
-- generate parse errors no matter how poor its input. This function
|
|
-- takes the decoding table (for @base64@ or @base64url@) as the first
|
|
-- paramert.
|
|
decodeLenientWithTable :: ForeignPtr Word8 -> ByteString -> ByteString
|
|
decodeLenientWithTable !decodeFP !bs = withBS bs go
|
|
where
|
|
go !sptr !slen
|
|
| dlen <= 0 = return B.empty
|
|
| otherwise = do
|
|
dfp <- mallocByteString dlen
|
|
withForeignPtr decodeFP $ \ !decptr -> do
|
|
let finish dbytes
|
|
| dbytes > 0 = return $ mkBS dfp dbytes
|
|
| otherwise = return B.empty
|
|
sEnd = sptr `plusPtr` slen
|
|
fill !dp !sp !n
|
|
| sp >= sEnd = finish n
|
|
| otherwise = {-# SCC "decodeLenientWithTable/fill" #-}
|
|
let look :: Bool -> Ptr Word8
|
|
-> (Ptr Word8 -> Word32 -> IO ByteString)
|
|
-> IO ByteString
|
|
{-# INLINE look #-}
|
|
look skipPad p0 f = go' p0
|
|
where
|
|
go' p | p >= sEnd = f (sEnd `plusPtr` (-1)) done
|
|
| otherwise = {-# SCC "decodeLenient/look" #-} do
|
|
ix <- fromIntegral `fmap` peek8 p
|
|
v <- peek8 (decptr `plusPtr` ix)
|
|
if v == x || v == done && skipPad
|
|
then go' (p `plusPtr` 1)
|
|
else f (p `plusPtr` 1) (fromIntegral v)
|
|
in look True sp $ \ !aNext !aValue ->
|
|
look True aNext $ \ !bNext !bValue ->
|
|
if aValue == done || bValue == done
|
|
then finish n
|
|
else
|
|
look False bNext $ \ !cNext !cValue ->
|
|
look False cNext $ \ !dNext !dValue -> do
|
|
let w = aValue `shiftL` 18 .|. bValue `shiftL` 12 .|.
|
|
cValue `shiftL` 6 .|. dValue
|
|
poke8 dp $ fromIntegral (w `shiftR` 16)
|
|
if cValue == done
|
|
then finish (n + 1)
|
|
else do
|
|
poke8 (dp `plusPtr` 1) $ fromIntegral (w `shiftR` 8)
|
|
if dValue == done
|
|
then finish (n + 2)
|
|
else do
|
|
poke8 (dp `plusPtr` 2) $ fromIntegral w
|
|
fill (dp `plusPtr` 3) dNext (n+3)
|
|
withForeignPtr dfp $ \dptr -> fill dptr sptr 0
|
|
where
|
|
!dlen = (slen + 3) `div` 4 * 3
|
|
|
|
x :: Integral a => a
|
|
x = 255
|
|
{-# INLINE x #-}
|
|
|
|
done :: Integral a => a
|
|
done = 99
|
|
{-# INLINE done #-}
|
|
|
|
-- This takes a list of ByteStrings, and returns a list in which each
|
|
-- (apart from possibly the last) has length that is a multiple of n
|
|
reChunkIn :: Int -> [ByteString] -> [ByteString]
|
|
reChunkIn !n = go
|
|
where
|
|
go [] = []
|
|
go (y : ys) = case B.length y `divMod` n of
|
|
(_, 0) -> y : go ys
|
|
(d, _) -> case B.splitAt (d * n) y of
|
|
(prefix, suffix) -> prefix : fixup suffix ys
|
|
fixup acc [] = [acc]
|
|
fixup acc (z : zs) = case B.splitAt (n - B.length acc) z of
|
|
(prefix, suffix) ->
|
|
let acc' = acc `B.append` prefix
|
|
in if B.length acc' == n
|
|
then let zs' = if B.null suffix
|
|
then zs
|
|
else suffix : zs
|
|
in acc' : go zs'
|
|
else -- suffix must be null
|
|
fixup acc' zs
|
|
|
|
-- $Validation
|
|
--
|
|
-- This function checks that the last char of a bytestring is '='
|
|
-- and, if true, fails with a message or completes some io action.
|
|
--
|
|
-- This is necessary to check when decoding permissively (i.e. filling in padding chars).
|
|
-- Consider the following 4 cases of a string of length l:
|
|
--
|
|
-- l = 0 mod 4: No pad chars are added, since the input is assumed to be good.
|
|
-- l = 1 mod 4: Never an admissible length in base64
|
|
-- l = 2 mod 4: 2 padding chars are added. If padding chars are present in the last 4 chars of the string,
|
|
-- they will fail to decode as final quanta.
|
|
-- l = 3 mod 4: 1 padding char is added. In this case a string is of the form <body> + <padchar>. If adding the
|
|
-- pad char "completes" the string so that it is `l = 0 mod 4`, then this may possibly form corrupted data.
|
|
-- This case is degenerate and should be disallowed.
|
|
--
|
|
-- Hence, permissive decodes should only fill in padding chars when it makes sense to add them. That is,
|
|
-- if an input is degenerate, it should never succeed when we add padding chars. We need the following invariant to hold:
|
|
--
|
|
-- @
|
|
-- B64U.decodeUnpadded <|> B64U.decodePadded ~ B64U.decodePadded
|
|
-- @
|
|
--
|
|
-- This means the only char we need to check is the last one, and only to disallow `l = 3 mod 4`.
|
|
--
|
|
validateLastPad
|
|
:: ByteString
|
|
-- ^ input to validate
|
|
-> String
|
|
-- ^ error msg
|
|
-> Either String ByteString
|
|
-> Either String ByteString
|
|
validateLastPad !bs err !io
|
|
| B.last bs == 0x3d = Left err
|
|
| otherwise = io
|
|
{-# INLINE validateLastPad #-}
|
|
|
|
-- | Sanity check an index against a bitmask to make sure
|
|
-- it's coherent. If pos & mask == 0, we're good. If not, we should fail.
|
|
--
|
|
sanityCheckPos :: Word32 -> Word8 -> Bool
|
|
sanityCheckPos pos mask = fromIntegral pos .&. mask == 0
|
|
{-# INLINE sanityCheckPos #-}
|
|
|
|
-- | Mask 2 bits
|
|
--
|
|
mask_2bits :: Word8
|
|
mask_2bits = 3 -- (1 << 2) - 1
|
|
{-# NOINLINE mask_2bits #-}
|
|
|
|
-- | Mask 4 bits
|
|
--
|
|
mask_4bits :: Word8
|
|
mask_4bits = 15 -- (1 << 4) - 1
|
|
{-# NOINLINE mask_4bits #-}
|
|
|
|
-- | Back-compat shim for bytestring >=0.11. Constructs a
|
|
-- bytestring from a foreign ptr and a length. Offset is 0.
|
|
--
|
|
mkBS :: ForeignPtr Word8 -> Int -> ByteString
|
|
#if MIN_VERSION_bytestring(0,11,0)
|
|
mkBS dfp n = BS dfp n
|
|
#else
|
|
mkBS dfp n = PS dfp 0 n
|
|
#endif
|
|
{-# INLINE mkBS #-}
|
|
|
|
-- | Back-compat shim for bytestring >=0.11. Unwraps the foreign ptr of
|
|
-- a bytestring, executing an IO action as a function of the underlying
|
|
-- pointer and some starting length.
|
|
--
|
|
-- Note: in `unsafePerformIO`.
|
|
--
|
|
withBS :: ByteString -> (Ptr Word8 -> Int -> IO a) -> a
|
|
#if MIN_VERSION_bytestring(0,11,0)
|
|
withBS (BS !sfp !slen) f = unsafePerformIO $
|
|
withForeignPtr sfp $ \p -> f p slen
|
|
#else
|
|
withBS (PS !sfp !soff !slen) f = unsafePerformIO $
|
|
withForeignPtr sfp $ \p -> f (plusPtr p soff) slen
|
|
#endif
|
|
{-# INLINE withBS #-}
|