#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
split_drums_2bar_save_v4a.py
- Drum-only (CH10) patterns extracted as non-overlapping 2-bar slices
- Deduplicate identical patterns (by canonical note sequence + tempo map)
- Save as 8.3 filename: <GEN>_P<NNN>.MID (e.g., RCK_P001.MID)
- Infer <GEN> from INPUT filename, with **leading digits stripped** (e.g., "2FUNK1.MID" -> detect FUNK -> FNK)
- Options:
    --start NNN   : starting pattern index (default 1 => 001)
    --genre XYZ   : force 3-letter genre code, bypass inference
Usage:
    python split_drums_2bar_save_v4a.py INPUT.MID [--start 101] [--genre RCK]
"""
import argparse
import re
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import mido

# -------- Genre inference helpers --------
GENRE_MAP = [
    (re.compile(r'rock', re.I), 'RCK'),
    (re.compile(r'bossa|bossanova', re.I), 'BOS'),
    (re.compile(r'funk', re.I), 'FNK'),
    (re.compile(r'jazz', re.I), 'JZZ'),
    (re.compile(r'blues?', re.I), 'BLU'),
    (re.compile(r'latin', re.I), 'LAT'),
    (re.compile(r'samba', re.I), 'SMB'),
    (re.compile(r'waltz|wlz', re.I), 'WLZ'),
    (re.compile(r'swing|swg', re.I), 'SWG'),
    (re.compile(r'shuffle|shf', re.I), 'SHF'),
    (re.compile(r'reggae', re.I), 'REG'),
    (re.compile(r'metal', re.I), 'MTL'),
    (re.compile(r'hip\s*-?\s*hop|hiphop|hhp', re.I), 'HHP'),
    (re.compile(r'house|hse', re.I), 'HSE'),
    (re.compile(r'techno|tno', re.I), 'TNO'),
]

def strip_leading_digits(name_noext: str) -> str:
    """Remove leading digits and optional separators from the start of filename stem.
       e.g., '2FUNK1' -> 'FUNK1', '12_ROCK_A' -> 'ROCK_A'"""
    return re.sub(r'^[0-9]+[ _\-.]*', '', name_noext)

def infer_genre_code_from_name(filename: str) -> str:
    stem = Path(filename).stem
    stem = strip_leading_digits(stem)
    for pat, code in GENRE_MAP:
        if pat.search(stem):
            return code
    return 'DRM'  # default

def sanitize_83(s: str) -> str:
    # Keep only letters/numbers/underscore; uppercase; truncate appropriately outside of pattern builder
    s = re.sub(r'[^A-Za-z0-9_]', '', s).upper()
    return s

# -------- MIDI helpers --------
def current_time_signature(track: mido.MidiTrack, abs_tick: int, ticks_per_beat: int) -> Tuple[int, int]:
    """Return (numerator, denominator) active at abs_tick. Default 4/4."""
    num, den = 4, 4
    t = 0
    for msg in track:
        t += msg.time
        if t > abs_tick: break
        if msg.type == 'time_signature':
            num = msg.numerator
            den = msg.denominator
    return num, den

def current_tempo(track: mido.MidiTrack, abs_tick: int) -> int:
    """Return tempo (microsec per beat) active at abs_tick. Default 500000 (120bpm)."""
    tempo = 500000
    t = 0
    for msg in track:
        t += msg.time
        if t > abs_tick: break
        if msg.type == 'set_tempo':
            tempo = msg.tempo
    return tempo

def ticks_per_bar(tpb: int, num: int, den: int) -> int:
    # bar_ticks = tpb * num * (4/den)
    return int(round(tpb * num * (4.0/den)))

def build_absolute_track(track: mido.MidiTrack) -> List[mido.Message]:
    """Return a *copy* of track where msg.time are absolute ticks (monotonic)."""
    out = []
    t = 0
    for msg in track:
        t += msg.time
        m = msg.copy(time=t)
        out.append(m)
    return out

def extract_tempo_changes(abs_msgs: List[mido.Message], start: int, end: int) -> List[Tuple[int,int]]:
    """Return list of (rel_tick, tempo) where rel_tick is relative to 'start' within [start,end)."""
    result = []
    for m in abs_msgs:
        if m.type == 'set_tempo' and start <= m.time < end:
            result.append((m.time - start, m.tempo))
    return result

def canonical_pattern_signature(abs_msgs: List[mido.Message], start: int, end: int, channel: int=9) -> Tuple:
    """Canonical signature for dedup: tempo map + ordered note on/off tuples (time,note,vel,on/off)."""
    notes = []
    tempos = []
    for m in abs_msgs:
        if m.type == 'set_tempo' and start <= m.time < end:
            tempos.append((m.time - start, m.tempo))
        if m.type in ('note_on','note_off'):
            if getattr(m, 'channel', None) == channel and start <= m.time < end:
                if m.type == 'note_on' and m.velocity == 0:
                    # normalize to note_off
                    notes.append((m.time - start, m.note, 0, 0))
                elif m.type == 'note_on':
                    notes.append((m.time - start, m.note, m.velocity, 1))
                else:
                    notes.append((m.time - start, m.note, getattr(m, 'velocity', 0), 0))
    return (tuple(tempos), tuple(notes))

def slice_and_save_2bars(infile: Path, start_idx: int, forced_genre: Optional[str]) -> None:
    mf = mido.MidiFile(infile)
    assert mf.type in (0,1), "Type 0 or 1 only."
    # Flatten to a single combined track view for timing/tempo meta
    if mf.type == 0:
        src_track = mf.tracks[0]
    else:
        # Merge all tracks for tempo/time signature lookup
        src_track = mido.merge_tracks(mf.tracks)

    abs_msgs = build_absolute_track(src_track)
    # Determine initial TS and bar length at t=0
    num0, den0 = current_time_signature(src_track, 0, mf.ticks_per_beat)
    bar_ticks = ticks_per_bar(mf.ticks_per_beat, num0, den0)

    # Get total length in ticks from last message
    total_ticks = abs_msgs[-1].time if abs_msgs else 0
    total_bars = total_ticks // bar_ticks
    if total_bars < 2:
        print("Not enough bars for 2-bar slicing.")
        return

    # Determine genre code
    if forced_genre:
        genre = sanitize_83(forced_genre)[:3]
    else:
        genre = infer_genre_code_from_name(infile.name)

    # Prepare output folder
    out_dir = infile.parent
    # Dedup map
    sig_to_stats: Dict[Tuple, Dict] = {}
    saved = 0

    for b in range(0, total_bars - (2-1), 2):
        start_tick = b * bar_ticks
        end_tick   = start_tick + 2*bar_ticks

        sig = canonical_pattern_signature(abs_msgs, start_tick, end_tick, channel=9)
        if sig not in sig_to_stats:
            sig_to_stats[sig] = {"count": 0, "first_index": None}
        sig_to_stats[sig]["count"] += 1

    # Now, actually save only first occurrence of each signature, in order of appearance
    used_signatures = set()
    idx = start_idx
    for b in range(0, total_bars - (2-1), 2):
        start_tick = b * bar_ticks
        end_tick   = start_tick + 2*bar_ticks
        sig = canonical_pattern_signature(abs_msgs, start_tick, end_tick, channel=9)
        if sig in used_signatures:
            continue
        used_signatures.add(sig)

        # Build output MIDI
        out = mido.MidiFile(type=1, ticks_per_beat=mf.ticks_per_beat)
        tr = mido.MidiTrack()
        out.tracks.append(tr)

        # Put time signature active at slice start
        num, den = current_time_signature(src_track, start_tick, mf.ticks_per_beat)
        tr.append(mido.MetaMessage('time_signature', numerator=num, denominator=den, time=0))

        # Put tempo active at slice start
        tempo0 = current_tempo(src_track, start_tick)
        tr.append(mido.MetaMessage('set_tempo', tempo=tempo0, time=0))

        # Add tempo changes inside the window at correct relative times
        tempos_inside = extract_tempo_changes(abs_msgs, start_tick, end_tick)
        # We'll insert them as we pass those times

        # Collect all channel-10 notes in window
        window_msgs = []
        for m in abs_msgs:
            if start_tick <= m.time < end_tick:
                if m.type in ('note_on','note_off') and getattr(m, 'channel', None) == 9:
                    t_rel = m.time - start_tick
                    # normalize vel=0 note_on => note_off
                    if m.type == 'note_on' and m.velocity == 0:
                        window_msgs.append(mido.Message('note_off', note=m.note, velocity=0, channel=9, time=t_rel))
                    else:
                        window_msgs.append(m.copy(time=t_rel))

        # Merge tempo changes + notes, sort by time, ensure stable ordering (tempo before notes at same tick)
        for t_rel, tp in tempos_inside:
            window_msgs.append(mido.MetaMessage('set_tempo', tempo=tp, time=t_rel))
        window_msgs.sort(key=lambda x: (x.time, 0 if x.is_meta else 1))

        # Convert relative delta times
        last = 0
        for m in window_msgs:
            dt = m.time - last
            last = m.time
            m.time = dt
            tr.append(m)

        # End of track
        tr.append(mido.MetaMessage('end_of_track', time=0))

        # Build 8.3 filename
        code = genre[:3]
        num3 = f"{idx:03d}"
        # stick to 8.3: "XXX_PNNN.MID" (3 + 1 + 4 = 8)
        basename = f"{code}_P{num3}"
        basename = sanitize_83(basename)[:8]
        outpath = out_dir / f"{basename}.MID"
        out.save(outpath)
        print(f"Saved: {outpath.name}")
        idx += 1
        saved += 1

    # Report frequency
    print("== Pattern frequency (desc) ==")
    sorted_items = sorted(sig_to_stats.items(), key=lambda kv: kv[1]["count"], reverse=True)
    for rank, (sig, info) in enumerate(sorted_items, 1):
        print(f"{rank:02d}. count={info['count']}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('input', help='Input MIDI file (Type 0 or 1).')
    ap.add_argument('--start', type=int, default=1, help='Starting index (default: 1 => 001).')
    ap.add_argument('--genre', type=str, default=None, help='Force 3-letter genre code (e.g., RCK).')
    args = ap.parse_args()

    infile = Path(args.input)
    if not infile.exists():
        raise SystemExit(f"Input not found: {infile}")

    slice_and_save_2bars(infile, start_idx=args.start, forced_genre=args.genre)

if __name__ == '__main__':
    main()
