#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import math
N = 100  # number of oscillators
q = 100  # number of quanta
plot_q_A = True 
fixed_y_scale = True 
iters_per_plot = 1000
pause_time = 0.01


N_A = N//2  # number of oscillators in system A
n_steps = 1_000_000

n_occ = np.zeros(N)
n_occ[0] = q # put all quanta in 1 oscillator of system A

q_A = np.zeros( n_steps)
x = np.arange( n_steps )

# Create a figure with two vertically stacked subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))

for i_osc in range(len(n_occ)):
    ax1.scatter(i_osc, n_occ[i_osc], marker='o', color='red')
q_A[0] = N
for i in range(1,n_steps-1):
    if i > 1:
        i_src = np.random.randint(N)
        while n_occ[i_src] == 0:
            i_src = np.random.randint(N)
        i_dst = np.random.randint(N)
        n_occ[i_src] -= 1
        n_occ[i_dst] += 1
    q_A[i] = np.sum( n_occ[:N_A] )
    if i % iters_per_plot == 0:
        if plot_q_A:
            ax2.set_ylim(0,q)
            ax2.set_xlabel('time', fontsize=18)
            ax2.set_ylabel(r'$q_A$', fontsize=18)
            ax2.plot(x[:i],q_A[:i], color='black')

        ax1.clear()
        ax1.set_xlabel('which oscillator', fontsize=18)
        ax1.set_ylabel('quanta in oscillator', fontsize=18)
        n_max = max(n_occ)
        if fixed_y_scale:
            ax1.set_ylim(0,q)
            label_y_position = 0.7 * q
        else:
            label_y_position = 0.7 * n_max
        # Add text at a specific coordinate
        omega = math.factorial(q+N-1) / ( math.factorial(q)* math.factorial(N-1) )
        label = r"$N=$" + f"{N}\n" + r"$q=$" + f"{q}\n" + r"$\Omega \approx " + f"{omega:.1e}$"
        ax1.text(0.45*N, label_y_position, label, fontsize=16, color='green')
        
        label = "A"
        ax1.text(0.2*N, label_y_position, label, fontsize=18, color='red')
        label = "B"
        ax1.text(0.9*N, label_y_position, label, fontsize=18, color='blue')
        # Hide the numbers on the x-axis
        ax1.set_xticks([])  # Pass an empty list to remove tick marks and numbers
        
        for i_osc in range(N):
            if i_osc < N//2:
                ax1.scatter(i_osc        , n_occ[i_osc], marker='o', color='r')
            else:
                ax1.scatter(i_osc + 0.2*N, n_occ[i_osc], marker='o', color='b')
        # Adjust spacing between subplots
        plt.tight_layout()
        plt.pause(pause_time)
        
        
    
    