Daniel Lemire's blog

, 5 min read

Sampling efficiently from groups

Suppose that you have to sample a student at random in a school. However, you cannot go into a classroom and just pick a student. All you are allowed to do is to pick a classroom, and then let the teacher pick a student at random. The student is then removed from the classroom. You may then have to pick a student at random, again, but with a classroom that is now missing a student. Importantly, I care about the order in which the students were picked. But I also have a limited memory: there are too many students, and I can’t remember too many. I can barely remember how many students I have per classroom.

If the classrooms contain a variable number of students, you cannot just pick classrooms at random. If you do so, students in classrooms with lots of students are going to be less likely to be picked.

So instead, you can do it as follows. If there are N students in total, pick a number x between 1 and N. If the number x is smaller or equal to the number of students in the first classroom, pick the first classroom and stop. If the number is smaller or equal to the number of students in the first two classrooms, pick the second classroom, and so forth. If you remove the student, you have to account for the fact that the sum (N) is reduced and the number of students in the classroom is being updated.

    while(N > 0) {
            int x = r.nextInt(sum); // random integer in [0,sum)
            int runningsum = 0;
            int z = 0;
            for(; z < runninghisto.length; z++) {
                runningsum += runninghisto[z];
                if(y < runningsum) { break; }
            } 
            // found z
            runninghisto[z] -= 1;
            N -= 1;
    }

If you have many classrooms, this algorithm can be naive because each time you must pick a classroom you may need to iterate over all classroom counts. Suppose you have K classroom: you may need to do N K tasks to solve the problem.

Importantly, the problem is dynamic: the number of students in each classroom is not fixed. You cannot simply precompute some kind of map once and for all.

Can you do better than the naive algorithm? I can craft a “logarithmic” algorithm, one where the number of operations you do is the logarithm of the number of classrooms. I believe that my algorithm is efficient in practice.

So divide your classrooms into two sets of classrooms of size R and S so that R + S = N. Pick an integer in [0,N). If it is smaller than R, decrement R and pick a buffer in the corresponding set. If it is larger than or equal to R, then go with the other set of buffers. So we have reduced the size of the problem in two. We can repeat the same trick recursively. Divide the two subsets of classrooms and so forth. Instead of merely maintaining the total number of students in all of the classrooms, you keep track of the counts in a tree-like manner.

We can operationalize this strategy efficiently in software:

  • Build an array containing the number students in each classroom.
  • Replace every two counts by the sum of two counts: instead of storing the number of students in second classroom, store the number of students in the first two classrooms.
  • Replace every four counts by the sum of the four counts: where the count of the fourth class was, store the sum of the number of students in the first four classrooms.
  • And so forth.

In Java, it might look like this:

int level = 1;
for(;(1<<level) < runninghisto.length; level++) {
  for(int z = 0; 
     z + (1<<level) < runninghisto.length; 
    z += 2*(1<<level)) {
       runninghisto[z + (1<<level)] += runninghisto[z];
  }
}

You effectively have a tree of prefixes which you can then use to find the right classroom without visiting all of the classroom counts. You visit a tree with a height that is the logarithm of the number of classrooms. At each level, you may have to decrement a counter.

        while(N > 0) {
            int y = r.nextInt(sum);
            // select logarithmic time
            level = maxlevel;
            int offset = 0;
            for(; level >= 0; level -= 1) {
                if(y >= runninghisto[offset + (1<<level)]) {
                    runninghisto[offset + (1<<level)] -= 1;
                    offset += (1<<level);
                }
            }
            // found offset
            N -= 1;
        }

I have written a Java demonstration of this idea. For simplicity, I assume that the number of classrooms is a power of two, but this limitation could be lifted by doing a bit more work. Instead of N K operations, you can now do only N log K operations.

Importantly, the second approach uses no more memory than the first one. All you need is enough memory to maintain the histogram.

My Java code includes a benchmark.

Update. If you are willing to use batch processing and use a little bit of extra memory, you might be able to do even better in some cases. Go back to the original algorithm. Instead of picking one integer in [0,N), pick C for some number C. Then find out in which room each of these k integers fit, as before, decrementing the counts. If you pick the C integers in sorted order, a simple sequential scan should be enough. For C small compared to N, you can pick the C integers in linear time using effective reservoir sampling. Then you can shuffle the result again using Knuth’s random shuffle, in linear time. So you get to do about N K / C operations. (Credit: Travis Downs and Richard Startin.)

Further reading: Hierarchical bin buffering: Online local moments for dynamic external memory arrays, ACM Transactions on Algorithms.

Follow-up: Sampling Whisky by Richard Startin