Wednesday, January 25, 2017

Solitaire of sorts

Suppose you had a well shuffled standard deck of 52 cards. From the top of the pack you turned over each card one at a time. As you turned over the first card, you called out 'Ace', then 2 for the next, followed by 3, 4, 5, 6, 7, 8, 9, 10, Jack, Queen, King. After the 14th card you called out 'Ace' again followed by 2, 3, 4 and so on till you reached the end.
What would be the probability that none of the cards that were turned over had the value (rank) that you called out?

Before working it, lets consider a related, but simpler question. Suppose you had 3 cards containing values 1,2,3. After you shuffled them, what would be the probability that none of the 3 were in the same position as they were at the start? In mathematical speak you'd ask what is the probability of a derangement.
In this case we can look at all 3! ( i.e 6 ) permutations. And we see that 2 out of 6 are derangements. So the probability of a derangement is 1/3.
You might be tempted to say that for each card, the probability that it is not in its original position is 2/3 and so for the 3 cards the total probability of a drangement is that to the power of 3, i.e. 8/27.
However the flaw with that method is that the probabilities are not independent.

Returning now to that deck of 52 with 4 suits and 13 ranks. There is of course more than one way to work it out. One method would be to write a Monte Carlo simulation. I've done that and the code is below.
Here is a sample output after 100 million simulations:

Found 1623551 survivors out of 1.0E8
So the survival probability is estimated to be: 0.01623551
with a standard deviation of: 1.2638005465673727E-5
Time elapsed: 77.742 seconds


So the probability of surviving to the end without getting any cards right is approximately 1.62%

OK, OK, MC is useful, but not very satisfying. You might think that you could just write a program to work out all 52! permutations and then work out the answer. But alas 52! is a very big number. So if you want an answer before you die, then it might be a good idea to try a different approach.

Well there is another way... In fact I'm sure there any many other ways. I present here a method that I used. Suppose we have slots numbered 1 to 13 and for simplicity we'll number the cards 1 to 13. At the start we have 13 different ranks and each have 4 cards and there are 4 slots available for each of the ranks.
We could represent that as a table of Ranks:
Cards0  1  2  3  4  
Slots
000000
100000
200000
300000
4000013

Now suppose we pick one card, then we'll have 12 card ranks with 4 cards remaining
and we'll have 1 card rank that has 3 cards remaining, but still 4 slots.
So in our table of ranks we'll represent that as:

Cards0  1  2  3  4  
Slots
000000
100000
200000
300000
4000112

But we'll have to put the card in one of the available slots. Of the 52 slots 48 are allowed.
When we do that we'll have
 1 rank with 4 slots and 3 cards remaining
 1 rank with 3 slots and 4 cards remaining
12 ranks with 4 slots and 4 cards remaining

In our table we represent that as:

Cards0  1  2  3  4  
Slots
000000
100000
200000
300001
4000112

Using such a table with the probabilities of the transitions, we can write some code that recursively solves the problem.

This is what the code looks like:

import java.io.PrintWriter;

// check if the back slash char is causing problems

public class Calculator {
       
