import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

# Sphere properties
radius = 0.05 / 3  # Now 1/3 of the original size
mass = 0.02

# Box properties
box_size = 0.3

# Initial conditions
# Positions and velocities of spheres
positions = np.array([[-0.1, -0.1], [0.1, 0.1], [-0.1, 0.1], [0.1, -0.1]])
velocities = np.array([[0.03, 0.04], [-0.03, -0.04], [0.02, -0.03], [-0.02, 0.03]])

# random from start:
positions  = np.array( [ [np.random.uniform(-0.1,0.1),np.random.uniform(-0.1,0.1)] for i in range(10) ])
velocities = np.array( [ [np.random.uniform(-0.1,0.1),np.random.uniform(-0.1,0.1)] for i in range(10) ])

n_particles = 1
if n_particles == 1:
    positions  = np.array( [[-0.1, -0.1]] )
    velocities = np.array( [[0.1, 0.14]] )
elif n_particles == 2:
    positions  = np.array(  [[-0.1, -0.1], [0.1  , 0.1]] )
    velocities = np.array( [[0.03, 0.04], [-0.03, -0.04]] )
else:
    positions = np.array( [ [-0.2 , -0.2+i/20] for i in range(n_particles) ])
    velocities = np.array( [ [0.1 ,         0] for i in range(n_particles) ])
    # make one slightly different:
    velocities[0] = [0.11, 0.01]
    
    
    
# Time properties
dt = 0.01
n_steps = 1000

# Animation setup
fig, ax = plt.subplots()
ax.set_xlim(-box_size, box_size)
ax.set_ylim(-box_size, box_size)
ax.set_aspect('equal', 'box')

circles = [plt.Circle(pos, radius, color=np.random.rand(3,)) for pos in positions]
for circ in circles:
    ax.add_artist(circ)

def update(num):
    global positions, velocities
    # Update positions
    positions += velocities * dt
    # Check for collision with box
    left_right = np.abs(positions[:, 0]) > box_size - radius
    top_bottom = np.abs(positions[:, 1]) > box_size - radius
    velocities[left_right, 0] *= -1
    velocities[top_bottom, 1] *= -1
    # Check for collision between spheres
    for i in range(len(positions)):
        for j in range(i+1, len(positions)):
            dx = positions[j, 0] - positions[i, 0]
            dy = positions[j, 1] - positions[i, 1]
            dist = np.sqrt(dx*dx + dy*dy)
            if dist < 2*radius:
                # Compute normal and tangent unit vectors
                nx, ny = dx/dist, dy/dist
                tx, ty = -ny, nx
                # Project velocities onto normal and tangent directions
                v1n = velocities[i, 0]*nx + velocities[i, 1]*ny
                v1t = velocities[i, 0]*tx + velocities[i, 1]*ty
                v2n = velocities[j, 0]*nx + velocities[j, 1]*ny
                v2t = velocities[j, 0]*tx + velocities[j, 1]*ty
                # Compute new normal velocities (tangent velocities don't change in elastic collision)
                velocities[i, 0] = v2n*nx + v1t*tx
                velocities[i, 1] = v2n*ny + v1t*ty
                velocities[j, 0] = v1n*nx + v2t*tx
                velocities[j, 1] = v1n*ny + v2t*ty
    # Update circle positions
    for pos, circ in zip(positions, circles):
        circ.center = pos[0], pos[1]
    return circles

n_steps= 5000;
ani = animation.FuncAnimation(fig, update, frames=n_steps, interval=dt*1000, blit=True)

#ani.save('gas.mp4', writer='ffmpeg')

plt.show()
