from typing import Type, Optional

from fastapi import Depends, HTTPException, APIRouter
from sqlalchemy.orm import Session

from ..database import session_provider
from ..models.base import Base
from ..models.hotspot_analysis.hotspot_analysis import HotspotAnalysisCreate, HotspotAnalysisInDB, HotspotAnalysis, \
    HotspotAnalysisBaseCreate, HotspotAnalysisCustomField, HotspotAnalysisCustomChartRow
from ..models.hotspot_analysis.region import RegionCreate, Region, RegionInDB, RegionSubmitThreads, \
    RegionPresentThreads, RegionBaseCreate, RegionCustomFields
from ..models.hotspot_analysis.region_thread import RegionThreadCreate, RegionThreadInDB, RegionThreadBaseCreate, \
    RegionThread, RegionThreadSearchInDB, RegionThreadSearchQuery
from ..models.hotspot_analysis.thread_hotspots import ThreadHotspotCollectionBaseCreate, ThreadHotspotRecord, \
    ThreadHotspotCollection
from ..models.models import Report, FrameTimes, RegionAppDxgKernelProfileRanges, RegionEtwEvents, RegionIssues, \
    RegionMetrics, ThreadCallstack

router = APIRouter()


@router.get("/hotspot_analysis/{analysis_id}", tags=['Hotspot Analysis'], response_model=HotspotAnalysisInDB)
def get_hotspot_analysis(analysis_id: int, db: Session = Depends(session_provider.get_session)):
    analysis_db: HotspotAnalysis = db.query(HotspotAnalysis).filter(HotspotAnalysis.id == analysis_id).first()
    if not analysis_db:
        raise HTTPException(status_code=404, detail=f"Analysis with id={analysis_id} not found")
    return HotspotAnalysisInDB.model_validate(analysis_db)


@router.post("/hotspot_analysis/add", tags=['Hotspot Analysis'], response_model=HotspotAnalysisInDB)
def create_hotspot_analysis(hotspot_analysis: HotspotAnalysisCreate, db: Session = Depends(session_provider.get_session)):
    db_report: Report = db.query(Report).filter(Report.id == hotspot_analysis.report_id).first()
    if not db_report:
        raise HTTPException(status_code=404, detail=f"Report with id={hotspot_analysis.report_id} not found")

    db_hotspot_analysis = HotspotAnalysis(**hotspot_analysis.model_dump(include=set(HotspotAnalysisBaseCreate.model_fields.keys())))

    try:
        db.add(db_hotspot_analysis)
        db.flush()

        for data in hotspot_analysis.frame_times or []:
            db_data = FrameTimes(hotspot_analysis_id=db_hotspot_analysis.id, **data.model_dump())
            db.add(db_data)

        for data in hotspot_analysis.custom_fields or []:
            db_data = HotspotAnalysisCustomField(hotspot_analysis_id=db_hotspot_analysis.id, **data.model_dump())
            db.add(db_data)

        for data in hotspot_analysis.custom_chart_rows or []:
            db_data = HotspotAnalysisCustomChartRow(hotspot_analysis_id=db_hotspot_analysis.id, **data.model_dump())
            db.add(db_data)
    except Exception as e:
        db.rollback()
        raise HTTPException(status_code=500, detail=f"Failed to add analysis data: {str(e)}")

    db.commit()
    db.refresh(db_hotspot_analysis)
    return HotspotAnalysisInDB.model_validate(db_hotspot_analysis)


@router.get("/hotspot_analysis/region/{region_id}", tags=['Hotspot Analysis'], response_model=RegionInDB)
def get_region(region_id: int, db: Session = Depends(session_provider.get_session)):
    region_db: Region = db.query(Region).filter(Region.id == region_id).first()
    if not region_db:
        raise HTTPException(status_code=404, detail=f"Region with id={region_id} not found")
    return RegionInDB.model_validate(region_db)