    public static void main ( String[] args){
       
        long startTime = System.currentTimeMillis();

        System.out.println("Starting...");
       
        Calculator calc = new Calculator();

        calc.calc(); // calc() is the main calculator, calcDebug() is for debugging
       
        long   endTime     = System.currentTimeMillis();
        double elapsedSec  = ((double) endTime - (double) startTime) * 0.001d;
       
        System.out.println("\nElapsed time: " + Double.toString(elapsedSec) + " sec.");       
        System.out.println("Finished.");
    }
    ///////////////////////////////////////////////////////////////////
    public void calc(){
        System.out.println("Doing calc.");

        int suits = 4;
        int ranks = 9; // Total number of cards will be: suits * ranks
       
        State s = new State(suits, ranks);
        double survivalProb = s.getSurvivalProb();
       
        String key = s.getKey();
        String resStr = "\nFor key: " + key + " found prob to be: "  + String.format("%1.14f", survivalProb);
        System.out.println(resStr);
       
        String pathToFile = "C:/temp/res_" + key + ".txt";
        writeToFile(pathToFile, resStr );

    }
    ///////////////////////////////////////////////////////////////
    public void calcDebug(){
        System.out.println("Doing calc2.");
        //           "0000000000111111111122222"
        //           "0123456789012345678901234"
        String key = "1001010000000000100000000";

        State s = new State(key, 0);
        double survivalProb = s.getSurvivalProb();
       
        String resStr = "For key: " + key + " found prob to be: "  + String.format("%1.14f", survivalProb);
        System.out.println(resStr);
       
        String pathToFile = "C:/temp/res_" + key + ".txt";
        writeToFile(pathToFile, resStr );
    }
    ///////////////////////////////////////////////////////////////
    public void writeToFile( String pathToFile, String str){
        try{
            PrintWriter writer = new PrintWriter(pathToFile, "UTF-8");
            writer.println(str);
            writer.close();
        } catch(Exception e){
            System.out.println("Caught exception: " + e.getMessage());
        }
    }
    /*  Prob(1,4)  = 0.375
     *  Prob(2,2)  = 0.166667
     *  Prob(4,4)  = 0.011869
     *  Prob(4,8)  = 0.014967   ( takes 11.6 secs )
     *  Prob(4,13) = 0.016232
     */
   
}

///////////////////////////////////////////////////////////////////////////////        ///////////////////////////////////////////////////////////////////////////////        ///////////////////////////////////////////////////////////////////////////////        


import java.util.Random;
import java.util.TreeMap;

// prob survival = sum ( numCardsLikeThis * numSafeLocations / TotalslotLocations
//                         * ProbSurvival( next))

// The key is a string 25 chars long,

public class State {
   private static final boolean            m_debug               = false; 
   private static final long               m_seed                = 1;
   private static final int                m_stepsBetweenCaching = 0; //

   private static int                      m_numSuits;
   private static TreeMap  m_map;
   private static Random                   m_rnd;

   private        String                   m_key;  
   private        double                   m_survivalProb;  
   private        int                      m_totalSlots;
   private        int                      m_depth;
  

