int find_next_power_of_2(unsigned int x){
    unsigned int power = 1;
    while (power < x) power <<= 1;
    return power;
}
__global__ void cumulative_sum(unsigned int* const d_in,
                               unsigned int * d_out,
                               unsigned int * d_block_sums,
                               int nElems,
                               int current_bit){
    // Sum-sum reduce
    int tidx = threadIdx.x;
    int block_pos = blockDim.x * blockIdx.x;
    int global_pos = tidx + block_pos;
    if (global_pos >= nElems) {
        d_out[global_pos] = 0; // On est dans la partie artificiellement mise à 0 pour atteindre une puissance de 2.
    } else {
        if (current_bit == -1){ // Le -1 permet de faire l'algorithme classique
            d_out[global_pos] = d_in[global_pos];
        } else { // Ici, on calcule le prédicat du compact en prenant le current_bit ième bit.
            d_out[global_pos] = 1 - ((d_in[global_pos] & (1 << current_bit)) >> current_bit); 
        }
    }
    __syncthreads();
    int last = 0;  // On on fait un scan exclusif, on doit se rappeler du dernier élément pour le rajouter à la somme de tous les éléments
    if (tidx == blockDim.x - 1){  // On travaille bien sur des blocks, et non sur tout le tableau
        last = d_out[global_pos];
    }
    for (int i = 2; i < blockDim.x; i <<= 1){
        if (tidx % i == i - 1) {
            d_out[global_pos] += d_out[global_pos - (i >> 1)];
        }
        __syncthreads();
    }
    // Downsweep
    if (tidx == blockDim.x - 1){
        d_out[global_pos] = 0;
    }
    __syncthreads();
    for (int i = blockDim.x; i > 0; i >>= 1){
        int current = d_out[global_pos];
        __syncthreads();
        int next_shift = i >> 1;
        if (tidx % i == i - 1) {
            d_out[global_pos - next_shift] = current;
        } else if ((tidx + next_shift) % i == i - 1) {
            d_out[global_pos + next_shift] += current;
        }
        __syncthreads();
    }
    if (tidx == blockDim.x - 1){  // On met à jour le tableau de la somme sur tout le block
        d_block_sums[blockIdx.x] = d_out[global_pos] + last;
    }
}
__global__ void add_block_sum(unsigned int* d_scan,
                              unsigned int* d_block_sum,
                              int nElems) {
    int tidx = threadIdx.x;
    int block_pos = blockDim.x * blockIdx.x;
    int global_pos = tidx + block_pos;
    if (global_pos >= nElems) return;
    d_scan[global_pos] += d_block_sum[blockIdx.x];
}
void compact(unsigned int* const d_inputVals,
             unsigned int* const d_scan,
             const size_t numElems,
             int current_bit){
    unsigned int threads_per_block = 1024;
    int scan_size = find_next_power_of_2(numElems); // On fait bien attention d'avoir des puissances de 2.
    unsigned int n_blocks_scan = scan_size / threads_per_block;
    int block_sum_size = find_next_power_of_2(n_blocks_scan);
    unsigned int* d_block_sum;
    checkCudaErrors(cudaMalloc((void**) &d_block_sum,  block_sum_size * sizeof(unsigned int)));
    checkCudaErrors(cudaMemset((void**) d_block_sum, 0, block_sum_size * sizeof(unsigned int)));
    unsigned int* d_block_sum_dummy;  // Ne sert à rien, mais on va devoir le donner en entrée plus tard
    checkCudaErrors(cudaMalloc((void**) &d_block_sum_dummy,  sizeof(unsigned int)));
    checkCudaErrors(cudaMemset((void**) d_block_sum_dummy, 0, sizeof(unsigned int)));
    cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
    // On fait le scan par block
    cumulative_sum<<<n_blocks_scan, threads_per_block>>>(d_inputVals, d_scan, d_block_sum, numElems, current_bit);
    cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
    // On fait un scan sur les totaux de chaque block
    cumulative_sum<<<1, block_sum_size>>>(d_block_sum, d_block_sum, d_block_sum_dummy, block_sum_size, -1);
    cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
    // On remet les blocks au bon niveau
    add_block_sum<<<n_blocks_scan, threads_per_block>>>(d_scan, d_block_sum, numElems);
    cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
    checkCudaErrors(cudaFree(d_block_sum));
    checkCudaErrors(cudaFree(d_block_sum_dummy));
}
void your_sort(unsigned int* const d_inputVals,
               unsigned int* const d_inputPos,
               unsigned int* const d_outputVals,
               unsigned int* const d_outputPos,
               size_t numElems)
{
    unsigned int threads_per_block = 1024;
    unsigned int n_blocks = numElems / threads_per_block + 1;
    printf("numElems: %d\nNum blocks: %d\n", (int) numElems, n_blocks);
    unsigned int numBins = 2;
    unsigned int *d_hist;
    unsigned int  h_hist[2];
    checkCudaErrors(cudaMalloc((void**) &d_hist,  numBins * sizeof(unsigned int)));
    int scan_size = find_next_power_of_2(numElems);
    unsigned int* d_scanned;
    checkCudaErrors(cudaMalloc(&d_scanned, scan_size * sizeof(unsigned int)));
    for (int i = 0; i < sizeof(unsigned int) * 8; i++){
        checkCudaErrors(cudaMemset((void**) d_hist, 0, numBins * sizeof(unsigned int)));
        histogram_kernel<<<n_blocks, threads_per_block>>>(i, d_hist, d_inputVals, numElems);
        cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
        compact(d_inputVals, d_scanned, numElems, i);
        cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
        checkCudaErrors(cudaMemcpy(h_hist, d_hist, 2 * sizeof(unsigned int), cudaMemcpyDeviceToHost));
	cudaDeviceSynchronize(); checkCudaErrors(cudaGetLastError());
    } ...