From 0ec71b78fb248eae6b6fbbb3f5735b2f9646fd62 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Fri, 15 Jan 2021 02:15:04 -0300
Subject: [PATCH] astc: Return zero on out of bound bits

Avoid out of bound reads on invalid ASTC textures.
Games can bind invalid textures that make us read or write out of bounds.
---
 src/video_core/textures/astc.cpp | 39 ++++++++++++++++++--------------
 1 file changed, 22 insertions(+), 17 deletions(-)

diff --git a/src/video_core/textures/astc.cpp b/src/video_core/textures/astc.cpp
index acd5bdd78..1f1c3bd3a 100644
--- a/src/video_core/textures/astc.cpp
+++ b/src/video_core/textures/astc.cpp
@@ -42,21 +42,24 @@ constexpr u32 Popcnt(u32 n) {
 
 class InputBitStream {
 public:
-    constexpr explicit InputBitStream(const u8* ptr, std::size_t start_offset = 0)
-        : cur_byte{ptr}, next_bit{start_offset % 8} {}
+    constexpr explicit InputBitStream(std::span<const u8> data, size_t start_offset = 0)
+        : cur_byte{data.data()}, total_bits{data.size()}, next_bit{start_offset % 8} {}
 
-    constexpr std::size_t GetBitsRead() const {
+    constexpr size_t GetBitsRead() const {
         return bits_read;
     }
 
     constexpr bool ReadBit() {
-        const bool bit = (*cur_byte >> next_bit++) & 1;
+        if (bits_read >= total_bits * 8) {
+            return 0;
+        }
+        const bool bit = ((*cur_byte >> next_bit) & 1) != 0;
+        ++next_bit;
         while (next_bit >= 8) {
             next_bit -= 8;
-            cur_byte++;
+            ++cur_byte;
         }
-
-        bits_read++;
+        ++bits_read;
         return bit;
     }
 
@@ -79,8 +82,9 @@ public:
 
 private:
     const u8* cur_byte;
-    std::size_t next_bit = 0;
-    std::size_t bits_read = 0;
+    size_t total_bits = 0;
+    size_t next_bit = 0;
+    size_t bits_read = 0;
 };
 
 class OutputBitStream {
@@ -200,8 +204,8 @@ using IntegerEncodedVector = boost::container::static_vector<
 
 static void DecodeTritBlock(InputBitStream& bits, IntegerEncodedVector& result, u32 nBitsPerValue) {
     // Implement the algorithm in section C.2.12
-    u32 m[5];
-    u32 t[5];
+    std::array<u32, 5> m;
+    std::array<u32, 5> t;
     u32 T;
 
     // Read the trit encoded block according to
@@ -866,7 +870,7 @@ public:
     }
 };
 
-static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nPartitions,
+static void DecodeColorValues(u32* out, std::span<u8> data, const u32* modes, const u32 nPartitions,
                               const u32 nBitsForColorData) {
     // First figure out how many color values we have
     u32 nValues = 0;
@@ -898,7 +902,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
     // We now have enough to decode our integer sequence.
     IntegerEncodedVector decodedColorValues;
 
-    InputBitStream colorStream(data);
+    InputBitStream colorStream(data, 0);
     DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues);
 
     // Once we have the decoded values, we need to dequantize them to the 0-255 range
@@ -1441,7 +1445,7 @@ static void ComputeEndpos32s(Pixel& ep1, Pixel& ep2, const u32*& colorValues,
 
 static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth,
                             const u32 blockHeight, std::span<u32, 12 * 12> outBuf) {
-    InputBitStream strm(inBuf.data());
+    InputBitStream strm(inBuf);
     TexelWeightParams weightParams = DecodeBlockInfo(strm);
 
     // Was there an error?
@@ -1619,15 +1623,16 @@ static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth,
 
     // Make sure that higher non-texel bits are set to zero
     const u32 clearByteStart = (weightParams.GetPackedBitSize() >> 3) + 1;
-    if (clearByteStart > 0) {
+    if (clearByteStart > 0 && clearByteStart <= texelWeightData.size()) {
         texelWeightData[clearByteStart - 1] &=
             static_cast<u8>((1 << (weightParams.GetPackedBitSize() % 8)) - 1);
+        std::memset(texelWeightData.data() + clearByteStart, 0,
+                    std::min(16U - clearByteStart, 16U));
     }
-    std::memset(texelWeightData.data() + clearByteStart, 0, std::min(16U - clearByteStart, 16U));
 
     IntegerEncodedVector texelWeightValues;
 
-    InputBitStream weightStream(texelWeightData.data());
+    InputBitStream weightStream(texelWeightData);
 
     DecodeIntegerSequence(texelWeightValues, weightStream, weightParams.m_MaxWeight,
                           weightParams.GetNumWeightValues());