   ///////////////////////////////////////////////////////
   State ( int suits, int ranks){
      
       m_depth    = 0;
       m_numSuits = suits;
       String      zeros = new String(new char[(suits+1) * (suits+1)]).replace("\0", "0");
       String      key   = adjustKey(suits, suits, zeros, ranks);
       initialize( key );    
      
       if ( m_debug) printKey  ( key );
   }
   ///////////////////////////////////////////////////////
   State( String key, int depth){
       m_depth = depth;
       initialize(key);
   }
   ///////////////////////////////////////////
   private void initialize(String key){
      
       if ( m_rnd == null)
            m_rnd =  new Random(m_seed); // formerly had: ThreadLocalRandom.current();

      
       if ( m_map == null)
            m_map =  new TreeMap();
      
       if( m_numSuits == 0){
           double sqrt = Math.sqrt((double) key.length());
           m_numSuits  = (int) Math.round(sqrt - 0.5f) -1;
           if ( m_debug)  {
              System.out.println(   "Have key of length: " + Integer.toString(key.length())
                                     + " and num suits: "     + Integer.toString(m_numSuits)   );
           }
       }  
      
       m_key          = key;
      
       m_survivalProb = -1;   // This indicates that it has not been calculated.      
       m_totalSlots   = calcTotalSlots();   
      
       if ( m_debug)  printKey(key);
   }
   ///////////////////////////////////////////////////////
   private int calcTotalSlots(){
       int sum = 0;
       for ( int slot = 0 ; slot <= m_numSuits; slot++){
           sum += getNumSlots(slot);
       }
       if ( m_debug)  System.out.println("Key: " + m_key + ", total number of slots found is: " + Integer.toString(sum));
       return sum;
   }
   ///////////////////////////////////////////////////////
   String getKey() { return m_key; }
   ///////////////////////////////////////////////////////
   double calcSurvivalProb(){
      
       if ( m_totalSlots == 0)
           return 1.0; // if we reach the end and there are no slots left, then we have survived.
      
       double prob = 0;
      
       for     ( int slot = 0; slot <= m_numSuits; slot++){
           for ( int card = 1; card <= m_numSuits; card++){
              
                  double probForCard = calcProbForCard(slot, card);
                  prob += probForCard;
                 
                  if ( m_depth < 3) {
                      System.out.println(  "Depth: "       + Integer.toString(m_depth)
                                         + ", slot: "      + Integer.toString(slot)
                                         + ", card: "      + Integer.toString(card)
                                         + ", num slots: " + Integer.toString(m_totalSlots)
                                         + ", prob: "      + Double.toString (probForCard)   );
                  }
           }
       }
      
       if ( m_debug) {
          printKey(m_key);
          System.out.println("Found prob to be: " + String.format("%1.14f",prob) + "\n");
       }      
       return prob;
   }
   ///////////////////////////////////////////////////////
   public double calcProbForCard(int slot, int card){
      
       if ( card == 0)
           return 0;

       int count = getCount(slot, card);
      
       if ( count == 0)
           return 0;
      
       int numCards = count * card;
      
       double probOfCardChoice = (double) numCards / (double) m_totalSlots;
      
       double probOfSlot = 0.0;
      
       for     ( int destSlot = 1; destSlot <= m_numSuits; destSlot++){
           for ( int destCard = 0; destCard <= m_numSuits; destCard++){
              
                  int destCount = getCount(destSlot, destCard);
                  if ( (slot == destSlot) && (card == destCard))
                      destCount--;  // a card cannot go into its own slot.
                 
                  int availableSlots = destSlot * destCount;

                  if ( availableSlots > 0){
                      double probOfContinuedSurvival = getProbOfContinuedSurvival(slot, card, destSlot, destCard);
                     
                      if ( probOfContinuedSurvival > 0) {          
                          probOfSlot += probOfContinuedSurvival * probOfCardChoice *(double) availableSlots /(double) m_totalSlots;
                          if ( probOfSlot > 1.0000001) { // We don't have >= 1.0 to allow a small rounding error
                              printKey(m_key);
                              System.out.println("ERROR: Prob is too big: " + Double.toString(probOfSlot));
                          }
                      }
                  }
           }
       }
       if ( m_debug) System.out.println("Found prob of slot to be: " + Double.toString(probOfSlot));
       return probOfSlot;
   }
   ///////////////////////////////////////////////////////
   double getProbOfContinuedSurvival(int sourceSlot, int sourceCard, int destSlot, int destCard){
      
          String keyForChosenCard   = adjustKeyForTakenCard (sourceSlot, sourceCard, m_key);
          String keyForContinuation = adjustKeyForCardInSlot(destSlot,   destCard,   keyForChosenCard);   
         
          // check if state obj is in map
          Double contProbDouble = m_map.get(keyForContinuation);
          double contProb;
         
          if ( contProbDouble == null ){
              State contState = new State(keyForContinuation, m_depth + 1);
              contProb        = contState.getSurvivalProb();
             
              if (    ( m_stepsBetweenCaching == 0 )
                   || (m_rnd.nextInt(m_stepsBetweenCaching) == 0 )) { // we'll only insert one in n key,value pairs to the map.
                 contProbDouble  = new Double( contProb);
                 m_map.put(keyForContinuation, contProbDouble); // we store it for later use
             
                   if ( m_map.size() % 10000 == 0)
                    System.out.println(   "Map elm: "+ Integer.toString(m_map.size())
                                       + ", key: "  + keyForContinuation
                                      + ", prob: " + Double.toString(contProbDouble));
             
                if ( m_debug)
                    System.out.println(  "Added new key to map: " + keyForContinuation
                                          + " that is item: " + Integer.toString(m_map.size()) );
              }
          } else {
              contProb = contProbDouble.doubleValue();
          }

          if ( m_debug) {
             System.out.println("Key: " + keyForContinuation + ", found cont prob: " + Double.toString(contProb));
             printKey(keyForContinuation);
          }
          return contProb;
   }
   ///////////////////////////////////////////////////////
   double getSurvivalProb() {
      
       if ( m_survivalProb == -1) // i.e. not yet set
            m_survivalProb = calcSurvivalProb();
             
       return m_survivalProb;
   }
   //////////////////////////////////////////////////////
   int getCount(int slot, int card){
             
       return ( getCount(slot, card, m_key));            
   }
   //////////////////////////////////////////////////////
   static int getCount(int slot, int card, String key){
      
       int    index = getIndex(slot, card);
       // System.out.println("Key: " + key + ", index: " + Integer.toString(index));
       String c     = key.substring(index, index + 1);
      
       return charToInt(c);
       // return ( Integer.parseInt(c, 16));            
   }
   //////////////////////////////////////////////////////
   public int getNumSlots(int slot){
      
       if ( slot == 0)
           return 0;
      
       int sum = 0;

       for ( int card = 0; card <= m_numSuits; card++)
           sum += slot * getCount(slot, card);
      
       return sum;      
   }
   //////////////////////////////////////////////////////
   int getTotalSlots(){
       return m_totalSlots;
   }
   //////////////////////////////////////////////////////
   public static void printKey(String key){
   
       if ( key == null){
           System.out.println("Key is null.");
           return;
       }

       System.out.println("key: " + key);

       String header = "   Card ";
      
       for( int j = 0; j <= m_numSuits; j++){
           header += Integer.toString(j) + " ";
       }
      
       System.out.println(header);
      
       for ( int slot = 0; slot <= m_numSuits; slot++){
           String str = "Slot " + Integer.toString(slot) + ": ";
          
           for ( int card = 0; card <= m_numSuits; card++){
               int index = getIndex(slot, card);
               str += key.substring(index, index+1) + " ";
           }
           System.out.println(str);
       }   
   }
   //////////////////////////////////////////////////////
   public static int getIndex(int slot, int card){
       int index = slot * (m_numSuits + 1) + card;
       /*
       System.out.println(    "For slot: "   + Integer.toString(slot)
                           + " and card: "   + Integer.toString(card)
                           + " have index: " + Integer.toString(index));
       */
       return (index);
   }
   //////////////////////////////////////////////////////
   public static String adjustKeyForTakenCard(int slot, int card, String key){
      
       if( card == 0)
           return null; // no cards to give

       String decrementedCurrent = adjustKey(slot, card   , key,                -1);
       return                      adjustKey(slot, card -1, decrementedCurrent, +1);
   }
   //////////////////////////////////////////////////////
   public static String adjustKeyForCardInSlot(int slot, int card, String key){
       if ( slot == 0)
           return null; // have no slot
      
       String adjKey  = adjustKey(slot    , card, key   , -1);
       adjKey         = adjustKey(slot - 1, card, adjKey, +1);          
             
       for( int i = m_numSuits; i > 1; i--){
          
           int countWithZeroCards = getCount(i,0, adjKey );
           if( countWithZeroCards > 0){
               adjKey = adjustKey(i,0, adjKey, -countWithZeroCards);
               adjKey = adjustKey(1,0, adjKey, countWithZeroCards * i);
           }

           int countWithZeroSuits = getCount(0,i, adjKey );
           if( countWithZeroSuits > 0){
               adjKey = adjustKey(0, i, adjKey, -countWithZeroSuits);
               adjKey = adjustKey(0, 1, adjKey,  countWithZeroSuits * i);
           }
       }
      
       int zeroZeroCount         = getCount(0,0, adjKey);
      
       return adjustKey( 0,0, adjKey, - zeroZeroCount);
   }
   //////////////////////////////////////////////////////
   public static String adjustKey(int slot, int card, String key, int adj){
      
       if ( key == null)
           return null;
      
       int count = getCount(slot, card, key);
      
       if ( 0 > count + adj){
           System.out.println("Cannot adjust below zero for slot" );
           return null;
       /*      
       } else if ( count + adj > 15){
           System.out.println("Cannot adjust above 1 char hex limit.");
           return null;       
        */  
       } else {
           int    index   = getIndex(slot, card);
          
           String start   = key.substring(0, index );
           String end     = key.substring(index + 1);
          
           String current = intToChar(count+ adj);
           return ( start + current + end);          
       }          
   }
   //////////////////////////////////////////////////////
   public static String intToChar(int i ){
       char c = (char)(i+48);
      
       return "" + c;
   }
   //////////////////////////////////////////////////////
   public static int charToInt(String s){
       char c = s.charAt(0);
       int  i = (int)c - 48;
      
       if ( m_debug) System.out.println("Char: " + s + " is deemed to be: " + Integer.toString(i));
      
       return (int) i;
   }
   /////////////////////////////////////////////////////
}
 

