import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Define initial positions and velocities
positions = np.array([[0.2, 0.53], [0.8, 0.5]])
velocities = np.array([[-0.008, 0.0], [0.008, 0.0]])
radius = 0.06

# Total animation time in seconds and desired frames per second
total_time = 30
frames_per_second = 30

# History of positions
history = [np.zeros((int(total_time * frames_per_second), 2)) for _ in positions]
history[0][0] = positions[0]
history[1][0] = positions[1]

# Define the update function
def update(num, positions, velocities, scatters, lines):
    # Update positions
    positions += velocities
    for i in range(2):
        history[i][num] = positions[i]

    # Reflect off walls
    for i in range(2):
        for j in range(2):
            if positions[i][j] < radius or positions[i][j] > 1 - radius:
                velocities[i][j] *= -1

    # Check for collision between spheres
    delta_pos = positions[1] - positions[0]
    distance = np.linalg.norm(delta_pos)
    if distance < 2 * radius:
        normal = delta_pos / distance
        delta_vel = np.dot(velocities[1] - velocities[0], normal) * normal
        velocities[0] += delta_vel
        velocities[1] -= delta_vel

    # Update scatter plot
    for i, scatter in enumerate(scatters):
        scatter.set_offsets([positions[i]])

    # Update lines for trails
    for i, line in enumerate(lines):
        line.set_data(history[i][:num, 0], history[i][:num, 1])

    return scatters + lines

# Set up the figure
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect('equal')
scatters = [plt.scatter(*pos, s=radius*3000) for pos in positions]
lines = [plt.plot([], [], lw=1,linestyle='dotted')[0] for _ in positions]

# Number of frames and interval between frames
frames = total_time * frames_per_second
interval = 1000 / frames_per_second # Interval in milliseconds

# Create the animation
ani = animation.FuncAnimation(fig, update, frames=int(frames), fargs=(positions, velocities, scatters, lines),
                              interval=int(interval), blit=True)

# Save as an MPEG file
#ani.save('hard_spheres_scattering.mp4', writer='ffmpeg')

plt.show()
