import time
import logging
import os
import sys
import clr
import ctypes
import re
import shutil
import psutil
from datetime import datetime
from library.lcd.lcd_comm import Orientation
from library.lcd.lcd_comm_rev_a import LcdCommRevA
import socket
import subprocess

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    force=True
)

# ====== Display Settings (★색상 테마 수정★) ======
DISPLAY_WIDTH = 480
DISPLAY_HEIGHT = 320
FONT_PATH = "res/fonts/jetbrains-mono/JetBrainsMono-Bold.ttf"
BG_COLOR = (0, 0, 0)
CPU_COLOR = (255, 128, 0)   # 주황색
RAM_COLOR = (0, 255, 255)   # 청록색
GPU_COLOR = (128, 255, 0)   # 연두색
NET_COLOR = (0, 128, 255)   # 하늘색
SSD_COLOR = (255, 0, 255)   # 자홍색
TEXT_COLOR = (255, 255, 255) # 흰색 (제목)
UPDATE_INTERVAL = 0.5
BACKGROUND_IMG = "background.png"
BAR_OFFSET = 1

# ====== Setup LibreHardwareMonitor ======
clr.AddReference(os.path.join(os.getcwd(), 'external', 'LibreHardwareMonitor', 'LibreHardwareMonitorLib.dll'))
from LibreHardwareMonitor import Hardware

handle = Hardware.Computer()
handle.IsCpuEnabled = True
handle.IsGpuEnabled = True
handle.IsMemoryEnabled = True
handle.IsNetworkEnabled = True
handle.IsStorageEnabled = True
handle.Open()

# ====== Global Variables ======
CPU_NAME = "Unknown CPU"

# ====== Sensor Query Functions ======
def get_sensor_value(hw_list, hw_type, sensor_type, sensor_name, hw_name=None):
    for hw in hw_list:
        if hw.HardwareType == hw_type:
            if hw_name and (hw_name.lower() not in hw.Name.lower()):
                continue
            hw.Update()
            for sensor in hw.Sensors:
                if str(sensor.SensorType) == sensor_type and sensor.Name == sensor_name:
                    return sensor.Value
    return None

def get_hardware_name(hw_list, hw_type):
    for hw in hw_list:
        if hw.HardwareType == hw_type:
            return hw.Name
    return None

def truncate_first_word(name_str):
    parts = name_str.split()
    return " ".join(parts[1:]) if len(parts) > 1 else name_str

def initialize_hardware_names():
    global CPU_NAME
    hw_list = handle.Hardware
    cpu_full_name = get_hardware_name(hw_list, Hardware.HardwareType.Cpu) or "Unknown CPU"
    CPU_NAME = truncate_first_word(cpu_full_name)

def get_gpu_stats(hw_list, filter_str):
    stats = {}
    hw_type = None
    if "amd" in filter_str.lower():
        hw_type = Hardware.HardwareType.GpuAmd
    elif "nvidia" in filter_str.lower():
        hw_type = Hardware.HardwareType.GpuNvidia

    if not hw_type: return None

    for hw in hw_list:
        if hw.HardwareType == hw_type and filter_str.lower() in hw.Name.lower():
            stats["name"] = hw.Name
            stats["util"] = get_sensor_value(hw_list, hw_type, "Load", "GPU Core", hw_name=filter_str) or 0.0
            stats["temp"] = get_sensor_value(hw_list, hw_type, "Temperature", "GPU Core", hw_name=filter_str) or 0.0
            stats["clock"] = get_sensor_value(hw_list, hw_type, "Clock", "GPU Core", hw_name=filter_str) or 0.0
            stats["mem_used"] = get_sensor_value(hw_list, hw_type, "SmallData", "GPU Memory Used", hw_name=filter_str) or 0.0
            stats["mem_total"] = get_sensor_value(hw_list, hw_type, "SmallData", "GPU Memory Total", hw_name=filter_str) or 1.0
            stats["mem_percent"] = (stats["mem_used"] / stats["mem_total"]) * 100 if stats["mem_total"] > 0 else 0
            return stats
    return None

