stellar-veritas/bundled/Data/ByteString/Base64/Internal.hs
2026-01-25 02:27:22 +01:00

446 lines
16 KiB
Haskell

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DoAndIfThenElse #-}
-- |
-- Module : Data.ByteString.Base64.Internal
-- Copyright : (c) 2010 Bryan O'Sullivan
--
-- License : BSD-style
-- Maintainer : bos@serpentine.com
-- Stability : experimental
-- Portability : GHC
--
-- Fast and efficient encoding and decoding of base64-encoded strings.
module Data.ByteString.Base64.Internal
( encodeWith
, decodeWithTable
, decodeLenientWithTable
, mkEncodeTable
, done
, peek8, poke8, peek8_32
, reChunkIn
, Padding(..)
, withBS
, mkBS
) where
import Data.Bits ((.|.), (.&.), shiftL, shiftR)
import qualified Data.ByteString as B
import Data.ByteString.Internal (ByteString(..), mallocByteString)
import Data.Word (Word8, Word16, Word32)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr, castPtr, minusPtr, plusPtr)
import Foreign.Storable (peek, peekElemOff, poke)
import System.IO.Unsafe (unsafePerformIO)
peek8 :: Ptr Word8 -> IO Word8
peek8 = peek
poke8 :: Ptr Word8 -> Word8 -> IO ()
poke8 = poke
peek8_32 :: Ptr Word8 -> IO Word32
peek8_32 = fmap fromIntegral . peek8
data Padding = Padded | Don'tCare | Unpadded deriving Eq
-- | Encode a string into base64 form. The result will always be a multiple
-- of 4 bytes in length.
encodeWith :: Padding -> EncodeTable -> ByteString -> ByteString
encodeWith !padding (ET alfaFP encodeTable) !bs = withBS bs go
where
go !sptr !slen
| slen > maxBound `div` 4 =
error "Data.ByteString.Base64.encode: input too long"
| otherwise = do
let dlen = (slen + 2) `div` 3 * 4
dfp <- mallocByteString dlen
withForeignPtr alfaFP $ \aptr ->
withForeignPtr encodeTable $ \ep -> do
let aidx n = peek8 (aptr `plusPtr` n)
sEnd = sptr `plusPtr` slen
finish !n = return $ mkBS dfp n
fill !dp !sp !n
| sp `plusPtr` 2 >= sEnd = complete (castPtr dp) sp n
| otherwise = {-# SCC "encode/fill" #-} do
i <- peek8_32 sp
j <- peek8_32 (sp `plusPtr` 1)
k <- peek8_32 (sp `plusPtr` 2)
let w = i `shiftL` 16 .|. j `shiftL` 8 .|. k
enc = peekElemOff ep . fromIntegral
poke dp =<< enc (w `shiftR` 12)
poke (dp `plusPtr` 2) =<< enc (w .&. 0xfff)
fill (dp `plusPtr` 4) (sp `plusPtr` 3) (n + 4)
complete dp sp n
| sp == sEnd = finish n
| otherwise = {-# SCC "encode/complete" #-} do
let peekSP m f = (f . fromIntegral) `fmap` peek8 (sp `plusPtr` m)
twoMore = sp `plusPtr` 2 == sEnd
equals = 0x3d :: Word8
doPad = padding == Padded
{-# INLINE equals #-}
!a <- peekSP 0 ((`shiftR` 2) . (.&. 0xfc))
!b <- peekSP 0 ((`shiftL` 4) . (.&. 0x03))
poke8 dp =<< aidx a
if twoMore
then do
!b' <- peekSP 1 ((.|. b) . (`shiftR` 4) . (.&. 0xf0))
!c <- aidx =<< peekSP 1 ((`shiftL` 2) . (.&. 0x0f))
poke8 (dp `plusPtr` 1) =<< aidx b'
poke8 (dp `plusPtr` 2) c
if doPad
then poke8 (dp `plusPtr` 3) equals >> finish (n + 4)
else finish (n + 3)
else do
poke8 (dp `plusPtr` 1) =<< aidx b
if doPad
then do
poke8 (dp `plusPtr` 2) equals
poke8 (dp `plusPtr` 3) equals
finish (n + 4)
else finish (n + 2)
withForeignPtr dfp (\dptr -> fill (castPtr dptr) sptr 0)
data EncodeTable = ET !(ForeignPtr Word8) !(ForeignPtr Word16)
-- The encoding table is constructed such that the expansion of a 12-bit
-- block to a 16-bit block can be done by a single Word16 copy from the
-- correspoding table entry to the target address. The 16-bit blocks are
-- stored in big-endian order, as the indices into the table are built in
-- big-endian order.
mkEncodeTable :: ByteString -> EncodeTable
#if MIN_VERSION_bytestring(0,11,0)
mkEncodeTable alphabet@(BS afp _) =
case table of BS fp _ -> ET afp (castForeignPtr fp)
#else
mkEncodeTable alphabet@(PS afp _ _) =
case table of PS fp _ _ -> ET afp (castForeignPtr fp)
#endif
where
ix = fromIntegral . B.index alphabet
table = B.pack $ concat $ [ [ix j, ix k] | j <- [0..63], k <- [0..63] ]
-- | Decode a base64-encoded string. This function strictly follows
-- the specification in <http://tools.ietf.org/rfc/rfc4648 RFC 4648>.
--
-- This function takes the decoding table (for @base64@ or @base64url@) as
-- the first parameter.
--
-- For validation of padding properties, see note: $Validation
--
decodeWithTable :: Padding -> ForeignPtr Word8 -> ByteString -> Either String ByteString
decodeWithTable padding !decodeFP bs
| B.length bs == 0 = Right B.empty
| otherwise = case padding of
Padded
| r == 0 -> withBS bs go
| r == 1 -> Left "Base64-encoded bytestring has invalid size"
| otherwise -> Left "Base64-encoded bytestring is unpadded or has invalid padding"
Don'tCare
| r == 0 -> withBS bs go
| r == 2 -> withBS (B.append bs (B.replicate 2 0x3d)) go
| r == 3 -> validateLastPad bs invalidPad $ withBS (B.append bs (B.replicate 1 0x3d)) go
| otherwise -> Left "Base64-encoded bytestring has invalid size"
Unpadded
| r == 0 -> validateLastPad bs noPad $ withBS bs go
| r == 2 -> validateLastPad bs noPad $ withBS (B.append bs (B.replicate 2 0x3d)) go
| r == 3 -> validateLastPad bs noPad $ withBS (B.append bs (B.replicate 1 0x3d)) go
| otherwise -> Left "Base64-encoded bytestring has invalid size"
where
!r = B.length bs `rem` 4
noPad = "Base64-encoded bytestring required to be unpadded"
invalidPad = "Base64-encoded bytestring has invalid padding"
go !sptr !slen = do
dfp <- mallocByteString (slen `quot` 4 * 3)
withForeignPtr decodeFP (\ !decptr ->
withForeignPtr dfp (\dptr ->
decodeLoop decptr sptr dptr (sptr `plusPtr` slen) dfp))
decodeLoop
:: Ptr Word8
-- ^ decoding table pointer
-> Ptr Word8
-- ^ source pointer
-> Ptr Word8
-- ^ destination pointer
-> Ptr Word8
-- ^ source end pointer
-> ForeignPtr Word8
-- ^ destination foreign pointer (used for finalizing string)
-> IO (Either String ByteString)
decodeLoop !dtable !sptr !dptr !end !dfp = go dptr sptr
where
err p = return . Left
$ "invalid character at offset: "
++ show (p `minusPtr` sptr)
padErr p = return . Left
$ "invalid padding at offset: "
++ show (p `minusPtr` sptr)
canonErr p = return . Left
$ "non-canonical encoding detected at offset: "
++ show (p `minusPtr` sptr)
look :: Ptr Word8 -> IO Word32
look !p = do
!i <- peek p
!v <- peekElemOff dtable (fromIntegral i)
return (fromIntegral v)
go !dst !src
| plusPtr src 4 >= end = do
!a <- look src
!b <- look (src `plusPtr` 1)
!c <- look (src `plusPtr` 2)
!d <- look (src `plusPtr` 3)
finalChunk dst src a b c d
| otherwise = do
!a <- look src
!b <- look (src `plusPtr` 1)
!c <- look (src `plusPtr` 2)
!d <- look (src `plusPtr` 3)
decodeChunk dst src a b c d
-- | Decodes chunks of 4 bytes at a time, recombining into
-- 3 bytes. Note that in the inner loop stage, no padding
-- characters are admissible.
--
decodeChunk !dst !src !a !b !c !d
| a == 0x63 = padErr src
| b == 0x63 = padErr (plusPtr src 1)
| c == 0x63 = padErr (plusPtr src 2)
| d == 0x63 = padErr (plusPtr src 3)
| a == 0xff = err src
| b == 0xff = err (plusPtr src 1)
| c == 0xff = err (plusPtr src 2)
| d == 0xff = err (plusPtr src 3)
| otherwise = do
let !w = (shiftL a 18
.|. shiftL b 12
.|. shiftL c 6
.|. d) :: Word32
poke8 dst (fromIntegral (shiftR w 16))
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
poke8 (plusPtr dst 2) (fromIntegral w)
go (plusPtr dst 3) (plusPtr src 4)
-- | Decode the final 4 bytes in the string, recombining into
-- 3 bytes. Note that in this stage, we can have padding chars
-- but only in the final 2 positions.
--
finalChunk !dst !src a b c d
| a == 0x63 = padErr src
| b == 0x63 = padErr (plusPtr src 1)
| c == 0x63 && d /= 0x63 = err (plusPtr src 3) -- make sure padding is coherent.
| a == 0xff = err src
| b == 0xff = err (plusPtr src 1)
| c == 0xff = err (plusPtr src 2)
| d == 0xff = err (plusPtr src 3)
| otherwise = do
let !w = (shiftL a 18
.|. shiftL b 12
.|. shiftL c 6
.|. d) :: Word32
poke8 dst (fromIntegral (shiftR w 16))
if c == 0x63 && d == 0x63
then
if sanityCheckPos b mask_4bits
then return $ Right $ mkBS dfp (1 + (dst `minusPtr` dptr))
else canonErr (plusPtr src 1)
else if d == 0x63
then
if sanityCheckPos c mask_2bits
then do
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
return $ Right $ mkBS dfp (2 + (dst `minusPtr` dptr))
else canonErr (plusPtr src 2)
else do
poke8 (plusPtr dst 1) (fromIntegral (shiftR w 8))
poke8 (plusPtr dst 2) (fromIntegral w)
return $ Right $ mkBS dfp (3 + (dst `minusPtr` dptr))
-- | Decode a base64-encoded string. This function is lenient in
-- following the specification from
-- <http://tools.ietf.org/rfc/rfc4648 RFC 4648>, and will not
-- generate parse errors no matter how poor its input. This function
-- takes the decoding table (for @base64@ or @base64url@) as the first
-- paramert.
decodeLenientWithTable :: ForeignPtr Word8 -> ByteString -> ByteString
decodeLenientWithTable !decodeFP !bs = withBS bs go
where
go !sptr !slen
| dlen <= 0 = return B.empty
| otherwise = do
dfp <- mallocByteString dlen
withForeignPtr decodeFP $ \ !decptr -> do
let finish dbytes
| dbytes > 0 = return $ mkBS dfp dbytes
| otherwise = return B.empty
sEnd = sptr `plusPtr` slen
fill !dp !sp !n
| sp >= sEnd = finish n
| otherwise = {-# SCC "decodeLenientWithTable/fill" #-}
let look :: Bool -> Ptr Word8
-> (Ptr Word8 -> Word32 -> IO ByteString)
-> IO ByteString
{-# INLINE look #-}
look skipPad p0 f = go' p0
where
go' p | p >= sEnd = f (sEnd `plusPtr` (-1)) done
| otherwise = {-# SCC "decodeLenient/look" #-} do
ix <- fromIntegral `fmap` peek8 p
v <- peek8 (decptr `plusPtr` ix)
if v == x || v == done && skipPad
then go' (p `plusPtr` 1)
else f (p `plusPtr` 1) (fromIntegral v)
in look True sp $ \ !aNext !aValue ->
look True aNext $ \ !bNext !bValue ->
if aValue == done || bValue == done
then finish n
else
look False bNext $ \ !cNext !cValue ->
look False cNext $ \ !dNext !dValue -> do
let w = aValue `shiftL` 18 .|. bValue `shiftL` 12 .|.
cValue `shiftL` 6 .|. dValue
poke8 dp $ fromIntegral (w `shiftR` 16)
if cValue == done
then finish (n + 1)
else do
poke8 (dp `plusPtr` 1) $ fromIntegral (w `shiftR` 8)
if dValue == done
then finish (n + 2)
else do
poke8 (dp `plusPtr` 2) $ fromIntegral w
fill (dp `plusPtr` 3) dNext (n+3)
withForeignPtr dfp $ \dptr -> fill dptr sptr 0
where
!dlen = (slen + 3) `div` 4 * 3
x :: Integral a => a
x = 255
{-# INLINE x #-}
done :: Integral a => a
done = 99
{-# INLINE done #-}
-- This takes a list of ByteStrings, and returns a list in which each
-- (apart from possibly the last) has length that is a multiple of n
reChunkIn :: Int -> [ByteString] -> [ByteString]
reChunkIn !n = go
where
go [] = []
go (y : ys) = case B.length y `divMod` n of
(_, 0) -> y : go ys
(d, _) -> case B.splitAt (d * n) y of
(prefix, suffix) -> prefix : fixup suffix ys
fixup acc [] = [acc]
fixup acc (z : zs) = case B.splitAt (n - B.length acc) z of
(prefix, suffix) ->
let acc' = acc `B.append` prefix
in if B.length acc' == n
then let zs' = if B.null suffix
then zs
else suffix : zs
in acc' : go zs'
else -- suffix must be null
fixup acc' zs
-- $Validation
--
-- This function checks that the last char of a bytestring is '='
-- and, if true, fails with a message or completes some io action.
--
-- This is necessary to check when decoding permissively (i.e. filling in padding chars).
-- Consider the following 4 cases of a string of length l:
--
-- l = 0 mod 4: No pad chars are added, since the input is assumed to be good.
-- l = 1 mod 4: Never an admissible length in base64
-- l = 2 mod 4: 2 padding chars are added. If padding chars are present in the last 4 chars of the string,
-- they will fail to decode as final quanta.
-- l = 3 mod 4: 1 padding char is added. In this case a string is of the form <body> + <padchar>. If adding the
-- pad char "completes" the string so that it is `l = 0 mod 4`, then this may possibly form corrupted data.
-- This case is degenerate and should be disallowed.
--
-- Hence, permissive decodes should only fill in padding chars when it makes sense to add them. That is,
-- if an input is degenerate, it should never succeed when we add padding chars. We need the following invariant to hold:
--
-- @
-- B64U.decodeUnpadded <|> B64U.decodePadded ~ B64U.decodePadded
-- @
--
-- This means the only char we need to check is the last one, and only to disallow `l = 3 mod 4`.
--
validateLastPad
:: ByteString
-- ^ input to validate
-> String
-- ^ error msg
-> Either String ByteString
-> Either String ByteString
validateLastPad !bs err !io
| B.last bs == 0x3d = Left err
| otherwise = io
{-# INLINE validateLastPad #-}
-- | Sanity check an index against a bitmask to make sure
-- it's coherent. If pos & mask == 0, we're good. If not, we should fail.
--
sanityCheckPos :: Word32 -> Word8 -> Bool
sanityCheckPos pos mask = fromIntegral pos .&. mask == 0
{-# INLINE sanityCheckPos #-}
-- | Mask 2 bits
--
mask_2bits :: Word8
mask_2bits = 3 -- (1 << 2) - 1
{-# NOINLINE mask_2bits #-}
-- | Mask 4 bits
--
mask_4bits :: Word8
mask_4bits = 15 -- (1 << 4) - 1
{-# NOINLINE mask_4bits #-}
-- | Back-compat shim for bytestring >=0.11. Constructs a
-- bytestring from a foreign ptr and a length. Offset is 0.
--
mkBS :: ForeignPtr Word8 -> Int -> ByteString
#if MIN_VERSION_bytestring(0,11,0)
mkBS dfp n = BS dfp n
#else
mkBS dfp n = PS dfp 0 n
#endif
{-# INLINE mkBS #-}
-- | Back-compat shim for bytestring >=0.11. Unwraps the foreign ptr of
-- a bytestring, executing an IO action as a function of the underlying
-- pointer and some starting length.
--
-- Note: in `unsafePerformIO`.
--
withBS :: ByteString -> (Ptr Word8 -> Int -> IO a) -> a
#if MIN_VERSION_bytestring(0,11,0)
withBS (BS !sfp !slen) f = unsafePerformIO $
withForeignPtr sfp $ \p -> f p slen
#else
withBS (PS !sfp !soff !slen) f = unsafePerformIO $
withForeignPtr sfp $ \p -> f (plusPtr p soff) slen
#endif
{-# INLINE withBS #-}