#!/usr/bin/python

import os
import sys
import argparse
import wave


class bcolors:
    HEADER = "\033[95m"
    OKBLUE = "\033[94m"
    OKCYAN = "\033[96m"
    OKGREEN = "\033[92m"
    WARNING = "\033[93m"
    FAIL = "\033[91m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"


parser = argparse.ArgumentParser(
    prog="wav_to_wavetable.py",
    description="converts a wave file into an c-style array of 10 bit signed variables.",
    epilog="the result in a header file. Copy the result into the \
                    appropriate place to use the wavetable (replacing the old one).\n\
                    \n\
                    If a stereo wav-file is provided this uses the left channel only.",
)

parser.add_argument("-i", "--input", required=True, help="Path of the input wav file")
parser.add_argument("-o", "--output", required=True, help="Path for the output")
parser.add_argument(
    "-f", "--force", action="store_true", help="overwrite output file if it exists"
)

args = parser.parse_args()

if not os.path.isfile(args.input):
    print(
        f"{bcolors.FAIL}Error:{bcolors.ENDC} The input file at the specified path does not exist:",
        file=sys.stderr,
    )
    print(f"       {args.input}")
    exit(1)

if os.path.isfile(args.output) and not args.force:
    print(
        f"{bcolors.FAIL}Error:{bcolors.ENDC} There is already a file at the specified output path:",
        file=sys.stderr,
    )
    print(f"       {args.input}", file=sys.stderr)
    print(
        f"{bcolors.OKGREEN}Hint:{bcolors.ENDC}  Run this command again with the --force flag to overwrite",
        file=sys.stderr,
    )
    exit(1)

path = args.input
w = wave.open(path, "rb")
channels = w.getnchannels()
samplewidth = w.getsampwidth()
print(
    f"Audio has {channels} channels and each sample is {samplewidth} bytes wide ({samplewidth * 8} bit audio)"
)

resampled = []


def resample(sample: int, bits=24) -> int:
    limit = 2**10
    steps = 2 ** (bits - 10)
    return max(-limit, min(limit, int(sample / steps)))


for n in range(w.getnframes()):
    frame = w.readframes(n)
    if frame != b"":
        # Convert the frame into a list of integers, assuming little endian encoding
        frame_data = [
            int.from_bytes(
                frame[i : i + samplewidth], byteorder=sys.byteorder, signed=True
            )
            for i in range(0, len(frame), samplewidth)
        ]
        # If we have more than one channel the samples of each channel
        # should be interleaved
        if channels == 1:
            # Each frame can contain multiple samples
            for sample in frame_data:
                # print(sample)
                # 24 bit - 14 bit = 10 bit
                resampled.append(resample(sample, samplewidth * 8))
        elif channels == 2:
            # Iterate in steps of 2 over the frames and deinterleave
            # them into the samples for left and right
            for c in range(0, len(frame_data), 2):
                left, right = zip(frame_data[c : c + 2])
                left, right = left[0], right[0]
                resampled.append(resample(left))
        else:
            print(
                f"{bcolors.FAIL}Error:{bcolors.ENDC} Sorry, we do not support wave files with {channels} channels.. yet",
                file=sys.stderr,
            )
            exit(1)

print(f"Resampled {len(resampled)} samples")
# shrink_to = 48000*2
# resampled = resampled[:shrink_to]
c_array = f"short wavetable[{len(resampled)}] = {{"
elements_per_line = 64
for chunk in range(0, len(resampled), elements_per_line):
    c_array += ",".join([str(s) for s in resampled[chunk : chunk + elements_per_line]])
    c_array += ",\n    "
c_array += "};"


with open(args.output, "w") as f:
    f.write(c_array)
    f.write("\n")

print(f"Written {len(resampled)} 10 bit samples to {args.output}")