def get_sorted_core_loads(hw_list):
    core_loads = []
    for hw in hw_list:
        if hw.HardwareType == Hardware.HardwareType.Cpu:
            hw.Update()
            for sensor in hw.Sensors:
                if str(sensor.SensorType) == "Load" and "Core #" in sensor.Name:
                    m = re.search(r'#(\d+)', sensor.Name)
                    if m:
                        core_loads.append((int(m.group(1)), sensor.Name, sensor.Value or 0.0))
            for subhw in getattr(hw, 'SubHardware', []):
                subhw.Update()
                for sensor in subhw.Sensors:
                    if str(sensor.SensorType) == "Load" and "Core #" in sensor.Name:
                        m = re.search(r'#(\d+)', sensor.Name)
                        if m:
                            core_loads.append((int(m.group(1)), sensor.Name, sensor.Value or 0.0))
    return sorted(list(set(core_loads)), key=lambda x: x[0])

# 이전 값을 저장하기 위한 전역 변수
last_net_check_time = 0
last_net_uploaded = 0
last_net_downloaded = 0

def get_network_stats(hw_list):
    """가장 활발한 네트워크 어댑터를 찾아 이름과 실시간 속도를 계산하는 함수"""
    global last_net_check_time, last_net_uploaded, last_net_downloaded
    
    active_adapter = None
    max_data = -1

    # 1. 가장 데이터 전송량이 많은 (활성화된) 어댑터 찾기
    for hw in hw_list:
        if hw.HardwareType == Hardware.HardwareType.Network:
            hw.Update()
            total_downloaded = get_sensor_value(hw_list, Hardware.HardwareType.Network, "Data", "Data Downloaded", hw_name=hw.Name)
            if total_downloaded is not None and total_downloaded > max_data:
                max_data = total_downloaded
                active_adapter = hw

    # ★수정: 반환할 딕셔너리에 어댑터 이름 추가
    stats = {'up': 0.0, 'down': 0.0, 'adapter_name': 'N/A'}
    if not active_adapter:
        return stats
        
    # ★수정: 찾은 어댑터의 이름을 저장
    stats['adapter_name'] = active_adapter.Name

    # 2. 누적 데이터 값 읽기 (단위: Gigabytes)
    current_uploaded_gb = get_sensor_value(hw_list, Hardware.HardwareType.Network, "Data", "Data Uploaded", hw_name=active_adapter.Name) or 0.0
    current_downloaded_gb = get_sensor_value(hw_list, Hardware.HardwareType.Network, "Data", "Data Downloaded", hw_name=active_adapter.Name) or 0.0
    current_time = time.time()
    
    # 3. 이전 값과 비교하여 실시간 속도 계산
    time_delta = current_time - last_net_check_time
    
    if last_net_check_time > 0 and time_delta > 0:
        upload_delta = current_uploaded_gb - last_net_uploaded
        download_delta = current_downloaded_gb - last_net_downloaded
        
        stats['up'] = (upload_delta * 8 * 1024) / time_delta
        stats['down'] = (download_delta * 8 * 1024) / time_delta

    # 현재 값을 다음 계산을 위해 저장
    last_net_check_time = current_time
    last_net_uploaded = current_uploaded_gb
    last_net_downloaded = current_downloaded_gb

    return stats

def get_ip_info():
    info = {'ip': 'N/A', 'gateway': 'N/A'}
    try:
        hostname = socket.gethostname()
        info['ip'] = socket.gethostbyname(hostname)
        result = subprocess.check_output("ipconfig", shell=True, text=True, stderr=subprocess.DEVNULL)
        for line in result.split("\n"):
            if "Default Gateway" in line:
                gateway = line.split(":")[-1].strip()
                if len(gateway) > 1:
                    info['gateway'] = gateway
                    break
    except Exception as e:
        logging.error(f"IP 정보 가져오기 실패: {e}")
    return info

# 이전 IO 값을 저장하기 위한 전역 변수
last_io_check_time = 0
last_bytes_read = 0
last_bytes_written = 0