@router.post("/hotspot_analysis/region/add", tags=['Hotspot Analysis'], response_model=RegionInDB)
def create_region(region: RegionCreate, db: Session = Depends(session_provider.get_session)):
    db_analysis: Optional[HotspotAnalysis] = db.query(HotspotAnalysis).filter(HotspotAnalysis.id == region.hotspot_analysis_id).first()
    if not db_analysis:
        raise HTTPException(status_code=404, detail=f"Analysis with id={region.hotspot_analysis_id} not found")

    db_region = Region(**region.model_dump(include=set(RegionBaseCreate.model_fields.keys())))
    try:
        db.add(db_region)
        db.flush()

        for data in region.submit_threads or []:
            db_data = RegionSubmitThreads(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

        for data in region.present_threads or []:
            db_data = RegionPresentThreads(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

        for data in region.region_app_dxg_kernel_profile_ranges or []:
            db_data = RegionAppDxgKernelProfileRanges(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

        for data in region.region_etw_events or []:
            db_data = RegionEtwEvents(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

        for data in region.region_issues or []:
            db_data = RegionIssues(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

        for data in region.custom_fields or []:
            db_data = RegionCustomFields(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

        for data in region.region_metrics or []:
            db_data = RegionMetrics(region_id=db_region.id, **data.model_dump(exclude={'id'}))
            db.add(db_data)

    except Exception as e:
        db.rollback()
        raise HTTPException(status_code=500, detail=f"Failed to add region data: {str(e)}")

    db.commit()
    db.refresh(db_region)
    return RegionInDB.model_validate(db_region)


def update_region_related_table(db_region: Region, region_update: RegionCreate.model_as_partial(), related_table_name: str, model: Type[Base], db: Session):
    related_table_data = getattr(region_update, related_table_name)

    for data in related_table_data or []:
        # Add new item
        db_data = model(region_id=db_region.id, **data.model_dump(exclude={'id'}))
        db.add(db_data)


@router.patch("/hotspot_analysis/region/{region_id}", tags=['Hotspot Analysis'], response_model=RegionInDB)
def patch_region(region_id: int, region_update: RegionCreate.model_as_partial(), db: Session = Depends(session_provider.get_session)):
    db_region: Region = db.query(Region).filter(Region.id == region_id).first()
    if not db_region:
        raise HTTPException(status_code=404, detail=f"Region with id={region_id} not found")

    try:
        for field, value in region_update.model_dump(include=set(RegionBaseCreate.model_fields.keys()), exclude={'hotspot_analysis_id'}).items():
            if hasattr(db_region, field) and value is not None:
                setattr(db_region, field, value)

        # Update related tables
        # Submit threads
        update_region_related_table(db_region, region_update, 'submit_threads', RegionSubmitThreads, db)

        # Present threads
        update_region_related_table(db_region, region_update, 'present_threads', RegionPresentThreads, db)

        # region_app_dxg_kernel_profile_ranges
        update_region_related_table(db_region, region_update, 'region_app_dxg_kernel_profile_ranges', RegionAppDxgKernelProfileRanges, db)

        # region_etw_events
        update_region_related_table(db_region, region_update, 'region_etw_events', RegionEtwEvents, db)

        # region_issues
        update_region_related_table(db_region, region_update, 'region_issues', RegionIssues, db)

        # custom_fields
        update_region_related_table(db_region, region_update, 'custom_fields', RegionCustomFields, db)

        # region_metrics
        update_region_related_table(db_region, region_update, 'region_metrics', RegionMetrics, db)

    except Exception as e:
        db.rollback()
        raise HTTPException(status_code=500, detail=f"Failed to update region data: {str(e)}")

    db.commit()
    db.refresh(db_region)
    return RegionInDB.model_validate(db_region)


@router.get("/hotspot_analysis/thread/{thread_id}", tags=['Hotspot Analysis'], response_model=RegionThreadInDB)
def get_region_thread(thread_id: int, db: Session = Depends(session_provider.get_session)):
    thread_db: Region = db.query(RegionThread).filter(RegionThread.id == thread_id).first()
    if not thread_db:
        raise HTTPException(status_code=404, detail=f"Thread with id={thread_id} not found")
    return RegionThreadInDB.model_validate(thread_db)


@router.post("/hotspot_analysis/thread/search", tags=['Hotspot Analysis'], response_model=RegionThreadSearchInDB)
def search_region_thread(query: RegionThreadSearchQuery, db: Session = Depends(session_provider.get_session)):
    thread_query = db.query(RegionThread)

    # Apply filters
    if query.region_id is not None:
        thread_query = thread_query.filter(RegionThread.region_id == query.region_id)
    if query.thread_id is not None:
        thread_query = thread_query.filter(RegionThread.thread_id == query.thread_id)
    if query.hotspot_analysis_id is not None:
        thread_query = thread_query.join(Region).filter(Region.hotspot_analysis_id == query.hotspot_analysis_id)

    # Total count before pagination
    total_count = thread_query.count()

    # Execute query and return results
    region_threads = thread_query.offset(query.offset).limit(query.limit).all()
    return RegionThreadSearchInDB(
        results=[RegionThreadInDB.model_validate(region_thread) for region_thread in region_threads],
        count=total_count
    )


@router.post("/hotspot_analysis/thread/add", tags=['Hotspot Analysis'], response_model=RegionThreadInDB)
def create_region_thread(region_thread: RegionThreadCreate, db: Session = Depends(session_provider.get_session)):
    db_region: Optional[Region] = db.query(Region).filter(Region.id == region_thread.region_id).first()
    if not db_region:
        raise HTTPException(status_code=404, detail=f"Region with id={region_thread.region_id} not found")

    db_region_thread = RegionThread(**region_thread.model_dump(include=set(RegionThreadBaseCreate.model_fields.keys())))
    try:
        db.add(db_region_thread)
        db.flush()

        # Bulk insert ThreadCallstack for performance optimisation
        # db.bulk_insert_mappings(ThreadCallstack, [{**data.model_dump(), 'region_thread_id': db_region_thread.id} for data in region_thread.thread_callstack])
        # db.flush()

        for data in region_thread.thread_callstack:
            db_data = ThreadCallstack(region_thread_id=db_region_thread.id, **data.model_dump())
            db.add(db_data)

        for data in region_thread.thread_hotspot_collections:
            db_data = ThreadHotspotCollection(region_thread_id=db_region_thread.id, **data.model_dump(include=set(ThreadHotspotCollectionBaseCreate.model_fields.keys())))
            db.add(db_data)
            db.flush()
            for row in data.thread_hotspot_records:
                db_row = ThreadHotspotRecord(thread_hotspot_collection_id=db_data.id, **row.model_dump())
                db.add(db_row)

    except Exception as e:
        db.rollback()
        raise HTTPException(status_code=500, detail=f"Failed to add region data: {str(e)}")

    db.commit()
    db.refresh(db_region_thread)
    return RegionThreadInDB.model_validate(db_region_thread)
