from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.dependencies import get_current_user from app.database import get_db from app.models import Follow from app.models.user import User from app.schemas.schemas import FollowCreate, FollowModeUpdate, FollowSchema router = APIRouter() VALID_FOLLOW_TYPES = {"bill", "member", "topic"} VALID_MODES = {"neutral", "pocket_veto", "pocket_boost"} @router.get("", response_model=list[FollowSchema]) async def list_follows( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): result = await db.execute( select(Follow) .where(Follow.user_id == current_user.id) .order_by(Follow.created_at.desc()) ) return result.scalars().all() @router.post("", response_model=FollowSchema, status_code=201) async def add_follow( body: FollowCreate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): if body.follow_type not in VALID_FOLLOW_TYPES: raise HTTPException(status_code=400, detail=f"follow_type must be one of {VALID_FOLLOW_TYPES}") follow = Follow( user_id=current_user.id, follow_type=body.follow_type, follow_value=body.follow_value, ) db.add(follow) try: await db.commit() await db.refresh(follow) except IntegrityError: await db.rollback() # Already following — return existing result = await db.execute( select(Follow).where( Follow.user_id == current_user.id, Follow.follow_type == body.follow_type, Follow.follow_value == body.follow_value, ) ) return result.scalar_one() return follow @router.patch("/{follow_id}/mode", response_model=FollowSchema) async def update_follow_mode( follow_id: int, body: FollowModeUpdate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): if body.follow_mode not in VALID_MODES: raise HTTPException(status_code=400, detail=f"follow_mode must be one of {VALID_MODES}") follow = await db.get(Follow, follow_id) if not follow: raise HTTPException(status_code=404, detail="Follow not found") if follow.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not your follow") follow.follow_mode = body.follow_mode await db.commit() await db.refresh(follow) return follow @router.delete("/{follow_id}", status_code=204) async def remove_follow( follow_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): follow = await db.get(Follow, follow_id) if not follow: raise HTTPException(status_code=404, detail="Follow not found") if follow.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not your follow") await db.delete(follow) await db.commit()