Files
2026-02-21 16:45:37 +08:00

74 lines
1.3 KiB
C#

using System;
using System.Collections.Generic;
namespace Mtree
{
public class WeightedRandom
{
private float[] probs;
private int[] alias;
public WeightedRandom(float[] weights)
{
PrepareWeights(weights);
}
private void PrepareWeights(float[] weights)
{
Queue<int> queue = new Queue<int>();
Queue<int> queue2 = new Queue<int>();
int num = weights.Length;
alias = new int[num];
probs = new float[num];
float[] array = new float[num];
for (int i = 0; i < weights.Length; i++)
{
array[i] = weights[i] * (float)num;
if (array[i] >= 1f)
{
queue2.Enqueue(i);
}
else
{
queue.Enqueue(i);
}
}
while (queue.Count > 0 && queue2.Count > 0)
{
int num2 = queue.Dequeue();
int num3 = queue2.Dequeue();
probs[num2] = array[num2];
alias[num2] = num3;
array[num3] = array[num3] + array[num2] - 1f;
if (array[num3] < 1f)
{
queue.Enqueue(num3);
}
else
{
queue2.Enqueue(num3);
}
}
while (queue2.Count > 0)
{
probs[queue2.Dequeue()] = 1f;
}
while (queue.Count > 0)
{
probs[queue.Dequeue()] = 1f;
}
}
public int GetRandomIndex(Random random)
{
int num = random.Next(0, probs.Length);
if (random.NextDouble() < (double)probs[num])
{
return num;
}
return alias[num];
}
}
}