from contextvars import ContextVar
from fastapi import Depends, Request
from sqlalchemy.orm import Session
from logger import logger

# models
from model.db import get_db
from model.user import User

# schema
from schema.user import UserTokenData, UserModel, UserStatus

# utils
from utils.exceptions import AuthException

# defining the context variables to store different types of required data
context_db_session: ContextVar[Session] = ContextVar("db_session", default=None)
context_user_data: ContextVar[UserTokenData] = ContextVar("user_data", default=None)

context_set_db_session_rollback: ContextVar[bool] = ContextVar(
    "set_db_session_rollback", default=False
)


# whenever an api is hit, define the context variables for it
async def build_request_context(db: Session = Depends(get_db)):

    context_db_session.set(db)

    # fetch the token from context and check if the user is active or not
    user_data_from_context: UserTokenData = context_user_data.get()

    if user_data_from_context:
        user: UserModel = User.get_by_uuid(user_data_from_context.uuid)
        error_message = None

        if not user:
            error_message = "Invalid authentication credentials, user not found"
        elif user.status != UserStatus.ACTIVE.value:
            error_message = "Invalid authentication credentials, user is not active"

        if error_message:
            raise AuthException(status_code=401, message=error_message)

    logger.info(extra=context_user_data.get(), msg="REQUEST_INITIATED")


# get the same session everywhere
# the db session is stored in context at the time of the building request context
def get_db_session() -> Session:
    return context_db_session.get()
