from __future__ import division, print_function

# Iterated Prisoner's Dilemma
# Doug Blank, Fall 2013
# CS361 Emergence

from conx import SRN
import random
import copy
import random
import math
import time

rewards = {"DC": 5,
           "CC": 3,
           "DD": 1,
           "CD": 0}

def random_player(history1, history2):
    """ A Random Player """
    return random.choice(["C", "D"])

def always_defect(history1, history2):
    """ Always defects """
    return "D"

def always_cooperate(history1, history2):
    """ Always cooperates """
    return "C"

def tit_for_tat(history1, history2):
    """ Tit for tat, initial C """
    if len(history2) > 0:
        return history2[-1]
    else:
        return "C"

def tit_for_tat2(history1, history2):
    """ Delayed tit for tat, initial C """
    if len(history2) == 49:
        return "D"
    elif len(history2) > 1:
        return history2[-2]
    else:
        return "C"

def besan_and_emily(history1, history2):
    # basically if the player is always defecting we need to defect too
    if len(history1) > 25:
        return "D"
    else:
        return "C"
    # else the player is not always defecting then we can cooperate
    # and have a better chance at getting more points


def allie_grace_dan(history1, history2):
    if len(history1) <= 4:
        return "C"
    elif len(history1) >= 48:
        return "D"
    else:
        defect = 0
        for i in range(len(history2)):
            if history2[i] == "D":
                defect = defect +1

        avg =  defect / len(history2)
        if avg > 0.5:
            return "D"
        else:
            return "C"

def michelle_gabby(history1, history2):
    x = random.choice(["A", "A", "B"])
    if x == "A":
        return "D"
    else:
        return random.choice(["D","D","D","D","D","D","C"])

def michelle_gabby2(history1, history2):
    x = random.choice(["A", "A", "B"])
    if history2[-4:] == ["D","D","D","D"]:
        return "D"
    elif x == "A":
        return "D"
    else:
        return random.choice(["D","D","D","D","D","D","C"])

def iris_sam(history1, history2):
    """TFT mod and cruel"""
    if len(history2) > 0 and len(history2) < 48:
        return history2[-1]
    else:
        return "D"

def iris_sam2(history1, history2):
    """TFT mod and  nice"""
    if len(history2) > 0:
        return history2[-1]
    elif len(history2) > 48:
        return "D"
    else:
        return "C"

def allie_grace_dan2(history1, history2):
    if len(history1) == 0: # we're at timestep 1
        return "C"
    elif len(history1) >= 48: # we're at timestep 49 or 50
        return "D"
    else:
        return history2[-1] # copy ipd2's most recent move

def ipd_network(history1, history2):
    network.unArrayify(GENE)
    output = network.propagate(input=[0.5, 0.5], context=[0.5, .5, .5, .5, .5])
    for a, b in zip(history1, history2):
        i1, i2 = pattern[a], pattern[b]
        output = network.propagate(input=[i1, i2], context=network["hidden"].activation)
    if output[0] < .5:
        return "C"
    else:
        return "D"

def play(ipd1, ipd2):
    score1 = 0
    score2 = 0
    history1 = []
    history2 = []
    for i in range(50):
        move1 = ipd1(history1, history2)
        move2 = ipd2(history2, history1)
        score1 += rewards[move1 + move2]
        score2 += rewards[move2 + move1]
        history1.append(move1)
        history2.append(move2)
    return (score1, score2)

def compute_fitness(gene):
    global GENE
    GENE = gene
    total = {}
    wins = {}
    for p1 in [ipd_network]:
        for p2 in players:
            n1, n2  = p1.__name__, p2.__name__
            print(n1, "vs", n2, "...")
            s1, s2 = play(p1, p2)
            if s1 > s2:
                wins[n1] = wins.get(n1, 0) + 1
            elif s2 > s1:
                wins[n2] = wins.get(n2, 0) + 1
            else: # tie
                wins[n1] = wins.get(n1, 0) + 1
                wins[n2] = wins.get(n2, 0) + 1
            total[n1] = total.get(n1, 0) + s1
            total[n2] = total.get(n2, 0) + s2

    print("%20s%10s%10s" % ("Team", "Score", "Wins"))
    print("*" * 70)
    for name in sorted(total, key=lambda k: total[k], reverse=True):
        print("%20s%10s%10s" % (name, total.get(name, 0), wins.get(name, 0)))
    print("*" * 70)
    print()
    return total["ipd_network"]

def make_gene():
    return [random.random() * 2 - 1 for x in range(SIZE)]

def make_population(pop_size):
    """
    Make a population, a list of [score, gene]'s.
    """
    population = []
    for i in range(pop_size):
        population.append([0, make_gene()])
    return population

def select(population, fitness_sum):
    """
    Select a biased index from the population.
    """
    partsum = 0.0
    spin = random.random() * fitness_sum
    for index in range(len(population)):
        fitness = population[index][0]
        partsum += fitness
        if partsum >= spin:
            break
    return index

def crossover(gene1, gene2):
    """
    Pick a crossover point, mate, and return two children.
    """
    point = random.random() * len(gene1)
    child1 = []
    child2 = []
    for i in range(len(gene1)):
        if i < point:
            child1.append(gene1[i])
            child2.append(gene2[i])
        else:
            child1.append(gene2[i])
            child2.append(gene1[i])
    return child1, child2

def mutate(gene, mutation_rate):
    """
    Mutate a gene at the given rate.
    """
    for i in range(len(gene)):
        if random.random() < mutation_rate:
            gene[i] = random.random() * 2 - 1

def save_population(filename):
    """
    Save the population to a file.
    """
    import pickle
    fp = open(filename, "w")
    fp.write(str(GEN) + "\n")
    fp.write(pickle.dumps(POPULATION))
    fp.close()

def load_population(filename):
    """
    Load the generation from a file.
    """
    global POPULATION, GEN
    import pickle
    fp = open(filename, "r")
    GEN = int(fp.readline())
    data = fp.readlines()
    POPULATION = pickle.loads("".join(data))
    fp.close()

def evolve(generations=10000, mutation_rate=.05, elite_percent=.10):
    global POPULATION, GEN
    elite = int(len(POPULATION) * elite_percent) # % elite, even
    for generation in range(generations): # number of generations
        GEN = GEN + 1
        print("Generation %s..." % GEN)
        for item in POPULATION:
            item[0] = compute_fitness(item[1])
        POPULATION.sort(key=lambda p: -p[0]) # decreasing order
        fitness_sum = sum([p[0] for p in POPULATION])
        print("Best score:", int(POPULATION[0][0]),
              "total:", sum([int(p[0]) for p in POPULATION]))
        new_population = [0] * len(POPULATION)
        for i in range(elite):
            new_population[i] = POPULATION[i]
        for i in range(elite, len(POPULATION), 2):
            i1 = select(POPULATION, fitness_sum)
            i2 = select(POPULATION, fitness_sum)
            child1, child2 = crossover(POPULATION[i1][1], POPULATION[i2][1])
            mutate(child1, mutation_rate)
            mutate(child2, mutation_rate)
            new_population[i] = [0, child1]
            if i + 1 < len(POPULATION):
                new_population[i+1] = [0, child2]
        POPULATION = new_population
        save_population("pop%d.pickle" % GEN)

GEN = 0
network = SRN()
network.addLayers(2, 5, 1)
SIZE = len(network.arrayify())
pattern = {"C": 0.1, "D": 0.9}
POPULATION = make_population(20) # even number
players = [always_defect, always_cooperate,
           tit_for_tat, ipd_network, allie_grace_dan]

evolve()
