/*
 * Decompiled with CFR 0.152.
 */
package me.jellysquid.mods.lithium.mixin.chunk.block_counting;

import me.jellysquid.mods.lithium.common.block.BlockCountingSection;
import me.jellysquid.mods.lithium.common.block.BlockStateFlagHolder;
import me.jellysquid.mods.lithium.common.block.BlockStateFlags;
import me.jellysquid.mods.lithium.common.block.TrackedBlockStatePredicate;
import net.minecraft.class_2540;
import net.minecraft.class_2680;
import net.minecraft.class_2826;
import net.minecraft.class_2841;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.Redirect;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;
import org.spongepowered.asm.mixin.injection.callback.LocalCapture;

@Mixin(value={class_2826.class})
public abstract class ChunkSectionMixin
implements BlockCountingSection {
    @Unique
    private short[] countsByFlag = new short[BlockStateFlags.NUM_FLAGS];

    @Shadow
    public abstract void method_12253();

    @Override
    public boolean anyMatch(TrackedBlockStatePredicate trackedBlockStatePredicate) {
        return this.countsByFlag[trackedBlockStatePredicate.getIndex()] != 0;
    }

    @Redirect(method={"calculateCounts()V"}, at=@At(value="INVOKE", target="Lnet/minecraft/world/chunk/PalettedContainer;count(Lnet/minecraft/world/chunk/PalettedContainer$Counter;)V"))
    private void initFlagCounters(class_2841<class_2680> palettedContainer, class_2841.class_4464<class_2680> consumer) {
        palettedContainer.method_21732((state, count) -> {
            consumer.accept(state, count);
            int flags = ((BlockStateFlagHolder)state).getAllFlags();
            int size = this.countsByFlag.length;
            for (int i = 0; i < size && flags != 0; flags >>>= 1, ++i) {
                if ((flags & 1) == 0) continue;
                int n = i;
                this.countsByFlag[n] = (short)(this.countsByFlag[n] + count);
            }
        });
    }

    @Inject(method={"calculateCounts()V"}, at={@At(value="HEAD")})
    private void resetFlagCounters(CallbackInfo ci) {
        this.countsByFlag = new short[BlockStateFlags.NUM_FLAGS];
    }

    @Inject(method={"setBlockState(IIILnet/minecraft/block/BlockState;Z)Lnet/minecraft/block/BlockState;"}, at={@At(value="INVOKE", target="Lnet/minecraft/block/BlockState;getFluidState()Lnet/minecraft/fluid/FluidState;", ordinal=0, shift=At.Shift.BEFORE)}, locals=LocalCapture.CAPTURE_FAILHARD)
    private void updateFlagCounters(int x, int y, int z, class_2680 newState, boolean lock, CallbackInfoReturnable<class_2680> cir, class_2680 oldState) {
        int i;
        int prevFlags = ((BlockStateFlagHolder)oldState).getAllFlags();
        int flags = ((BlockStateFlagHolder)newState).getAllFlags();
        int flagsXOR = prevFlags ^ flags;
        while ((i = Integer.numberOfTrailingZeros(flagsXOR)) < 32) {
            int n = i;
            this.countsByFlag[n] = (short)(this.countsByFlag[n] + (1 - ((prevFlags >>> i & 1) << 1)));
            flagsXOR &= ~(1 << i);
        }
    }

    @Inject(method={"fromPacket(Lnet/minecraft/network/PacketByteBuf;)V"}, at={@At(value="RETURN")})
    private void initCounts(class_2540 packetByteBuf, CallbackInfo ci) {
        this.method_12253();
    }
}

