Files
Fishing2/Assets/Obi/Scripts/Common/Backends/Compute/DataStructures/ComputePrefixSum.cs
2025-11-10 00:08:26 +08:00

82 lines
3.1 KiB
C#

using System.Collections;
using System.Collections.Generic;
using Obi;
using UnityEngine;
namespace Obi
{
public class ComputePrefixSum
{
private ComputeShader scanShader;
private int scanInBucketKernel;
private int scanAddBucketResult;
private List<GraphicsBuffer> blockSums = new List<GraphicsBuffer>();
private List<GraphicsBuffer> prefixBlockSums = new List<GraphicsBuffer>();
private int inputSize;
private const int threadsPerGroup = 512;
public ComputePrefixSum(int inputSize)
{
scanShader = Resources.Load<ComputeShader>("Compute/Scan");
scanInBucketKernel = scanShader.FindKernel("ScanInBucketExclusive");
scanAddBucketResult = scanShader.FindKernel("ScanAddBucketResult");
this.inputSize = inputSize;
// intermediate auxiliary buffers to store prefix sum of partial block sums:
int c = inputSize;
while (c > 1)
{
c = (c + threadsPerGroup - 1) / threadsPerGroup;
blockSums.Add(new GraphicsBuffer(GraphicsBuffer.Target.Structured, c, 4));
prefixBlockSums.Add(new GraphicsBuffer(GraphicsBuffer.Target.Structured, c, 4));
}
}
public void Dispose()
{
foreach (var sums in blockSums)
if (sums != null && sums.IsValid())
sums.Dispose();
blockSums.Clear();
foreach (var sums in prefixBlockSums)
if (sums != null && sums.IsValid())
sums.Dispose();
prefixBlockSums.Clear();
}
public void Sum(GraphicsBuffer input, GraphicsBuffer result)
{
if (input.count != inputSize)
return;
Sum(input, result, input.count, 0);
}
private void Sum(GraphicsBuffer input, GraphicsBuffer result, int count, int level)
{
int groups = (count + threadsPerGroup - 1) / threadsPerGroup;
// calculate partial prefix sums, one per block:
scanShader.SetInt("count", count);
scanShader.SetBuffer(scanInBucketKernel, "_Input", input);
scanShader.SetBuffer(scanInBucketKernel, "_Result", result);
scanShader.SetBuffer(scanInBucketKernel, "_BlockSum", blockSums[level]);
scanShader.Dispatch(scanInBucketKernel, groups, 1, 1);
if (groups <= 1)
return;
// recursively calculate prefix sum of the partial block sums:
Sum(blockSums[level], prefixBlockSums[level], groups, level + 1);
// add the scanned partial block sums to the result:
// (it's important to set the count again, as we just returned from a recursive call).
scanShader.SetInt("count", count);
scanShader.SetBuffer(scanAddBucketResult, "_Input", prefixBlockSums[level]);
scanShader.SetBuffer(scanAddBucketResult, "_Result", result);
scanShader.Dispatch(scanAddBucketResult, groups, 1, 1);
}
}
}