def get_storage_stats():
    """psutil 라이브러리를 사용하여 디스크 정보를 가져오는 함수"""
    global last_io_check_time, last_bytes_read, last_bytes_written
    stats = {'used_gb': 0, 'total_gb': 0, 'used_pct': 0, 'read': 0.0, 'write': 0.0}

    # 1. 디스크 사용량 정보 (shutil 유지)
    try:
        total, used, free = shutil.disk_usage("C:\\")
        stats['used_gb'] = used / (1024**3)
        stats['total_gb'] = total / (1024**3)
        stats['used_pct'] = (used / total) * 100
    except Exception as e:
        logging.error(f"disk fale: {e}")

    # 2. 읽기/쓰기 속도 정보 (psutil 사용)
    try:
        disk_io = psutil.disk_io_counters()
        current_bytes_read = disk_io.read_bytes
        current_bytes_written = disk_io.write_bytes
        current_time = time.time()
        
        time_delta = current_time - last_io_check_time
        
        if last_io_check_time > 0 and time_delta > 0:
            read_delta = current_bytes_read - last_bytes_read
            write_delta = current_bytes_written - last_bytes_written
            
            # Bytes/s를 MB/s로 변환
            stats['read'] = (read_delta / time_delta) / 1024 / 1024
            stats['write'] = (write_delta / time_delta) / 1024 / 1024
            
        # 현재 값을 다음 계산을 위해 저장
        last_io_check_time = current_time
        last_bytes_read = current_bytes_read
        last_bytes_written = current_bytes_written
    except Exception as e:
        logging.error(f"disk fale: {e}")
        
    return stats

# ====== Drawing Functions ======

def draw_gpu_section(lcd, x, y, stats):
    if not stats: return y
    bar_width = 180
    
    lcd.DisplayText(stats["name"], x, y, font=FONT_PATH, font_size=15, font_color=GPU_COLOR, background_color=BG_COLOR)
    y += 20
    lcd.DisplayText(f"Util: {int(stats['util']):3d}%", x, y, font=FONT_PATH, font_size=15, font_color=GPU_COLOR, background_color=BG_COLOR)
    y += 18
    lcd.DisplayProgressBar(x, y, bar_width, 8, 0, 100, int(stats['util']), bar_color=GPU_COLOR, bar_outline=True, background_color=BG_COLOR)
    y += 18
    temp_freq_str = f"Temp: {int(stats['temp']):2d}C   Freq: {int(stats['clock']):4d}MHz"
    lcd.DisplayText(f"{temp_freq_str:<30}", x, y, font=FONT_PATH, font_size=15, font_color=GPU_COLOR, background_color=BG_COLOR)
    y += 20
    mem_str = f"Mem: {int(stats['mem_used'])}MB/{int(stats['mem_total'])}MB"
    lcd.DisplayText(f"{mem_str:<30}", x, y, font=FONT_PATH, font_size=15, font_color=GPU_COLOR, background_color=BG_COLOR)
    y += 18
    lcd.DisplayProgressBar(x, y, bar_width, 8, 0, 100, int(stats['mem_percent']), bar_color=GPU_COLOR, bar_outline=True, background_color=BG_COLOR)
    y += 20
    return y

def draw_network_section(lcd, x, y):
    hw_list = handle.Hardware
    net_stats = get_network_stats(hw_list)
    ip_info = get_ip_info()
    max_speed_mbps = 100
    bar_width = 100
    
    adapter_name = net_stats.get('adapter_name', 'N/A')
    adapter_name = net_stats.get('adapter_name', 'N/A')
    if '이더넷' in adapter_name:
        adapter_name = 'Ethernet'
    title_str = f"Network Stats ({adapter_name})"

    lcd.DisplayText(title_str, x, y, font=FONT_PATH, font_size=17, font_color=TEXT_COLOR, background_color=BG_COLOR)
    y += 22

    up_str = f"{net_stats['up']:.1f} Mbps"
    lcd.DisplayText("Up", x, y, font=FONT_PATH, font_size=15, font_color=NET_COLOR, background_color=BG_COLOR)
    lcd.DisplayProgressBar(x + 40, y + BAR_OFFSET, bar_width, 8, 0, max_speed_mbps, net_stats['up'], bar_color=NET_COLOR, bar_outline=True, background_color=BG_COLOR)
    lcd.DisplayText(f"{up_str:<15}", x + 40 + bar_width + 5, y, font=FONT_PATH, font_size=15, font_color=NET_COLOR, background_color=BG_COLOR)
    y += 18

    down_str = f"{net_stats['down']:.1f} Mbps"
    lcd.DisplayText("Down", x, y, font=FONT_PATH, font_size=15, font_color=NET_COLOR, background_color=BG_COLOR)
    lcd.DisplayProgressBar(x + 40, y + BAR_OFFSET, bar_width, 8, 0, max_speed_mbps, net_stats['down'], bar_color=NET_COLOR, bar_outline=True, background_color=BG_COLOR)
    lcd.DisplayText(f"{down_str:<15}", x + 40 + bar_width + 5, y, font=FONT_PATH, font_size=15, font_color=NET_COLOR, background_color=BG_COLOR)
    y += 20
    
    lcd.DisplayText(f"IP: {ip_info['ip']}", x, y, font=FONT_PATH, font_size=15, font_color=NET_COLOR, background_color=BG_COLOR)
    y += 18

    return y

