146 lines
4.5 KiB
Python
146 lines
4.5 KiB
Python
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Depends, Query
|
|
from sqlalchemy import desc, func, or_, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.database import get_db
|
|
from app.models import Bill, BillAction, BillBrief, NewsArticle, TrendScore
|
|
from app.schemas.schemas import (
|
|
BillDetailSchema,
|
|
BillSchema,
|
|
BillActionSchema,
|
|
NewsArticleSchema,
|
|
PaginatedResponse,
|
|
TrendScoreSchema,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("", response_model=PaginatedResponse[BillSchema])
|
|
async def list_bills(
|
|
chamber: Optional[str] = Query(None),
|
|
topic: Optional[str] = Query(None),
|
|
sponsor_id: Optional[str] = Query(None),
|
|
q: Optional[str] = Query(None),
|
|
page: int = Query(1, ge=1),
|
|
per_page: int = Query(20, ge=1, le=100),
|
|
sort: str = Query("latest_action_date"),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
query = (
|
|
select(Bill)
|
|
.options(
|
|
selectinload(Bill.sponsor),
|
|
selectinload(Bill.briefs),
|
|
selectinload(Bill.trend_scores),
|
|
)
|
|
)
|
|
|
|
if chamber:
|
|
query = query.where(Bill.chamber == chamber)
|
|
if sponsor_id:
|
|
query = query.where(Bill.sponsor_id == sponsor_id)
|
|
if topic:
|
|
query = query.join(BillBrief, Bill.bill_id == BillBrief.bill_id).where(
|
|
BillBrief.topic_tags.contains([topic])
|
|
)
|
|
if q:
|
|
query = query.where(
|
|
or_(
|
|
Bill.bill_id.ilike(f"%{q}%"),
|
|
Bill.title.ilike(f"%{q}%"),
|
|
Bill.short_title.ilike(f"%{q}%"),
|
|
)
|
|
)
|
|
|
|
# Count total
|
|
count_query = select(func.count()).select_from(query.subquery())
|
|
total = await db.scalar(count_query) or 0
|
|
|
|
# Sort
|
|
sort_col = getattr(Bill, sort, Bill.latest_action_date)
|
|
query = query.order_by(desc(sort_col)).offset((page - 1) * per_page).limit(per_page)
|
|
|
|
result = await db.execute(query)
|
|
bills = result.scalars().unique().all()
|
|
|
|
# Attach latest brief and trend to each bill
|
|
items = []
|
|
for bill in bills:
|
|
bill_dict = BillSchema.model_validate(bill)
|
|
if bill.briefs:
|
|
bill_dict.latest_brief = bill.briefs[0]
|
|
if bill.trend_scores:
|
|
bill_dict.latest_trend = bill.trend_scores[0]
|
|
items.append(bill_dict)
|
|
|
|
return PaginatedResponse(
|
|
items=items,
|
|
total=total,
|
|
page=page,
|
|
per_page=per_page,
|
|
pages=max(1, (total + per_page - 1) // per_page),
|
|
)
|
|
|
|
|
|
@router.get("/{bill_id}", response_model=BillDetailSchema)
|
|
async def get_bill(bill_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(Bill)
|
|
.options(
|
|
selectinload(Bill.sponsor),
|
|
selectinload(Bill.actions),
|
|
selectinload(Bill.briefs),
|
|
selectinload(Bill.news_articles),
|
|
selectinload(Bill.trend_scores),
|
|
)
|
|
.where(Bill.bill_id == bill_id)
|
|
)
|
|
bill = result.scalar_one_or_none()
|
|
if not bill:
|
|
from fastapi import HTTPException
|
|
raise HTTPException(status_code=404, detail="Bill not found")
|
|
|
|
detail = BillDetailSchema.model_validate(bill)
|
|
if bill.briefs:
|
|
detail.latest_brief = bill.briefs[0]
|
|
if bill.trend_scores:
|
|
detail.latest_trend = bill.trend_scores[0]
|
|
return detail
|
|
|
|
|
|
@router.get("/{bill_id}/actions", response_model=list[BillActionSchema])
|
|
async def get_bill_actions(bill_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(BillAction)
|
|
.where(BillAction.bill_id == bill_id)
|
|
.order_by(desc(BillAction.action_date))
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/{bill_id}/news", response_model=list[NewsArticleSchema])
|
|
async def get_bill_news(bill_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(NewsArticle)
|
|
.where(NewsArticle.bill_id == bill_id)
|
|
.order_by(desc(NewsArticle.published_at))
|
|
.limit(20)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/{bill_id}/trend", response_model=list[TrendScoreSchema])
|
|
async def get_bill_trend(bill_id: str, days: int = Query(30, ge=7, le=365), db: AsyncSession = Depends(get_db)):
|
|
from datetime import date, timedelta
|
|
cutoff = date.today() - timedelta(days=days)
|
|
result = await db.execute(
|
|
select(TrendScore)
|
|
.where(TrendScore.bill_id == bill_id, TrendScore.score_date >= cutoff)
|
|
.order_by(TrendScore.score_date)
|
|
)
|
|
return result.scalars().all()
|