///////////////////////////////////////////////////////////////////////////////       





And here's the Java code for the Monte Carlo:

import java.util.Random;

public class RanksAndSuitsSurvival {

    private final long   m_simsPerBatch = 1_000_000L;
    private final int    m_numBatches   = 100;

    private final int    m_numSuits     = 4;   // should be  4
    private final int    m_numRanks     = 13;  // should be 13
   
    private       int    m_numCards;
    private       int[]  m_hand;
   
    private       Random m_rnd;
    private final int    m_rndSeed      = 1;
    ///////////////////////////////////////////////////////////////////////////////
    public static void main ( String[] args){
               
        System.out.println("Started...");
       
        RanksAndSuitsSurvival  rass = new RanksAndSuitsSurvival();
        rass.calc();
       
        System.out.println("\nFinished");   
    }
    ///////////////////////////////////////////////////////////////////////////////
    public RanksAndSuitsSurvival(){
       
        m_rnd      = new Random(m_rndSeed);
       
        m_numCards = m_numSuits * m_numRanks;
        m_hand     = new int[m_numCards];
       
        for ( int i = 0; i < m_numCards; i++)
            m_hand[i] = i % m_numRanks;
    }
    ///////////////////////////////////////////////////////////////////////////////
    public void calc(){
       
        long startTime = System.currentTimeMillis();
               
        long survivors = 0;
       
        // The only reason for the nested 'for' loops rather than a single 'for' loop
        // is because we wanted to print out the progress intermittently and we didn't
        // want to introduce another 'if' inside the main 'for' loop for reasons of speed.
        for ( int batch = 1; batch <= m_numBatches; batch++ ) {
       
            for ( long counter = 0; counter < m_simsPerBatch; counter ++){
       
                if (survived())  
                    survivors++;
            }
           
            System.out.println("Have now completed batch: " + Integer.toString(batch) + ", after "
                               + Double.toString((double)( System.currentTimeMillis() - startTime) * 0.001) + " seconds");
        }
   
        double numSims = (double) m_numBatches * (double) m_simsPerBatch;
        double prob    = (double) survivors              / numSims;
        double stdDev  = Math.sqrt( (1.0 - prob ) * prob / numSims);

        long   endTime = System.currentTimeMillis();
       
        printoutResults(survivors, numSims, prob, stdDev, endTime - startTime);
    }
    /////////////////////////////////////////////////////////////////////////////////
    public void printoutResults(long survivors, double numSims, double prob, double stdDev, long elapsedTimeMS){
   
        System.out.println("\nFound " + Long.toString(survivors) + " survivors out of " + Double.toString(numSims) );
        System.out.println("So the survival probability is estimated to be: "           + Double.toString(prob     ) );
        System.out.println("with a standard deviation of:                   "           + Double.toString(stdDev   ) );
           
        System.out.println("Time elapsed: " + Double.toString((double)( elapsedTimeMS) * 0.001d) + " seconds");
       
    }
    ///////////////////////////////////////////////////////////////////////////////
    public boolean survived(){
        
         shuffleHand();     // will shuffle m_hand
         return checkHand();
    }
    ///////////////////////////////////////////////////////////////////////////////   
    // Implementing Fisher–Yates shuffle
    public void shuffleHand(){         
        for (int i = m_hand.length - 1; i > 0; i--)  {
          int index     = m_rnd.nextInt(i + 1);
          // Simple swap
          int a         = m_hand[index];
          m_hand[index] = m_hand[i];
          m_hand[i]     = a;
        }
    }
    ////////////////////////////////////////////////////////////////////
    // Returns true when the hand 'survived'.
    public boolean checkHand(){      
                   
           for (int i = 0; i < m_hand.length; i++){
              
               if ( ((m_hand[i] - i) % m_numRanks) == 0 )
                      return false; // did not survive
           }
              
           return true;  // survived              
    }     
    ///////////////////////////////////////////////////////////////////////////////   
}