def draw_storage_section(lcd, x, y):
    hw_list = handle.Hardware
    storage_stats = get_storage_stats()
    bar_width = 180
    
    lcd.DisplayText("SSD Stats", x, y, font=FONT_PATH, font_size=17, font_color=TEXT_COLOR, background_color=BG_COLOR)
    y += 22

    # 사용량 (GB / GB 형태)
    usage_str = f"{storage_stats['used_gb']:.0f}GB / {storage_stats['total_gb']:.0f}GB"
    lcd.DisplayText(usage_str, x, y, font=FONT_PATH, font_size=15,
                    font_color=SSD_COLOR, background_color=BG_COLOR)
    y += 20
    lcd.DisplayProgressBar(x, y, bar_width, 8,
                           0, 100, storage_stats['used_pct'],
                           bar_color=SSD_COLOR, bar_outline=True, background_color=BG_COLOR)
    y += 10

    read_str = f"R: {storage_stats['read']:.0f}MB/s"
    write_str = f"W: {storage_stats['write']:.0f}MB/s"
    
    # 읽기 속도는 왼쪽에 고정 (x 좌표 사용)
    # 이전 글씨를 지우기 위해 오른쪽 공백 추가
    lcd.DisplayText(f"{read_str:<15}", x, y, font=FONT_PATH, font_size=15, font_color=SSD_COLOR, background_color=BG_COLOR)
    
    # 쓰기 속도는 오른쪽(x + 120)에 고정
    # 이전 글씨를 지우기 위해 오른쪽 공백 추가
    lcd.DisplayText(f"{write_str:<15}", x + 120, y, font=FONT_PATH, font_size=15, font_color=SSD_COLOR, background_color=BG_COLOR)
    
    return y + 20

def get_uptime_str():
    uptime_ms = ctypes.windll.kernel32.GetTickCount64()
    uptime_sec = uptime_ms // 1000
    days, rem = divmod(uptime_sec, 86400)
    hours, rem = divmod(rem, 3600)
    minutes, seconds = divmod(rem, 60)
    return f"Uptime: {int(days)}d {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"

# ====== Main Drawing and Loop ======

def initialize_display():
    lcd = LcdCommRevA(com_port="AUTO", display_width=DISPLAY_WIDTH, display_height=DISPLAY_HEIGHT)
    lcd.Reset()
    lcd.InitializeComm()
    lcd.SetBrightness(50)
    lcd.SetOrientation(Orientation.LANDSCAPE)
    return lcd

def draw_static_text(lcd):
    lcd.DisplayText("CPU Stats", 8, 5, font=FONT_PATH, font_size=17, font_color=TEXT_COLOR, background_color=BG_COLOR)
    lcd.DisplayText(CPU_NAME, 8, 25, font=FONT_PATH, font_size=15, font_color=CPU_COLOR, background_color=BG_COLOR)
    lcd.DisplayText("GPU Stats", 240, 5, font=FONT_PATH, font_size=17, font_color=TEXT_COLOR, background_color=BG_COLOR)

