#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
midi2adt.py (ADT v2.2)
- 2-bar MIDI(드럼) -> ADT 텍스트로 변환
- 기본: 4/4, GRID=16(스트레이트 16분), LENGTH=32, GM 12슬롯
- 트리플렛/3박 확장: --grid 8T/16T, --length 24/48

필수: pip install mido
"""

import argparse, sys, os, pathlib
from mido import MidiFile

# --- v2.2 기본 상수 ---
ADT_VERSION_STR = "ADT v2.2"
DEFAULT_GRID = "16"     # "16", "8T", "16T"
DEFAULT_LENGTH = 32     # 24/32/48
DEFAULT_TIME_SIG = "4/4"
DEFAULT_KIT = "GM_STD"
DEFAULT_ORIENTATION = "STEP"
DEFAULT_SLOTS = 12
DEFAULT_PPQN_NOTE = 96  # (정보용: 내부 ADP용 권고치, 여기선 미사용)

# GRID → 박 당 분할 수
GRID_SUBDIV = {
    "16": 4,   # 16분 = 4 subdivision per beat
    "8T": 3,   # 8분 트리플렛
    "16T": 6,  # 16분 트리플렛
}

# GM 12-slot preset: (note, abbr, name)
GM12 = [
    (36,"KK","KICK"), (38,"SN","SNARE"), (42,"CH","HH_CL"), (46,"OH","HH_OP"),
    (45,"LT","TOM_L"), (47,"MT","TOM_M"), (50,"HT","TOM_H"), (51,"RD","RIDE"),
    (49,"CR","CRASH"), (37,"RM","RIM"),  (39,"CL","CLAP"),  (44,"PH","HH_PED"),
]
NOTE2SLOT = {n:i for i,(n,_,_) in enumerate(GM12)}

def parse_args():
    p = argparse.ArgumentParser(description="2-bar MIDI (drums) → ADT (v2.2)")
    p.add_argument("input", nargs="?", help="입력 MIDI 파일 경로(.mid). --in-dir 사용 시 생략 가능")
    p.add_argument("--in-dir", type=str, default=None, help="입력 폴더(내의 .mid 일괄 변환)")
    p.add_argument("--out-dir", type=str, default=None, help="출력 폴더(기본: 입력과 동일)")
    p.add_argument("--recursive", action="store_true", help="--in-dir 사용 시 하위 폴더까지 재귀 처리")
    p.add_argument("--grid", type=str, choices=["16","8T","16T"], default=DEFAULT_GRID, help="그리드: 16/8T/16T")
    p.add_argument("--length", type=int, choices=[24,32,48], default=DEFAULT_LENGTH, help="스텝 수: 24/32/48")
    p.add_argument("--time-sig", type=str, default=DEFAULT_TIME_SIG, help="표시용 박자 (예: 4/4, 3/4, 12/8)")
    p.add_argument("--kit", type=str, default=DEFAULT_KIT, help="KIT 힌트 (예: GM_STD)")
    p.add_argument("--orientation", type=str, choices=["STEP","SLOT"], default=DEFAULT_ORIENTATION,
                   help="ADT 본문 배치 (출력은 STEP 권장)")
    p.add_argument("--channel", type=int, default=10,
                   help="드럼 채널 (1~16, 기본 10=GM drum). 0/음수 입력 금지.")
    p.add_argument("--vel-thresholds", type=str, default="64,96,112",
                   help="ACC 구간 경계 (세 개, 예: '64,96,112')")
    p.add_argument("--overwrite", action="store_true", help="기존 .ADT가 있어도 덮어쓰기")
    return p.parse_args()

def acc_from_velocity(v, thresholds):
    # thresholds: [t1,t2,t3] (예: [64,96,112]); v<=0이면 rest
    if v <= 0: return 0
    if v < thresholds[0]: return 1
    if v < thresholds[1]: return 2
    # t3는 상한 힌트지만, 등급은 3이 최대
    return 3

def acc_to_char(a):
    return ['-','.', 'o','X'][a]

def quantize_step(abs_ticks, tpq, grid, length):
    """
    abs_ticks: 현재 메시지의 절대 tick
    tpq: ticks per quarter (MIDI header)
    grid: "16"/"8T"/"16T"
    length: 총 스텝 수 (24/32/48)
    """
    subdiv = GRID_SUBDIV[grid]
    ticks_per_step = tpq / subdiv
    if ticks_per_step <= 0:
        step = 0
    else:
        step = int(round(abs_ticks / ticks_per_step))
    # 2마디(패턴 길이)에서 벗어난 이벤트는 클램프
    if step < 0: step = 0
    if step > length - 1: step = length - 1
    return step

def extract_grid_from_midi(mid: MidiFile, drum_channel_one_based: int, grid: str, length: int,
                           thresholds):
    """
    mid: mido.MidiFile
    drum_channel_one_based: 1~16 (GM 드럼은 10)
    """
    tpq = mid.ticks_per_beat
    ch_idx = drum_channel_one_based - 1  # 0~15
    grid_data = [[0]*DEFAULT_SLOTS for _ in range(length)]

    for tr in mid.tracks:
        abs_t = 0
        for msg in tr:
            abs_t += msg.time
            if not hasattr(msg, "type"): 
                continue
            if msg.type not in ("note_on","note_off"):
                continue
            # 일부 DAW가 channel 필드 없는 meta를 넣기도 함
            ch = getattr(msg, "channel", None)
            if ch is None or ch != ch_idx:
                continue
            note = getattr(msg, "note", None)
            if note is None or note not in NOTE2SLOT:
                continue
            vel = getattr(msg, "velocity", 0) if msg.type == "note_on" else 0
            if vel <= 0:  # note_off 또는 vel=0은 무시 (게이트는 엔진에서 처리)
                continue

            step = quantize_step(abs_t, tpq, grid, length)
            slot = NOTE2SLOT[note]
            acc  = acc_from_velocity(vel, thresholds)
            if acc > grid_data[step][slot]:
                grid_data[step][slot] = acc

    return tpq, grid_data

def write_adt(path_out: pathlib.Path, name_base: str, grid: str, length: int,
              time_sig: str, kit: str, orientation: str, grid_data):
    """
    grid_data: STEP-우선( length x 12 ) 형태의 ACC 등급 그리드
    orientation이 SLOT이면 출력 전에 90도 회전 출력
    """
    lines = []
    lines.append(f"; {ADT_VERSION_STR}")
    lines.append(f"NAME={name_base}")
    lines.append(f"TIME_SIG={time_sig}")
    lines.append(f"GRID={grid}")
    lines.append(f"LENGTH={length}")
    lines.append(f"SLOTS={DEFAULT_SLOTS}")
    lines.append(f"KIT={kit}")
    lines.append(f"ORIENTATION={orientation}")

    # SLOT 선언
    for idx,(note,abbr,name) in enumerate(GM12):
        lines.append(f"SLOT{idx}={abbr}@{note},{name}")

    # 본문
    if orientation == "STEP":
        # length줄 × 12문자
        for s in range(length):
            row = ''.join(acc_to_char(grid_data[s][j]) for j in range(DEFAULT_SLOTS))
            lines.append(row)
    else:
        # SLOT-우선(12줄 × length문자) — 90도 회전하여 출력
        for j in range(DEFAULT_SLOTS):
            row = ''.join(acc_to_char(grid_data[s][j]) for s in range(length))
            lines.append(row)

    text = "\n".join(lines) + "\n"
    path_out.write_text(text, encoding="utf-8")

def convert_file(path_in: pathlib.Path, out_dir: pathlib.Path, args):
    if not path_in.exists() or path_in.suffix.lower() not in [".mid", ".midi"]:
        return False, f"skip (not midi): {path_in}"

    name_base = path_in.stem
    path_out  = (out_dir / name_base).with_suffix(".ADT")

    if path_out.exists() and not args.overwrite:
        return False, f"exists: {path_out.name} (use --overwrite)"

    # 파라미터 준비
    length = args.length
    grid   = args.grid
    time_sig = args.time_sig
    kit      = args.kit
    orientation = args.orientation
    ch = args.channel
    if not (1 <= ch <= 16):
        return False, f"--channel must be 1..16 (got {ch})"
    # velocity thresholds
    try:
        th = [int(x.strip()) for x in args.vel_thresholds.split(",")]
        if len(th) != 3:
            raise ValueError
        th.sort()  # 오름차순 보장
    except Exception:
        return False, f"--vel-thresholds must be like '64,96,112'"

    # MIDI 로드 및 추출
    try:
        mid = MidiFile(str(path_in))
    except Exception as e:
        return False, f"mido load error: {path_in.name}: {e}"

    tpq, grid_data = extract_grid_from_midi(mid, ch, grid, length, th)

    # 출력
    try:
        out_dir.mkdir(parents=True, exist_ok=True)
        write_adt(path_out, name_base, grid, length, time_sig, kit, orientation, grid_data)
    except Exception as e:
        return False, f"write error: {path_out.name}: {e}"

    return True, f"ok: {path_in.name} -> {path_out.name} (tpq={tpq}, grid={grid}, len={length})"

def iter_midi_files(root: pathlib.Path, recursive: bool):
    if not recursive:
        for p in root.glob("*.mid"):
            yield p
        for p in root.glob("*.MID"):
            yield p
    else:
        for p in root.rglob("*"):
            if p.suffix.lower() in (".mid",".midi"):
                yield p

def main():
    args = parse_args()

    if args.in_dir:
        in_root = pathlib.Path(args.in_dir)
        if not in_root.exists():
            print(f"[ERR] no such dir: {in_root}", file=sys.stderr); sys.exit(1)
        out_root = pathlib.Path(args.out_dir) if args.out_dir else in_root
        total = ok = 0
        for p in iter_midi_files(in_root, args.recursive):
            total += 1
            success, msg = convert_file(p, out_root, args)
            print(("[OK] " if success else "[SKIP] ") + msg)
            if success: ok += 1
        print(f"\nDone. {ok}/{total} converted.")
        sys.exit(0)

    # 단일 파일 모드
    if not args.input:
        print("Usage: midi2adt.py <file.mid> [--grid 16|8T|16T] [--length 24|32|48] ...", file=sys.stderr)
        sys.exit(1)

    path_in = pathlib.Path(args.input)
    out_dir = pathlib.Path(args.out_dir) if args.out_dir else path_in.parent
    success, msg = convert_file(path_in, out_dir, args)
    print(("[OK] " if success else "[ERR] ") + msg)
    sys.exit(0 if success else 1)

if __name__ == "__main__":
    main()
