from util import get_input
from itertools import product
from more_itertools import flatten
from math import sqrt, ceil, floor

input = get_input("17.input")

stuff = input[0].split()

tx = [int(a) for a in stuff[2][2:-1].split("..")]
ty = [int(a) for a in stuff[3][2:].split("..")]

def hits(velx, tx, ty):
    # High school math comes in handy once again...

    #x = velx + velx - 1 + velx - 2 + velx - 3 = velx * steps - steps (steps - 1) / 2 = (velx + 1/2) * steps - steps^2 / 2
    # 2 x = (2 velx + 1) steps - steps^2
    # 2x - (2 velx + 1) steps + steps^2 = 0

    try:
        minsteps = ceil((2 * velx + 1) / 2 - sqrt(pow((2 * velx + 1) / 2, 2) - 2 * tx[0]))
    except ValueError:
        # Equation has no solution;
        # projectile never reaches area
        return []

    try:
        maxsteps = floor((2 * velx + 1) / 2 - sqrt(pow((2 * velx + 1) / 2, 2) - 2 * tx[1]))
    except ValueError:
        # Projectile x-velocity reaches 0
        # while above the target area
        maxsteps = minsteps + 1000

    res = []
    for nstep in range(minsteps, maxsteps + 1):
        #y = vely * nstep - nstep * (nstep - 1) / 2
        #vely * nstep = nstep * (nstep - 1) / 2 + y
        #vely = (nstep * (nstep - 1) / 2 + y) / nstep
        minvely = ceil((nstep * (nstep - 1) / 2 + ty[0]) / nstep)
        maxvely = floor((nstep * (nstep - 1) / 2 + ty[1]) / nstep)

        for vely in range(minvely, maxvely + 1):
            res.append((velx, vely))
    return list(set(res))


# This is slow garbage, leaving it here so you can laugh at me
def naive_hits(vel, tx, ty):
    topy = 0
    pos = (0, 0)
    startvel = vel
    while True:
        if pos[0] >= tx[0] and pos[0] <= tx[1]:
            if pos[1] >= ty[0] and pos[1] <= ty[1]:
                return (topy, startvel)
        if pos[1] < ty[0]:
            return (-1000, startvel)
        if pos[0] > tx[1]:
            return (-1000, startvel)
        if vel[0] == 0:
            dx = 0
        elif vel[0] < 0:
            dx = 1
        else:
            dx = -1
        pos = (pos[0] + vel[0], pos[1] + vel[1])
        vel = (vel[0] + dx, vel[1] - 1)
        topy = max(pos[1], topy)


hits = list(flatten([hits(x, tx, ty) for x in range(0, 70)]))

# Find highest point on parabola using initial y-velocity
maxvely = max([h[1] for h in hits])
# D(vely * t - (t ^ 2) / 2) == 0
# vely - t == 0
t = maxvely
maxy = int(t * t - t * (t - 1) / 2)

print("Part 1:", maxy)

# This is easy now lol
print("Part 2:", len(hits))