module Shuffling

# Experiment based on
# D. Bayer, P. Diaconis: Trailing the Dovetail Shuffle to its Lair
#   The Annals of Applied Probability, Vol. 2(2), pp. 294-313, 1992.

function cut(deck)
    n = length(deck)
    i = rand(1:n)
    vcat(deck[i+1:n], deck[1:i])
end

function riffle_old(deck)
    cut_error = 3
    riffle_error = 2
    n = length(deck)
    half = n ÷ 2
    i = rand(half-cut_error:half+cut_error)
    a, b = deck[1:i], deck[i+1:n]
    result = copy(deck)
    i1, i2, index = 1, 1, 1
    while true
        c1, c2 = rand(0:riffle_error), rand(0:riffle_error)
        if i1 + c1 > length(a)
            c1 = length(a) - i1
        end
        if i2 + c2 > length(b)
            c2 = length(b) - i2
        end
        c1 < 0 && c2 < 0 && break
        if c1 >= 0
            result[index:index+c1] = a[i1:i1+c1]
            index += c1 + 1
            i1 += c1 + 1
        end
        if c2 >= 0
            result[index:index+c2] = b[i2:i2+c2]
            index += c2 + 1
            i2 += c2 + 1
        end
    end
    result
end

function riffle(deck)
    cut_error = 3
    n = length(deck)
    half = n ÷ 2
    i = rand(half-cut_error:half+cut_error)
    a, b = deck[1:i], deck[i+1:n]
    result = copy(deck)
    i1, i2, index = 1, 1, 1
    while index <= n
        p = length(a) - i1 + 1
        p /= p + length(b) - i2 + 1
        if rand(Float64) < p
            result[index] = a[i1]
            i1 += 1
        else
            result[index] = b[i2]
            i2 += 1
        end
        index += 1
    end
    result
end

function guess(deck)
    n = length(deck)
    maxv, maxi = 0, 0
    for i in 1:n
        index = findfirst(x -> x == i, deck)
        prev = findfirst(x -> x == mod1(i - 1, n), deck)
        next = findfirst(x -> x == mod1(i + 1, n), deck)
        d1 = mod(index - prev, n)
        d2 = mod(next - index, n)
        # d = d1 + d2 - 1       # as in the paper
        d = min(d1, d2)         # better!
        if d > maxv
            maxv = d
            maxi = i
        end
    end
    maxi
end

function trick(n = 52)
    # Shuffle two or three times, and cut
    deck = collect(1:n)
    deck = riffle(deck)
    deck = cut(deck)
    deck = riffle(deck)
    deck = cut(deck)
    deck = riffle(deck)

    # Put the top card somewhere in the central half
    cut_error = n ÷ 4
    half = n ÷ 2
    i = rand(half-cut_error:half+cut_error)
    card = deck[1]
    deck[1:half] = deck[2:half+1]
    deck[half+1] = card

    # Cut a final time and guess
    deck = cut(deck)
    card == guess(deck)
end

function test(n)
    count(trick() for _ in 1:n) / n
end

end