From 27c8bb9615039eefc837ade5d91d5733c738f4a0 Mon Sep 17 00:00:00 2001
From: Ameer J <52414509+ameerj@users.noreply.github.com>
Date: Sun, 30 Jul 2023 12:41:52 -0400
Subject: [PATCH] flattening

---
 src/video_core/host_shaders/astc_decoder.comp | 87 ++++++++++---------
 1 file changed, 44 insertions(+), 43 deletions(-)

diff --git a/src/video_core/host_shaders/astc_decoder.comp b/src/video_core/host_shaders/astc_decoder.comp
index b84ddd67d..f720df6d2 100644
--- a/src/video_core/host_shaders/astc_decoder.comp
+++ b/src/video_core/host_shaders/astc_decoder.comp
@@ -140,8 +140,6 @@ int result_index = 0;
 EncodingData texel_vector[144];
 int texel_vector_index = 0;
 
-uint unquantized_texel_weights[2][144];
-
 // Replicates low num_bits such that [(to_bit - 1):(to_bit - 1 - from_bit)]
 // is the same as [(num_bits - 1):0] and repeats all the way down.
 uint Replicate(uint val, uint num_bits, uint to_bit) {
@@ -879,58 +877,60 @@ uint UnquantizeTexelWeight(EncodingData val) {
     return result;
 }
 
-void UnquantizeTexelWeights(bool dual_plane, uvec2 size) {
-    uint weight_idx = 0;
-    uint unquantized[2][144];
-    uint area = size.x * size.y;
-    for (uint itr = 0; itr < texel_vector_index; itr++) {
-        unquantized[0][weight_idx] = UnquantizeTexelWeight(texel_vector[itr]);
-        if (dual_plane) {
-            ++itr;
-            unquantized[1][weight_idx] = UnquantizeTexelWeight(texel_vector[itr]);
-            if (itr == texel_vector_index) {
-                break;
-            }
-        }
-        if (++weight_idx >= (area))
-            break;
-    }
-
+void UnquantizeTexelWeights(bool is_dual_plane, uvec2 size, out uint unquantized_texel_weights[2 * 144]) {
     const uint Ds = uint((block_dims.x * 0.5f + 1024) / (block_dims.x - 1));
     const uint Dt = uint((block_dims.y * 0.5f + 1024) / (block_dims.y - 1));
-    const uint k_plane_scale = dual_plane ? 2 : 1;
-    for (uint plane = 0; plane < k_plane_scale; plane++) {
+    const uint num_planes = is_dual_plane ? 2 : 1;
+    const uint area = size.x * size.y;
+    const uint loop_count = min(result_index, area * num_planes);
+    uint unquantized[2 * 144];
+    for (uint itr = 0; itr < loop_count; ++itr) {
+        unquantized[itr] = UnquantizeTexelWeight(texel_vector[itr]);
+    }
+    for (uint plane = 0; plane < num_planes; ++plane) {
         for (uint t = 0; t < block_dims.y; t++) {
             for (uint s = 0; s < block_dims.x; s++) {
-                uint cs = Ds * s;
-                uint ct = Dt * t;
-                uint gs = (cs * (size.x - 1) + 32) >> 6;
-                uint gt = (ct * (size.y - 1) + 32) >> 6;
-                uint js = gs >> 4;
-                uint fs = gs & 0xF;
-                uint jt = gt >> 4;
-                uint ft = gt & 0x0F;
-                uint w11 = (fs * ft + 8) >> 4;
-                uint w10 = ft - w11;
-                uint w01 = fs - w11;
-                uint w00 = 16 - fs - ft + w11;
-                uvec4 w = uvec4(w00, w01, w10, w11);
-                uint v0 = jt * size.x + js;
+                const uint cs = Ds * s;
+                const uint ct = Dt * t;
+                const uint gs = (cs * (size.x - 1) + 32) >> 6;
+                const uint gt = (ct * (size.y - 1) + 32) >> 6;
+                const uint js = gs >> 4;
+                const uint fs = gs & 0xF;
+                const uint jt = gt >> 4;
+                const uint ft = gt & 0x0F;
+                const uint w11 = (fs * ft + 8) >> 4;
+                const uint w10 = ft - w11;
+                const uint w01 = fs - w11;
+                const uint w00 = 16 - fs - ft + w11;
+                const uvec4 w = uvec4(w00, w01, w10, w11);
+                const uint v0 = jt * size.x + js;
 
                 uvec4 p = uvec4(0);
+
+#define VectorIndicesFromBase(offset_base)                                                         \
+    const uint offset = is_dual_plane ? 2 * offset_base + plane : offset_base;                     \
+
                 if (v0 < area) {
-                    p.x = unquantized[plane][v0];
+                    const uint offset_base = v0;
+                    VectorIndicesFromBase(offset_base);
+                    p.x = unquantized[offset];
                 }
                 if ((v0 + 1) < (area)) {
-                    p.y = unquantized[plane][v0 + 1];
+                    const uint offset_base = v0 + 1;
+                    VectorIndicesFromBase(offset_base);
+                    p.y = unquantized[offset];
                 }
                 if ((v0 + size.x) < (area)) {
-                    p.z = unquantized[plane][(v0 + size.x)];
+                    const uint offset_base = v0 + size.x;
+                    VectorIndicesFromBase(offset_base);
+                    p.z = unquantized[offset];
                 }
                 if ((v0 + size.x + 1) < (area)) {
-                    p.w = unquantized[plane][(v0 + size.x + 1)];
+                    const uint offset_base = v0 + size.x + 1;
+                    VectorIndicesFromBase(offset_base);
+                    p.w = unquantized[offset];
                 }
-                unquantized_texel_weights[plane][t * block_dims.x + s] = (uint(dot(p, w)) + 8) >> 4;
+                unquantized_texel_weights[plane * 144 + t * block_dims.x + s] = (uint(dot(p, w)) + 8) >> 4;
             }
         }
     }
@@ -1208,7 +1208,8 @@ void DecompressBlock(ivec3 coord) {
     texel_flag = true; // use texel "vector" and bit stream in integer decoding
     DecodeIntegerSequence(params.max_weight, GetNumWeightValues(params.size, params.dual_plane));
 
-    UnquantizeTexelWeights(params.dual_plane, params.size);
+    uint unquantized_texel_weights[2 * 144];
+    UnquantizeTexelWeights(params.dual_plane, params.size, unquantized_texel_weights);
 
     for (uint j = 0; j < block_dims.y; j++) {
         for (uint i = 0; i < block_dims.x; i++) {
@@ -1220,10 +1221,10 @@ void DecompressBlock(ivec3 coord) {
             const uvec4 C0 = ReplicateByteTo16(endpoints[local_partition][0]);
             const uvec4 C1 = ReplicateByteTo16(endpoints[local_partition][1]);
             const uint weight_offset = (j * block_dims.x + i);
-            const uint primary_weight = unquantized_texel_weights[weight_offset][0];
+            const uint primary_weight = unquantized_texel_weights[weight_offset];
             uvec4 weight_vec = uvec4(primary_weight);
             if (params.dual_plane) {
-                const uint secondary_weight = unquantized_texel_weights[weight_offset][1];
+                const uint secondary_weight = unquantized_texel_weights[weight_offset + 144];
                 for (uint c = 0; c < 4; c++) {
                     const bool is_secondary = ((plane_index + 1u) & 3u) == c;
                     weight_vec[c] = is_secondary ? secondary_weight : primary_weight;