import logging

from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from fastapi import Depends, HTTPException, status
from jose import jwt, JWTError
from sqlalchemy.orm import selectinload

from src.config import settings
from src.user.models import User
from src.database import get_db

bearer_scheme = HTTPBearer()

logger = logging.getLogger(__name__)


async def get_current_user(token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
                           session: AsyncSession = Depends(get_db)):
    try:
        payload = jwt.decode(token.credentials, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
        user_id = payload.get("user_id")
        if not user_id:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")

        result = await session.execute(
            select(User).options(selectinload(User.role)).where(User.user_id == user_id)
        )
        user = result.scalar_one_or_none()
        if not user:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
        return user
    except JWTError as e:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")

async def decode_token(token: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
    if token:
        try:
            return jwt.decode(token.credentials, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
        except JWTError as e:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
    else:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token Required")

async def get_current_user_optional(token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
                           session: AsyncSession = Depends(get_db)):
    try:
        payload = jwt.decode(token.credentials, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
        user_id = payload.get("user_id")

        result = await session.execute(
            select(User).options(selectinload(User.role)).where(User.user_id == user_id)
        )
        user = result.scalar_one_or_none()
        if not user:
            return None
        return user
    except JWTError as e:
        return None


async def decode_token_optional(token: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False))):
    if not token:
        return None
    try:
        return jwt.decode(token.credentials, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
    except JWTError as e:
        print(e)

def require_role(*allowed_roles: str):
    async def wrapper(user: User = Depends(get_current_user)):
        role_name = user.role.name if user.role else None
        if role_name not in allowed_roles:
            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
        return user
    return wrapper