def draw_dynamic_stats(lcd):
    hw_list = handle.Hardware

    # --- Left Side: CPU and RAM ---
    cpu_load = get_sensor_value(hw_list, Hardware.HardwareType.Cpu, "Load", "CPU Total") or 0.0
    cpu_temp = get_sensor_value(hw_list, Hardware.HardwareType.Cpu, "Temperature", "Core (Tctl/Tdie)") or get_sensor_value(hw_list, Hardware.HardwareType.Cpu, "Temperature", "Package") or 0.0
    
    lcd.DisplayText(f"Total: {int(cpu_load):>3}%  Temp: {int(cpu_temp):>3}C", 8, 45, font=FONT_PATH, font_size=15, font_color=CPU_COLOR, background_color=BG_COLOR)

    y_cpu = 68
    core_loads = get_sorted_core_loads(hw_list)
    bar_width, bar1_x, text1_x, bar2_x, text2_x = 50, 8, 63, 123, 178
    
    for i in range(0, len(core_loads), 2):
        if y_cpu > 240: break 
        load1 = core_loads[i][2]
        lcd.DisplayProgressBar(bar1_x, y_cpu, bar_width, 5, 0, 100, int(load1), bar_color=CPU_COLOR, bar_outline=True, background_color=BG_COLOR)
        lcd.DisplayText(f"{int(load1):>3}%", text1_x, y_cpu-BAR_OFFSET-3, font=FONT_PATH, font_size=12, font_color=CPU_COLOR, background_color=BG_COLOR)
        
        if i + 1 < len(core_loads):
            load2 = core_loads[i+1][2]
            lcd.DisplayProgressBar(bar2_x, y_cpu, bar_width, 5, 0, 100, int(load2), bar_color=CPU_COLOR, bar_outline=True, background_color=BG_COLOR)
            lcd.DisplayText(f"{int(load2):>3}%", text2_x, y_cpu-BAR_OFFSET-3, font=FONT_PATH, font_size=12, font_color=CPU_COLOR, background_color=BG_COLOR)
        
        y_cpu += 10 # ★수정: 코어 줄 간격을 15에서 12로 변경

    mem_used = get_sensor_value(hw_list, Hardware.HardwareType.Memory, "Data", "Memory Used") or 0.0
    mem_total = mem_used + (get_sensor_value(hw_list, Hardware.HardwareType.Memory, "Data", "Memory Available") or 0.0)
    mem_pct = (mem_used / mem_total) * 100 if mem_total > 0 else 0
    y_ram = y_cpu + 5
    
    lcd.DisplayText("RAM Stats", 10, y_ram, font=FONT_PATH, font_size=17, font_color=TEXT_COLOR, background_color=BG_COLOR)
    y_ram += 22
    lcd.DisplayText(f" {int(mem_used*1024):}MB / {int(mem_total*1024):}MB", 1, y_ram, font=FONT_PATH, font_size=15, font_color=RAM_COLOR, background_color=BG_COLOR)
    lcd.DisplayProgressBar(10, y_ram + 20, 180, 8, 0, 100, int(mem_pct), bar_color=RAM_COLOR, bar_outline=True, background_color=BG_COLOR)

    # --- Right Side: GPU, Network, SSD, Time ---
    y_right = 25
    
    gpu_stats_amd = get_gpu_stats(hw_list, "AMD")
    if gpu_stats_amd:
        y_right = draw_gpu_section(lcd, 240, y_right, gpu_stats_amd)
    
    gpu_stats_nvidia = get_gpu_stats(hw_list, "NVIDIA")
    if gpu_stats_nvidia:
        y_right = draw_gpu_section(lcd, 240, y_right, gpu_stats_nvidia)

    y_right = draw_network_section(lcd, 240, y_right + 2)
    y_right = draw_storage_section(lcd, 240, y_right + 10)

    # Uptime and Clock
    now = datetime.now()
    clock_str = now.strftime("%Y-%m-%d %I:%M:%S %p")
    uptime_str = get_uptime_str()
    lcd.DisplayText(uptime_str, 100, 305, font=FONT_PATH, font_size=10, font_color=TEXT_COLOR, background_color=BG_COLOR)
    lcd.DisplayText(clock_str, 240, 305, font=FONT_PATH, font_size=10, font_color=TEXT_COLOR, background_color=BG_COLOR)

def main():
    initialize_hardware_names()
    lcd = initialize_display()
    
    if os.path.exists(BACKGROUND_IMG):
        lcd.DisplayBitmap(BACKGROUND_IMG)
    
    draw_static_text(lcd)
    
    while True:
        draw_dynamic_stats(lcd)
        time.sleep(UPDATE_INTERVAL)

if __name__ == "__main__":
    try:
        main()
    finally:
        handle.Close()