from datetime import date, timedelta from fastapi import Depends from fastapi import APIRouter from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.core.dependencies import get_current_user from app.database import get_db from app.models import Bill, BillBrief, Follow, TrendScore from app.models.user import User from app.schemas.schemas import BillSchema router = APIRouter() @router.get("") async def get_dashboard( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): # Load follows for the current user follows_result = await db.execute( select(Follow).where(Follow.user_id == current_user.id) ) follows = follows_result.scalars().all() followed_bill_ids = [f.follow_value for f in follows if f.follow_type == "bill"] followed_member_ids = [f.follow_value for f in follows if f.follow_type == "member"] followed_topics = [f.follow_value for f in follows if f.follow_type == "topic"] feed_bills: list[Bill] = [] seen_ids: set[str] = set() # 1. Directly followed bills if followed_bill_ids: result = await db.execute( select(Bill) .options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores)) .where(Bill.bill_id.in_(followed_bill_ids)) .order_by(desc(Bill.latest_action_date)) .limit(20) ) for bill in result.scalars().all(): if bill.bill_id not in seen_ids: feed_bills.append(bill) seen_ids.add(bill.bill_id) # 2. Bills from followed members if followed_member_ids: result = await db.execute( select(Bill) .options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores)) .where(Bill.sponsor_id.in_(followed_member_ids)) .order_by(desc(Bill.latest_action_date)) .limit(20) ) for bill in result.scalars().all(): if bill.bill_id not in seen_ids: feed_bills.append(bill) seen_ids.add(bill.bill_id) # 3. Bills matching followed topics for topic in followed_topics: result = await db.execute( select(Bill) .options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores)) .join(BillBrief, Bill.bill_id == BillBrief.bill_id) .where(BillBrief.topic_tags.contains([topic])) .order_by(desc(Bill.latest_action_date)) .limit(10) ) for bill in result.scalars().all(): if bill.bill_id not in seen_ids: feed_bills.append(bill) seen_ids.add(bill.bill_id) # Sort feed by latest action date feed_bills.sort(key=lambda b: b.latest_action_date or date.min, reverse=True) # 4. Trending bills (top 10 by composite score today) trending_result = await db.execute( select(Bill) .options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores)) .join(TrendScore, Bill.bill_id == TrendScore.bill_id) .where(TrendScore.score_date >= date.today() - timedelta(days=1)) .order_by(desc(TrendScore.composite_score)) .limit(10) ) trending_bills = trending_result.scalars().unique().all() def serialize_bill(bill: Bill) -> dict: b = BillSchema.model_validate(bill) if bill.briefs: b.latest_brief = bill.briefs[0] if bill.trend_scores: b.latest_trend = bill.trend_scores[0] return b.model_dump() return { "feed": [serialize_bill(b) for b in feed_bills[:50]], "trending": [serialize_bill(b) for b in trending_bills], "follows": { "bills": len(followed_bill_ids), "members": len(followed_member_ids), "topics": len(followed_topics), }, }