50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.database import get_db
|
|
from app.models import Follow
|
|
from app.schemas.schemas import FollowCreate, FollowSchema
|
|
|
|
router = APIRouter()
|
|
|
|
VALID_FOLLOW_TYPES = {"bill", "member", "topic"}
|
|
|
|
|
|
@router.get("", response_model=list[FollowSchema])
|
|
async def list_follows(db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Follow).order_by(Follow.created_at.desc()))
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.post("", response_model=FollowSchema, status_code=201)
|
|
async def add_follow(body: FollowCreate, db: AsyncSession = Depends(get_db)):
|
|
if body.follow_type not in VALID_FOLLOW_TYPES:
|
|
raise HTTPException(status_code=400, detail=f"follow_type must be one of {VALID_FOLLOW_TYPES}")
|
|
follow = Follow(follow_type=body.follow_type, follow_value=body.follow_value)
|
|
db.add(follow)
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(follow)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
# Already following — return existing
|
|
result = await db.execute(
|
|
select(Follow).where(
|
|
Follow.follow_type == body.follow_type,
|
|
Follow.follow_value == body.follow_value,
|
|
)
|
|
)
|
|
return result.scalar_one()
|
|
return follow
|
|
|
|
|
|
@router.delete("/{follow_id}", status_code=204)
|
|
async def remove_follow(follow_id: int, db: AsyncSession = Depends(get_db)):
|
|
follow = await db.get(Follow, follow_id)
|
|
if not follow:
|
|
raise HTTPException(status_code=404, detail="Follow not found")
|
|
await db.delete(follow)
|
|
await db.commit()
|