feat: PocketVeto v1.0.0 — initial public release
Self-hosted US Congress monitoring platform with AI policy briefs, bill/member/topic follows, ntfy + RSS + email notifications, alignment scoring, collections, and draft-letter generator. Authored by: Jack Levy
This commit is contained in:
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
497
backend/app/api/admin.py
Normal file
497
backend/app/api/admin.py
Normal file
@@ -0,0 +1,497 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_admin
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillBrief, BillDocument, Follow
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── User Management ───────────────────────────────────────────────────────────
|
||||
|
||||
class UserWithStats(UserResponse):
|
||||
follow_count: int
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserWithStats])
|
||||
async def list_users(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""List all users with their follow counts."""
|
||||
users_result = await db.execute(select(User).order_by(User.created_at))
|
||||
users = users_result.scalars().all()
|
||||
|
||||
counts_result = await db.execute(
|
||||
select(Follow.user_id, func.count(Follow.id).label("cnt"))
|
||||
.group_by(Follow.user_id)
|
||||
)
|
||||
counts = {row.user_id: row.cnt for row in counts_result}
|
||||
|
||||
return [
|
||||
UserWithStats(
|
||||
id=u.id,
|
||||
email=u.email,
|
||||
is_admin=u.is_admin,
|
||||
notification_prefs=u.notification_prefs or {},
|
||||
created_at=u.created_at,
|
||||
follow_count=counts.get(u.id, 0),
|
||||
)
|
||||
for u in users
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=204)
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Delete a user account (cascades to their follows). Cannot delete yourself."""
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own account")
|
||||
user = await db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}/toggle-admin", response_model=UserResponse)
|
||||
async def toggle_admin(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Promote or demote a user's admin status."""
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot change your own admin status")
|
||||
user = await db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user.is_admin = not user.is_admin
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
# ── Analysis Stats ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Return analysis pipeline progress counters."""
|
||||
total_bills = (await db.execute(select(func.count()).select_from(Bill))).scalar()
|
||||
docs_fetched = (await db.execute(
|
||||
select(func.count()).select_from(BillDocument).where(BillDocument.raw_text.isnot(None))
|
||||
)).scalar()
|
||||
total_briefs = (await db.execute(select(func.count()).select_from(BillBrief))).scalar()
|
||||
full_briefs = (await db.execute(
|
||||
select(func.count()).select_from(BillBrief).where(BillBrief.brief_type == "full")
|
||||
)).scalar()
|
||||
amendment_briefs = (await db.execute(
|
||||
select(func.count()).select_from(BillBrief).where(BillBrief.brief_type == "amendment")
|
||||
)).scalar()
|
||||
uncited_briefs = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bill_briefs
|
||||
WHERE key_points IS NOT NULL
|
||||
AND jsonb_array_length(key_points) > 0
|
||||
AND jsonb_typeof(key_points->0) = 'string'
|
||||
""")
|
||||
)).scalar()
|
||||
# Bills with null sponsor
|
||||
bills_missing_sponsor = (await db.execute(
|
||||
text("SELECT COUNT(*) FROM bills WHERE sponsor_id IS NULL")
|
||||
)).scalar()
|
||||
# Bills with null metadata (introduced_date / chamber / congress_url)
|
||||
bills_missing_metadata = (await db.execute(
|
||||
text("SELECT COUNT(*) FROM bills WHERE introduced_date IS NULL OR chamber IS NULL OR congress_url IS NULL")
|
||||
)).scalar()
|
||||
# Bills with no document record at all (text not yet published on GovInfo)
|
||||
no_text_bills = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bills b
|
||||
LEFT JOIN bill_documents bd ON bd.bill_id = b.bill_id
|
||||
WHERE bd.id IS NULL
|
||||
""")
|
||||
)).scalar()
|
||||
# Documents that have text but no brief (LLM not yet run / failed)
|
||||
pending_llm = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bill_documents bd
|
||||
LEFT JOIN bill_briefs bb ON bb.document_id = bd.id
|
||||
WHERE bb.id IS NULL AND bd.raw_text IS NOT NULL
|
||||
""")
|
||||
)).scalar()
|
||||
# Bills that have never had their action history fetched
|
||||
bills_missing_actions = (await db.execute(
|
||||
text("SELECT COUNT(*) FROM bills WHERE actions_fetched_at IS NULL")
|
||||
)).scalar()
|
||||
# Cited brief points (objects) that have no label yet
|
||||
unlabeled_briefs = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bill_briefs
|
||||
WHERE (
|
||||
key_points IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements(key_points) AS p
|
||||
WHERE jsonb_typeof(p) = 'object' AND (p->>'label') IS NULL
|
||||
)
|
||||
) OR (
|
||||
risks IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements(risks) AS r
|
||||
WHERE jsonb_typeof(r) = 'object' AND (r->>'label') IS NULL
|
||||
)
|
||||
)
|
||||
""")
|
||||
)).scalar()
|
||||
return {
|
||||
"total_bills": total_bills,
|
||||
"docs_fetched": docs_fetched,
|
||||
"briefs_generated": total_briefs,
|
||||
"full_briefs": full_briefs,
|
||||
"amendment_briefs": amendment_briefs,
|
||||
"uncited_briefs": uncited_briefs,
|
||||
"no_text_bills": no_text_bills,
|
||||
"pending_llm": pending_llm,
|
||||
"bills_missing_sponsor": bills_missing_sponsor,
|
||||
"bills_missing_metadata": bills_missing_metadata,
|
||||
"bills_missing_actions": bills_missing_actions,
|
||||
"unlabeled_briefs": unlabeled_briefs,
|
||||
"remaining": total_bills - total_briefs,
|
||||
}
|
||||
|
||||
|
||||
# ── Celery Tasks ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/backfill-citations")
|
||||
async def backfill_citations(current_user: User = Depends(get_current_admin)):
|
||||
"""Delete pre-citation briefs and re-queue LLM processing using stored document text."""
|
||||
from app.workers.llm_processor import backfill_brief_citations
|
||||
task = backfill_brief_citations.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-sponsors")
|
||||
async def backfill_sponsors(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import backfill_sponsor_ids
|
||||
task = backfill_sponsor_ids.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-poll")
|
||||
async def trigger_poll(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import poll_congress_bills
|
||||
task = poll_congress_bills.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-member-sync")
|
||||
async def trigger_member_sync(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import sync_members
|
||||
task = sync_members.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-fetch-actions")
|
||||
async def trigger_fetch_actions(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import fetch_actions_for_active_bills
|
||||
task = fetch_actions_for_active_bills.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-all-actions")
|
||||
async def backfill_all_actions(current_user: User = Depends(get_current_admin)):
|
||||
"""Queue action fetches for every bill that has never had actions fetched."""
|
||||
from app.workers.congress_poller import backfill_all_bill_actions
|
||||
task = backfill_all_bill_actions.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-metadata")
|
||||
async def backfill_metadata(current_user: User = Depends(get_current_admin)):
|
||||
"""Fill in null introduced_date, congress_url, chamber for existing bills."""
|
||||
from app.workers.congress_poller import backfill_bill_metadata
|
||||
task = backfill_bill_metadata.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-labels")
|
||||
async def backfill_labels(current_user: User = Depends(get_current_admin)):
|
||||
"""Classify existing cited brief points as fact or inference without re-generating briefs."""
|
||||
from app.workers.llm_processor import backfill_brief_labels
|
||||
task = backfill_brief_labels.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-cosponsors")
|
||||
async def backfill_cosponsors(current_user: User = Depends(get_current_admin)):
|
||||
"""Fetch co-sponsor data from Congress.gov for all bills that haven't been fetched yet."""
|
||||
from app.workers.bill_classifier import backfill_all_bill_cosponsors
|
||||
task = backfill_all_bill_cosponsors.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-categories")
|
||||
async def backfill_categories(current_user: User = Depends(get_current_admin)):
|
||||
"""Classify all bills with text but no category as substantive/commemorative/administrative."""
|
||||
from app.workers.bill_classifier import backfill_bill_categories
|
||||
task = backfill_bill_categories.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/calculate-effectiveness")
|
||||
async def calculate_effectiveness(current_user: User = Depends(get_current_admin)):
|
||||
"""Recalculate member effectiveness scores and percentiles now."""
|
||||
from app.workers.bill_classifier import calculate_effectiveness_scores
|
||||
task = calculate_effectiveness_scores.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/resume-analysis")
|
||||
async def resume_analysis(current_user: User = Depends(get_current_admin)):
|
||||
"""Re-queue LLM processing for docs with no brief, and document fetching for bills with no doc."""
|
||||
from app.workers.llm_processor import resume_pending_analysis
|
||||
task = resume_pending_analysis.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-weekly-digest")
|
||||
async def trigger_weekly_digest(current_user: User = Depends(get_current_admin)):
|
||||
"""Send the weekly bill activity summary to all eligible users now."""
|
||||
from app.workers.notification_dispatcher import send_weekly_digest
|
||||
task = send_weekly_digest.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-trend-scores")
|
||||
async def trigger_trend_scores(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.trend_scorer import calculate_all_trend_scores
|
||||
task = calculate_all_trend_scores.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/bills/{bill_id}/reprocess")
|
||||
async def reprocess_bill(bill_id: str, current_user: User = Depends(get_current_admin)):
|
||||
"""Queue document and action fetches for a specific bill. Useful for debugging."""
|
||||
from app.workers.document_fetcher import fetch_bill_documents
|
||||
from app.workers.congress_poller import fetch_bill_actions
|
||||
doc_task = fetch_bill_documents.delay(bill_id)
|
||||
actions_task = fetch_bill_actions.delay(bill_id)
|
||||
return {"task_ids": {"documents": doc_task.id, "actions": actions_task.id}}
|
||||
|
||||
|
||||
@router.get("/newsapi-quota")
|
||||
async def get_newsapi_quota(current_user: User = Depends(get_current_admin)):
|
||||
"""Return today's remaining NewsAPI daily quota (calls used vs. 100/day limit)."""
|
||||
from app.services.news_service import get_newsapi_quota_remaining
|
||||
import asyncio
|
||||
remaining = await asyncio.to_thread(get_newsapi_quota_remaining)
|
||||
return {"remaining": remaining, "limit": 100}
|
||||
|
||||
|
||||
@router.post("/clear-gnews-cache")
|
||||
async def clear_gnews_cache_endpoint(current_user: User = Depends(get_current_admin)):
|
||||
"""Flush the Google News RSS Redis cache so fresh data is fetched on next run."""
|
||||
from app.services.news_service import clear_gnews_cache
|
||||
import asyncio
|
||||
cleared = await asyncio.to_thread(clear_gnews_cache)
|
||||
return {"cleared": cleared}
|
||||
|
||||
|
||||
@router.post("/submit-llm-batch")
|
||||
async def submit_llm_batch_endpoint(current_user: User = Depends(get_current_admin)):
|
||||
"""Submit all unbriefed documents to the Batch API (OpenAI/Anthropic only)."""
|
||||
from app.workers.llm_batch_processor import submit_llm_batch
|
||||
task = submit_llm_batch.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.get("/llm-batch-status")
|
||||
async def get_llm_batch_status(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Return the current batch job state, or no_active_batch if none."""
|
||||
import json
|
||||
from app.models.setting import AppSetting
|
||||
row = await db.get(AppSetting, "llm_active_batch")
|
||||
if not row:
|
||||
return {"status": "no_active_batch"}
|
||||
try:
|
||||
return json.loads(row.value)
|
||||
except Exception:
|
||||
return {"status": "unknown"}
|
||||
|
||||
|
||||
@router.get("/api-health")
|
||||
async def api_health(current_user: User = Depends(get_current_admin)):
|
||||
"""Test each external API and return status + latency for each."""
|
||||
import asyncio
|
||||
results = await asyncio.gather(
|
||||
asyncio.to_thread(_test_congress),
|
||||
asyncio.to_thread(_test_govinfo),
|
||||
asyncio.to_thread(_test_newsapi),
|
||||
asyncio.to_thread(_test_gnews),
|
||||
asyncio.to_thread(_test_rep_lookup),
|
||||
return_exceptions=True,
|
||||
)
|
||||
keys = ["congress_gov", "govinfo", "newsapi", "google_news", "rep_lookup"]
|
||||
return {
|
||||
k: r if isinstance(r, dict) else {"status": "error", "detail": str(r)}
|
||||
for k, r in zip(keys, results)
|
||||
}
|
||||
|
||||
|
||||
def _timed(fn):
|
||||
"""Run fn(), return its dict merged with latency_ms."""
|
||||
import time as _time
|
||||
t0 = _time.perf_counter()
|
||||
result = fn()
|
||||
result["latency_ms"] = round((_time.perf_counter() - t0) * 1000)
|
||||
return result
|
||||
|
||||
|
||||
def _test_congress() -> dict:
|
||||
from app.config import settings
|
||||
from app.services import congress_api
|
||||
if not settings.DATA_GOV_API_KEY:
|
||||
return {"status": "error", "detail": "DATA_GOV_API_KEY not configured"}
|
||||
def _call():
|
||||
data = congress_api.get_bills(119, limit=1)
|
||||
count = data.get("pagination", {}).get("count") or len(data.get("bills", []))
|
||||
return {"status": "ok", "detail": f"{count:,} bills available in 119th Congress"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_govinfo() -> dict:
|
||||
from app.config import settings
|
||||
import requests as req
|
||||
if not settings.DATA_GOV_API_KEY:
|
||||
return {"status": "error", "detail": "DATA_GOV_API_KEY not configured"}
|
||||
def _call():
|
||||
# /collections lists all available collections — simple health check endpoint
|
||||
resp = req.get(
|
||||
"https://api.govinfo.gov/collections",
|
||||
params={"api_key": settings.DATA_GOV_API_KEY},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
collections = data.get("collections", [])
|
||||
bills_col = next((c for c in collections if c.get("collectionCode") == "BILLS"), None)
|
||||
if bills_col:
|
||||
count = bills_col.get("packageCount", "?")
|
||||
return {"status": "ok", "detail": f"BILLS collection: {count:,} packages" if isinstance(count, int) else "GovInfo reachable, BILLS collection found"}
|
||||
return {"status": "ok", "detail": f"GovInfo reachable — {len(collections)} collections available"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_newsapi() -> dict:
|
||||
from app.config import settings
|
||||
import requests as req
|
||||
if not settings.NEWSAPI_KEY:
|
||||
return {"status": "skipped", "detail": "NEWSAPI_KEY not configured"}
|
||||
def _call():
|
||||
resp = req.get(
|
||||
"https://newsapi.org/v2/top-headlines",
|
||||
params={"country": "us", "pageSize": 1, "apiKey": settings.NEWSAPI_KEY},
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("status") != "ok":
|
||||
return {"status": "error", "detail": data.get("message", "Unknown error")}
|
||||
return {"status": "ok", "detail": f"{data.get('totalResults', 0):,} headlines available"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_gnews() -> dict:
|
||||
import requests as req
|
||||
def _call():
|
||||
resp = req.get(
|
||||
"https://news.google.com/rss/search",
|
||||
params={"q": "congress", "hl": "en-US", "gl": "US", "ceid": "US:en"},
|
||||
timeout=10,
|
||||
headers={"User-Agent": "Mozilla/5.0"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
item_count = resp.text.count("<item>")
|
||||
return {"status": "ok", "detail": f"{item_count} items in test RSS feed"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_rep_lookup() -> dict:
|
||||
import re as _re
|
||||
import requests as req
|
||||
def _call():
|
||||
# Step 1: Nominatim ZIP → lat/lng
|
||||
r1 = req.get(
|
||||
"https://nominatim.openstreetmap.org/search",
|
||||
params={"postalcode": "20001", "country": "US", "format": "json", "limit": "1"},
|
||||
headers={"User-Agent": "PocketVeto/1.0"},
|
||||
timeout=10,
|
||||
)
|
||||
r1.raise_for_status()
|
||||
places = r1.json()
|
||||
if not places:
|
||||
return {"status": "error", "detail": "Nominatim: no result for test ZIP 20001"}
|
||||
lat, lng = places[0]["lat"], places[0]["lon"]
|
||||
half = 0.5
|
||||
# Step 2: TIGERweb identify → congressional district
|
||||
r2 = req.get(
|
||||
"https://tigerweb.geo.census.gov/arcgis/rest/services/TIGERweb/Legislative/MapServer/identify",
|
||||
params={
|
||||
"f": "json",
|
||||
"geometry": f"{lng},{lat}",
|
||||
"geometryType": "esriGeometryPoint",
|
||||
"sr": "4326",
|
||||
"layers": "all",
|
||||
"tolerance": "2",
|
||||
"mapExtent": f"{float(lng)-half},{float(lat)-half},{float(lng)+half},{float(lat)+half}",
|
||||
"imageDisplay": "100,100,96",
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
results = r2.json().get("results", [])
|
||||
for item in results:
|
||||
attrs = item.get("attributes", {})
|
||||
cd_field = next((k for k in attrs if _re.match(r"CD\d+FP$", k)), None)
|
||||
if cd_field:
|
||||
district = str(int(str(attrs[cd_field]))) if str(attrs[cd_field]).strip("0") else "At-large"
|
||||
return {"status": "ok", "detail": f"Nominatim + TIGERweb reachable — district {district} found for ZIP 20001"}
|
||||
layers = [r.get("layerName") for r in results]
|
||||
return {"status": "error", "detail": f"Reachable but no CD field found. Layers: {layers}"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
@router.get("/task-status/{task_id}")
|
||||
async def get_task_status(task_id: str, current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.celery_app import celery_app
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": result.status,
|
||||
"result": result.result if result.ready() else None,
|
||||
}
|
||||
161
backend/app/api/alignment.py
Normal file
161
backend/app/api/alignment.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Representation Alignment API.
|
||||
|
||||
Returns how well each followed member's voting record aligns with the
|
||||
current user's bill stances (pocket_veto / pocket_boost).
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models import Follow, Member
|
||||
from app.models.user import User
|
||||
from app.models.vote import BillVote, MemberVotePosition
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_alignment(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Cross-reference the user's stanced bill follows with how their
|
||||
followed members voted on those same bills.
|
||||
|
||||
pocket_boost + Yea → aligned
|
||||
pocket_veto + Nay → aligned
|
||||
All other combinations with an actual Yea/Nay vote → opposed
|
||||
Not Voting / Present → excluded from tally
|
||||
"""
|
||||
# 1. Bill follows with a stance
|
||||
bill_follows_result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == "bill",
|
||||
Follow.follow_mode.in_(["pocket_veto", "pocket_boost"]),
|
||||
)
|
||||
)
|
||||
bill_follows = bill_follows_result.scalars().all()
|
||||
|
||||
if not bill_follows:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": 0,
|
||||
"total_bills_with_votes": 0,
|
||||
}
|
||||
|
||||
stance_map = {f.follow_value: f.follow_mode for f in bill_follows}
|
||||
|
||||
# 2. Followed members
|
||||
member_follows_result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == "member",
|
||||
)
|
||||
)
|
||||
member_follows = member_follows_result.scalars().all()
|
||||
followed_member_ids = {f.follow_value for f in member_follows}
|
||||
|
||||
if not followed_member_ids:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": 0,
|
||||
}
|
||||
|
||||
# 3. Bulk fetch votes for all stanced bills
|
||||
bill_ids = list(stance_map.keys())
|
||||
votes_result = await db.execute(
|
||||
select(BillVote).where(BillVote.bill_id.in_(bill_ids))
|
||||
)
|
||||
votes = votes_result.scalars().all()
|
||||
|
||||
if not votes:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": 0,
|
||||
}
|
||||
|
||||
vote_ids = [v.id for v in votes]
|
||||
bill_id_by_vote = {v.id: v.bill_id for v in votes}
|
||||
bills_with_votes = len({v.bill_id for v in votes})
|
||||
|
||||
# 4. Bulk fetch positions for followed members on those votes
|
||||
positions_result = await db.execute(
|
||||
select(MemberVotePosition).where(
|
||||
MemberVotePosition.vote_id.in_(vote_ids),
|
||||
MemberVotePosition.bioguide_id.in_(followed_member_ids),
|
||||
)
|
||||
)
|
||||
positions = positions_result.scalars().all()
|
||||
|
||||
# 5. Aggregate per member
|
||||
tally: dict[str, dict] = defaultdict(lambda: {"aligned": 0, "opposed": 0})
|
||||
|
||||
for pos in positions:
|
||||
if pos.position not in ("Yea", "Nay"):
|
||||
# Skip Not Voting / Present — not a real position signal
|
||||
continue
|
||||
bill_id = bill_id_by_vote.get(pos.vote_id)
|
||||
if not bill_id:
|
||||
continue
|
||||
stance = stance_map.get(bill_id)
|
||||
is_aligned = (
|
||||
(stance == "pocket_boost" and pos.position == "Yea") or
|
||||
(stance == "pocket_veto" and pos.position == "Nay")
|
||||
)
|
||||
if is_aligned:
|
||||
tally[pos.bioguide_id]["aligned"] += 1
|
||||
else:
|
||||
tally[pos.bioguide_id]["opposed"] += 1
|
||||
|
||||
if not tally:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": bills_with_votes,
|
||||
}
|
||||
|
||||
# 6. Load member details
|
||||
member_ids = list(tally.keys())
|
||||
members_result = await db.execute(
|
||||
select(Member).where(Member.bioguide_id.in_(member_ids))
|
||||
)
|
||||
members = members_result.scalars().all()
|
||||
member_map = {m.bioguide_id: m for m in members}
|
||||
|
||||
# 7. Build response
|
||||
result = []
|
||||
for bioguide_id, counts in tally.items():
|
||||
m = member_map.get(bioguide_id)
|
||||
aligned = counts["aligned"]
|
||||
opposed = counts["opposed"]
|
||||
total = aligned + opposed
|
||||
result.append({
|
||||
"bioguide_id": bioguide_id,
|
||||
"name": m.name if m else bioguide_id,
|
||||
"party": m.party if m else None,
|
||||
"state": m.state if m else None,
|
||||
"chamber": m.chamber if m else None,
|
||||
"photo_url": m.photo_url if m else None,
|
||||
"effectiveness_percentile": m.effectiveness_percentile if m else None,
|
||||
"aligned": aligned,
|
||||
"opposed": opposed,
|
||||
"total": total,
|
||||
"alignment_pct": round(aligned / total * 100, 1) if total > 0 else None,
|
||||
})
|
||||
|
||||
result.sort(key=lambda x: (x["alignment_pct"] is None, -(x["alignment_pct"] or 0)))
|
||||
|
||||
return {
|
||||
"members": result,
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": bills_with_votes,
|
||||
}
|
||||
58
backend/app/api/auth.py
Normal file
58
backend/app/api/auth.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.core.security import create_access_token, hash_password, verify_password
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import TokenResponse, UserCreate, UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=201)
|
||||
async def register(body: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
if "@" not in body.email:
|
||||
raise HTTPException(status_code=400, detail="Invalid email address")
|
||||
|
||||
# Check for duplicate email
|
||||
existing = await db.execute(select(User).where(User.email == body.email.lower()))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail="Email already registered")
|
||||
|
||||
# First registered user becomes admin
|
||||
count_result = await db.execute(select(func.count()).select_from(User))
|
||||
is_first_user = count_result.scalar() == 0
|
||||
|
||||
user = User(
|
||||
email=body.email.lower(),
|
||||
hashed_password=hash_password(body.password),
|
||||
is_admin=is_first_user,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return TokenResponse(access_token=create_access_token(user.id), user=user)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.email == body.email.lower()))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
|
||||
return TokenResponse(access_token=create_access_token(user.id), user=user)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
277
backend/app/api/bills.py
Normal file
277
backend/app/api/bills.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillAction, BillBrief, BillDocument, NewsArticle, TrendScore
|
||||
from app.schemas.schemas import (
|
||||
BillDetailSchema,
|
||||
BillSchema,
|
||||
BillActionSchema,
|
||||
BillVoteSchema,
|
||||
NewsArticleSchema,
|
||||
PaginatedResponse,
|
||||
TrendScoreSchema,
|
||||
)
|
||||
|
||||
_BILL_TYPE_LABELS: dict[str, str] = {
|
||||
"hr": "H.R.",
|
||||
"s": "S.",
|
||||
"hjres": "H.J.Res.",
|
||||
"sjres": "S.J.Res.",
|
||||
"hconres": "H.Con.Res.",
|
||||
"sconres": "S.Con.Res.",
|
||||
"hres": "H.Res.",
|
||||
"sres": "S.Res.",
|
||||
}
|
||||
|
||||
|
||||
class DraftLetterRequest(BaseModel):
|
||||
stance: Literal["yes", "no"]
|
||||
recipient: Literal["house", "senate"]
|
||||
tone: Literal["short", "polite", "firm"]
|
||||
selected_points: list[str]
|
||||
include_citations: bool = True
|
||||
zip_code: str | None = None # not stored, not logged
|
||||
rep_name: str | None = None # not stored, not logged
|
||||
|
||||
|
||||
class DraftLetterResponse(BaseModel):
|
||||
draft: str
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[BillSchema])
|
||||
async def list_bills(
|
||||
chamber: Optional[str] = Query(None),
|
||||
topic: Optional[str] = Query(None),
|
||||
sponsor_id: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None),
|
||||
has_document: Optional[bool] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
sort: str = Query("latest_action_date"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = (
|
||||
select(Bill)
|
||||
.options(
|
||||
selectinload(Bill.sponsor),
|
||||
selectinload(Bill.briefs),
|
||||
selectinload(Bill.trend_scores),
|
||||
)
|
||||
)
|
||||
|
||||
if chamber:
|
||||
query = query.where(Bill.chamber == chamber)
|
||||
if sponsor_id:
|
||||
query = query.where(Bill.sponsor_id == sponsor_id)
|
||||
if topic:
|
||||
query = query.join(BillBrief, Bill.bill_id == BillBrief.bill_id).where(
|
||||
BillBrief.topic_tags.contains([topic])
|
||||
)
|
||||
if q:
|
||||
query = query.where(
|
||||
or_(
|
||||
Bill.bill_id.ilike(f"%{q}%"),
|
||||
Bill.title.ilike(f"%{q}%"),
|
||||
Bill.short_title.ilike(f"%{q}%"),
|
||||
)
|
||||
)
|
||||
if has_document is True:
|
||||
doc_subq = select(BillDocument.bill_id).where(BillDocument.bill_id == Bill.bill_id).exists()
|
||||
query = query.where(doc_subq)
|
||||
elif has_document is False:
|
||||
doc_subq = select(BillDocument.bill_id).where(BillDocument.bill_id == Bill.bill_id).exists()
|
||||
query = query.where(~doc_subq)
|
||||
|
||||
# Count total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = await db.scalar(count_query) or 0
|
||||
|
||||
# Sort
|
||||
sort_col = getattr(Bill, sort, Bill.latest_action_date)
|
||||
query = query.order_by(desc(sort_col)).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
bills = result.scalars().unique().all()
|
||||
|
||||
# Single batch query: which of these bills have at least one document?
|
||||
bill_ids = [b.bill_id for b in bills]
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id.in_(bill_ids)).distinct()
|
||||
)
|
||||
bills_with_docs = {row[0] for row in doc_result}
|
||||
|
||||
# Attach latest brief, trend, and has_document to each bill
|
||||
items = []
|
||||
for bill in bills:
|
||||
bill_dict = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bill_dict.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bill_dict.latest_trend = bill.trend_scores[0]
|
||||
bill_dict.has_document = bill.bill_id in bills_with_docs
|
||||
items.append(bill_dict)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bill_id}", response_model=BillDetailSchema)
|
||||
async def get_bill(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(
|
||||
selectinload(Bill.sponsor),
|
||||
selectinload(Bill.actions),
|
||||
selectinload(Bill.briefs),
|
||||
selectinload(Bill.news_articles),
|
||||
selectinload(Bill.trend_scores),
|
||||
)
|
||||
.where(Bill.bill_id == bill_id)
|
||||
)
|
||||
bill = result.scalar_one_or_none()
|
||||
if not bill:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
detail = BillDetailSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
detail.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
detail.latest_trend = bill.trend_scores[0]
|
||||
doc_exists = await db.scalar(
|
||||
select(func.count()).select_from(BillDocument).where(BillDocument.bill_id == bill_id)
|
||||
)
|
||||
detail.has_document = bool(doc_exists)
|
||||
|
||||
# Trigger a background news refresh if no articles are stored but trend
|
||||
# data shows there are gnews results out there waiting to be fetched.
|
||||
latest_trend = bill.trend_scores[0] if bill.trend_scores else None
|
||||
has_gnews = latest_trend and (latest_trend.gnews_count or 0) > 0
|
||||
if not bill.news_articles and has_gnews:
|
||||
try:
|
||||
from app.workers.news_fetcher import fetch_news_for_bill
|
||||
fetch_news_for_bill.delay(bill_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
@router.get("/{bill_id}/actions", response_model=list[BillActionSchema])
|
||||
async def get_bill_actions(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(BillAction)
|
||||
.where(BillAction.bill_id == bill_id)
|
||||
.order_by(desc(BillAction.action_date))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/news", response_model=list[NewsArticleSchema])
|
||||
async def get_bill_news(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(NewsArticle)
|
||||
.where(NewsArticle.bill_id == bill_id)
|
||||
.order_by(desc(NewsArticle.published_at))
|
||||
.limit(20)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/trend", response_model=list[TrendScoreSchema])
|
||||
async def get_bill_trend(bill_id: str, days: int = Query(30, ge=7, le=365), db: AsyncSession = Depends(get_db)):
|
||||
from datetime import date, timedelta
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(TrendScore)
|
||||
.where(TrendScore.bill_id == bill_id, TrendScore.score_date >= cutoff)
|
||||
.order_by(TrendScore.score_date)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/votes", response_model=list[BillVoteSchema])
|
||||
async def get_bill_votes_endpoint(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
from app.models.vote import BillVote
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(
|
||||
select(BillVote)
|
||||
.where(BillVote.bill_id == bill_id)
|
||||
.options(selectinload(BillVote.positions))
|
||||
.order_by(desc(BillVote.vote_date))
|
||||
)
|
||||
votes = result.scalars().unique().all()
|
||||
|
||||
# Trigger background fetch if no votes are stored yet
|
||||
if not votes:
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if bill:
|
||||
try:
|
||||
from app.workers.vote_fetcher import fetch_bill_votes
|
||||
fetch_bill_votes.delay(bill_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return votes
|
||||
|
||||
|
||||
@router.post("/{bill_id}/draft-letter", response_model=DraftLetterResponse)
|
||||
async def generate_letter(bill_id: str, body: DraftLetterRequest, db: AsyncSession = Depends(get_db)):
|
||||
from app.models.setting import AppSetting
|
||||
from app.services.llm_service import generate_draft_letter
|
||||
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
if not body.selected_points:
|
||||
raise HTTPException(status_code=422, detail="At least one point must be selected")
|
||||
|
||||
prov_row = await db.get(AppSetting, "llm_provider")
|
||||
model_row = await db.get(AppSetting, "llm_model")
|
||||
llm_provider_override = prov_row.value if prov_row else None
|
||||
llm_model_override = model_row.value if model_row else None
|
||||
|
||||
type_label = _BILL_TYPE_LABELS.get((bill.bill_type or "").lower(), (bill.bill_type or "").upper())
|
||||
bill_label = f"{type_label} {bill.bill_number}"
|
||||
|
||||
try:
|
||||
draft = generate_draft_letter(
|
||||
bill_label=bill_label,
|
||||
bill_title=bill.short_title or bill.title or bill_label,
|
||||
stance=body.stance,
|
||||
recipient=body.recipient,
|
||||
tone=body.tone,
|
||||
selected_points=body.selected_points,
|
||||
include_citations=body.include_citations,
|
||||
zip_code=body.zip_code,
|
||||
rep_name=body.rep_name,
|
||||
llm_provider=llm_provider_override,
|
||||
llm_model=llm_model_override,
|
||||
)
|
||||
except Exception as exc:
|
||||
msg = str(exc)
|
||||
if "insufficient_quota" in msg or "quota" in msg.lower():
|
||||
detail = "LLM quota exceeded. Check your API key billing."
|
||||
elif "rate_limit" in msg.lower() or "429" in msg:
|
||||
detail = "LLM rate limit hit. Wait a moment and try again."
|
||||
elif "auth" in msg.lower() or "401" in msg or "403" in msg:
|
||||
detail = "LLM authentication failed. Check your API key."
|
||||
else:
|
||||
detail = f"LLM error: {msg[:200]}"
|
||||
raise HTTPException(status_code=502, detail=detail)
|
||||
return {"draft": draft}
|
||||
319
backend/app/api/collections.py
Normal file
319
backend/app/api/collections.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Collections API — named, curated groups of bills with share links.
|
||||
"""
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.bill import Bill, BillDocument
|
||||
from app.models.collection import Collection, CollectionBill
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import (
|
||||
BillSchema,
|
||||
CollectionCreate,
|
||||
CollectionDetailSchema,
|
||||
CollectionSchema,
|
||||
CollectionUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _slugify(text: str) -> str:
|
||||
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
|
||||
text = re.sub(r"[^\w\s-]", "", text.lower())
|
||||
return re.sub(r"[-\s]+", "-", text).strip("-")
|
||||
|
||||
|
||||
async def _unique_slug(db: AsyncSession, user_id: int, name: str, exclude_id: int | None = None) -> str:
|
||||
base = _slugify(name) or "collection"
|
||||
slug = base
|
||||
counter = 2
|
||||
while True:
|
||||
q = select(Collection).where(Collection.user_id == user_id, Collection.slug == slug)
|
||||
if exclude_id is not None:
|
||||
q = q.where(Collection.id != exclude_id)
|
||||
existing = (await db.execute(q)).scalar_one_or_none()
|
||||
if not existing:
|
||||
return slug
|
||||
slug = f"{base}-{counter}"
|
||||
counter += 1
|
||||
|
||||
|
||||
def _to_schema(collection: Collection) -> CollectionSchema:
|
||||
return CollectionSchema(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
slug=collection.slug,
|
||||
is_public=collection.is_public,
|
||||
share_token=collection.share_token,
|
||||
bill_count=len(collection.collection_bills),
|
||||
created_at=collection.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def _detail_schema(db: AsyncSession, collection: Collection) -> CollectionDetailSchema:
|
||||
"""Build CollectionDetailSchema with bills (including has_document)."""
|
||||
cb_list = collection.collection_bills
|
||||
bills = [cb.bill for cb in cb_list]
|
||||
|
||||
bill_ids = [b.bill_id for b in bills]
|
||||
if bill_ids:
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id.in_(bill_ids)).distinct()
|
||||
)
|
||||
bills_with_docs = {row[0] for row in doc_result}
|
||||
else:
|
||||
bills_with_docs = set()
|
||||
|
||||
bill_schemas = []
|
||||
for bill in bills:
|
||||
bs = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bs.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bs.latest_trend = bill.trend_scores[0]
|
||||
bs.has_document = bill.bill_id in bills_with_docs
|
||||
bill_schemas.append(bs)
|
||||
|
||||
return CollectionDetailSchema(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
slug=collection.slug,
|
||||
is_public=collection.is_public,
|
||||
share_token=collection.share_token,
|
||||
bill_count=len(cb_list),
|
||||
created_at=collection.created_at,
|
||||
bills=bill_schemas,
|
||||
)
|
||||
|
||||
|
||||
async def _load_collection(db: AsyncSession, collection_id: int) -> Collection:
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.briefs),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.trend_scores),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.sponsor),
|
||||
)
|
||||
.where(Collection.id == collection_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
# ── List ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=list[CollectionSchema])
|
||||
async def list_collections(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.user_id == current_user.id)
|
||||
.order_by(Collection.created_at.desc())
|
||||
)
|
||||
collections = result.scalars().unique().all()
|
||||
return [_to_schema(c) for c in collections]
|
||||
|
||||
|
||||
# ── Create ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", response_model=CollectionSchema, status_code=201)
|
||||
async def create_collection(
|
||||
body: CollectionCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
name = body.name.strip()
|
||||
if not 1 <= len(name) <= 100:
|
||||
raise HTTPException(status_code=422, detail="name must be 1–100 characters")
|
||||
|
||||
slug = await _unique_slug(db, current_user.id, name)
|
||||
collection = Collection(
|
||||
user_id=current_user.id,
|
||||
name=name,
|
||||
slug=slug,
|
||||
is_public=body.is_public,
|
||||
)
|
||||
db.add(collection)
|
||||
await db.flush()
|
||||
await db.execute(select(Collection).where(Collection.id == collection.id)) # ensure loaded
|
||||
await db.commit()
|
||||
await db.refresh(collection)
|
||||
|
||||
# Load collection_bills for bill_count
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.id == collection.id)
|
||||
)
|
||||
collection = result.scalar_one()
|
||||
return _to_schema(collection)
|
||||
|
||||
|
||||
# ── Share (public — no auth) ──────────────────────────────────────────────────
|
||||
|
||||
@router.get("/share/{share_token}", response_model=CollectionDetailSchema)
|
||||
async def get_collection_by_share_token(
|
||||
share_token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.briefs),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.trend_scores),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.sponsor),
|
||||
)
|
||||
.where(Collection.share_token == share_token)
|
||||
)
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
return await _detail_schema(db, collection)
|
||||
|
||||
|
||||
# ── Get (owner) ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{collection_id}", response_model=CollectionDetailSchema)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
collection = await _load_collection(db, collection_id)
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return await _detail_schema(db, collection)
|
||||
|
||||
|
||||
# ── Update ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.patch("/{collection_id}", response_model=CollectionSchema)
|
||||
async def update_collection(
|
||||
collection_id: int,
|
||||
body: CollectionUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.id == collection_id)
|
||||
)
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if body.name is not None:
|
||||
name = body.name.strip()
|
||||
if not 1 <= len(name) <= 100:
|
||||
raise HTTPException(status_code=422, detail="name must be 1–100 characters")
|
||||
collection.name = name
|
||||
collection.slug = await _unique_slug(db, current_user.id, name, exclude_id=collection_id)
|
||||
|
||||
if body.is_public is not None:
|
||||
collection.is_public = body.is_public
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(collection)
|
||||
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.id == collection_id)
|
||||
)
|
||||
collection = result.scalar_one()
|
||||
return _to_schema(collection)
|
||||
|
||||
|
||||
# ── Delete ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.delete("/{collection_id}", status_code=204)
|
||||
async def delete_collection(
|
||||
collection_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Collection).where(Collection.id == collection_id))
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
await db.delete(collection)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ── Add bill ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{collection_id}/bills/{bill_id}", status_code=204)
|
||||
async def add_bill_to_collection(
|
||||
collection_id: int,
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Collection).where(Collection.id == collection_id))
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
existing = await db.execute(
|
||||
select(CollectionBill).where(
|
||||
CollectionBill.collection_id == collection_id,
|
||||
CollectionBill.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
return # idempotent
|
||||
|
||||
db.add(CollectionBill(collection_id=collection_id, bill_id=bill_id))
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ── Remove bill ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.delete("/{collection_id}/bills/{bill_id}", status_code=204)
|
||||
async def remove_bill_from_collection(
|
||||
collection_id: int,
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Collection).where(Collection.id == collection_id))
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
cb_result = await db.execute(
|
||||
select(CollectionBill).where(
|
||||
CollectionBill.collection_id == collection_id,
|
||||
CollectionBill.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
cb = cb_result.scalar_one_or_none()
|
||||
if not cb:
|
||||
raise HTTPException(status_code=404, detail="Bill not in collection")
|
||||
await db.delete(cb)
|
||||
await db.commit()
|
||||
121
backend/app/api/dashboard.py
Normal file
121
backend/app/api/dashboard.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import APIRouter
|
||||
from sqlalchemy import desc, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.dependencies import get_optional_user
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillBrief, Follow, TrendScore
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import BillSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_trending(db: AsyncSession) -> list[dict]:
|
||||
# Try progressively wider windows so stale scores still surface results
|
||||
for days_back in (1, 3, 7, 30):
|
||||
trending_result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.join(TrendScore, Bill.bill_id == TrendScore.bill_id)
|
||||
.where(TrendScore.score_date >= date.today() - timedelta(days=days_back))
|
||||
.order_by(desc(TrendScore.composite_score))
|
||||
.limit(10)
|
||||
)
|
||||
trending_bills = trending_result.scalars().unique().all()
|
||||
if trending_bills:
|
||||
return [_serialize_bill(b) for b in trending_bills]
|
||||
return []
|
||||
|
||||
|
||||
def _serialize_bill(bill: Bill) -> dict:
|
||||
b = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
b.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
b.latest_trend = bill.trend_scores[0]
|
||||
return b.model_dump()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_dashboard(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
trending = await _get_trending(db)
|
||||
|
||||
if current_user is None:
|
||||
return {"feed": [], "trending": trending, "follows": {"bills": 0, "members": 0, "topics": 0}}
|
||||
|
||||
# Load follows for the current user
|
||||
follows_result = await db.execute(
|
||||
select(Follow).where(Follow.user_id == current_user.id)
|
||||
)
|
||||
follows = follows_result.scalars().all()
|
||||
|
||||
followed_bill_ids = [f.follow_value for f in follows if f.follow_type == "bill"]
|
||||
followed_member_ids = [f.follow_value for f in follows if f.follow_type == "member"]
|
||||
followed_topics = [f.follow_value for f in follows if f.follow_type == "topic"]
|
||||
|
||||
feed_bills: list[Bill] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
# 1. Directly followed bills
|
||||
if followed_bill_ids:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.where(Bill.bill_id.in_(followed_bill_ids))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# 2. Bills from followed members
|
||||
if followed_member_ids:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.where(Bill.sponsor_id.in_(followed_member_ids))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# 3. Bills matching followed topics (single query with OR across all topics)
|
||||
if followed_topics:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.join(BillBrief, Bill.bill_id == BillBrief.bill_id)
|
||||
.where(or_(*[BillBrief.topic_tags.contains([t]) for t in followed_topics]))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# Sort feed by latest action date
|
||||
feed_bills.sort(key=lambda b: b.latest_action_date or date.min, reverse=True)
|
||||
|
||||
return {
|
||||
"feed": [_serialize_bill(b) for b in feed_bills[:50]],
|
||||
"trending": trending,
|
||||
"follows": {
|
||||
"bills": len(followed_bill_ids),
|
||||
"members": len(followed_member_ids),
|
||||
"topics": len(followed_topics),
|
||||
},
|
||||
}
|
||||
94
backend/app/api/follows.py
Normal file
94
backend/app/api/follows.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models import Follow
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import FollowCreate, FollowModeUpdate, FollowSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
VALID_FOLLOW_TYPES = {"bill", "member", "topic"}
|
||||
VALID_MODES = {"neutral", "pocket_veto", "pocket_boost"}
|
||||
|
||||
|
||||
@router.get("", response_model=list[FollowSchema])
|
||||
async def list_follows(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Follow)
|
||||
.where(Follow.user_id == current_user.id)
|
||||
.order_by(Follow.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=FollowSchema, status_code=201)
|
||||
async def add_follow(
|
||||
body: FollowCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if body.follow_type not in VALID_FOLLOW_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"follow_type must be one of {VALID_FOLLOW_TYPES}")
|
||||
follow = Follow(
|
||||
user_id=current_user.id,
|
||||
follow_type=body.follow_type,
|
||||
follow_value=body.follow_value,
|
||||
)
|
||||
db.add(follow)
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(follow)
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
# Already following — return existing
|
||||
result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == body.follow_type,
|
||||
Follow.follow_value == body.follow_value,
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
return follow
|
||||
|
||||
|
||||
@router.patch("/{follow_id}/mode", response_model=FollowSchema)
|
||||
async def update_follow_mode(
|
||||
follow_id: int,
|
||||
body: FollowModeUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if body.follow_mode not in VALID_MODES:
|
||||
raise HTTPException(status_code=400, detail=f"follow_mode must be one of {VALID_MODES}")
|
||||
follow = await db.get(Follow, follow_id)
|
||||
if not follow:
|
||||
raise HTTPException(status_code=404, detail="Follow not found")
|
||||
if follow.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not your follow")
|
||||
follow.follow_mode = body.follow_mode
|
||||
await db.commit()
|
||||
await db.refresh(follow)
|
||||
return follow
|
||||
|
||||
|
||||
@router.delete("/{follow_id}", status_code=204)
|
||||
async def remove_follow(
|
||||
follow_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
follow = await db.get(Follow, follow_id)
|
||||
if not follow:
|
||||
raise HTTPException(status_code=404, detail="Follow not found")
|
||||
if follow.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not your follow")
|
||||
await db.delete(follow)
|
||||
await db.commit()
|
||||
43
backend/app/api/health.py
Normal file
43
backend/app/api/health.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis as redis_lib
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def health():
|
||||
return {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@router.get("/detailed")
|
||||
async def health_detailed(db: AsyncSession = Depends(get_db)):
|
||||
# Check DB
|
||||
db_ok = False
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check Redis
|
||||
redis_ok = False
|
||||
try:
|
||||
r = redis_lib.from_url(settings.REDIS_URL)
|
||||
redis_ok = r.ping()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
status = "ok" if (db_ok and redis_ok) else "degraded"
|
||||
return {
|
||||
"status": status,
|
||||
"database": "ok" if db_ok else "error",
|
||||
"redis": "ok" if redis_ok else "error",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
313
backend/app/api/members.py
Normal file
313
backend/app/api/members.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
_FIPS_TO_STATE = {
|
||||
"01": "AL", "02": "AK", "04": "AZ", "05": "AR", "06": "CA",
|
||||
"08": "CO", "09": "CT", "10": "DE", "11": "DC", "12": "FL",
|
||||
"13": "GA", "15": "HI", "16": "ID", "17": "IL", "18": "IN",
|
||||
"19": "IA", "20": "KS", "21": "KY", "22": "LA", "23": "ME",
|
||||
"24": "MD", "25": "MA", "26": "MI", "27": "MN", "28": "MS",
|
||||
"29": "MO", "30": "MT", "31": "NE", "32": "NV", "33": "NH",
|
||||
"34": "NJ", "35": "NM", "36": "NY", "37": "NC", "38": "ND",
|
||||
"39": "OH", "40": "OK", "41": "OR", "42": "PA", "44": "RI",
|
||||
"45": "SC", "46": "SD", "47": "TN", "48": "TX", "49": "UT",
|
||||
"50": "VT", "51": "VA", "53": "WA", "54": "WV", "55": "WI",
|
||||
"56": "WY", "60": "AS", "66": "GU", "69": "MP", "72": "PR", "78": "VI",
|
||||
}
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, Member, MemberTrendScore, MemberNewsArticle
|
||||
from app.schemas.schemas import (
|
||||
BillSchema, MemberSchema, MemberTrendScoreSchema,
|
||||
MemberNewsArticleSchema, PaginatedResponse,
|
||||
)
|
||||
from app.services import congress_api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/by-zip/{zip_code}", response_model=list[MemberSchema])
|
||||
async def get_members_by_zip(zip_code: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Return the House rep and senators for a ZIP code.
|
||||
Step 1: Nominatim (OpenStreetMap) — ZIP → lat/lng.
|
||||
Step 2: TIGERweb Legislative identify — lat/lng → congressional district.
|
||||
"""
|
||||
if not re.fullmatch(r"\d{5}", zip_code):
|
||||
raise HTTPException(status_code=400, detail="ZIP code must be 5 digits")
|
||||
|
||||
state_code: str | None = None
|
||||
district_num: str | None = None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
# Step 1: ZIP → lat/lng
|
||||
r1 = await client.get(
|
||||
"https://nominatim.openstreetmap.org/search",
|
||||
params={"postalcode": zip_code, "country": "US", "format": "json", "limit": "1"},
|
||||
headers={"User-Agent": "PocketVeto/1.0"},
|
||||
)
|
||||
places = r1.json() if r1.status_code == 200 else []
|
||||
if not places:
|
||||
logger.warning("Nominatim: no result for ZIP %s", zip_code)
|
||||
return []
|
||||
|
||||
lat = places[0]["lat"]
|
||||
lng = places[0]["lon"]
|
||||
|
||||
# Step 2: lat/lng → congressional district via TIGERweb identify (all layers)
|
||||
half = 0.5
|
||||
r2 = await client.get(
|
||||
"https://tigerweb.geo.census.gov/arcgis/rest/services/TIGERweb/Legislative/MapServer/identify",
|
||||
params={
|
||||
"f": "json",
|
||||
"geometry": f"{lng},{lat}",
|
||||
"geometryType": "esriGeometryPoint",
|
||||
"sr": "4326",
|
||||
"layers": "all",
|
||||
"tolerance": "2",
|
||||
"mapExtent": f"{float(lng)-half},{float(lat)-half},{float(lng)+half},{float(lat)+half}",
|
||||
"imageDisplay": "100,100,96",
|
||||
},
|
||||
)
|
||||
if r2.status_code != 200:
|
||||
logger.warning("TIGERweb returned %s for ZIP %s", r2.status_code, zip_code)
|
||||
return []
|
||||
|
||||
identify_results = r2.json().get("results", [])
|
||||
logger.info(
|
||||
"TIGERweb ZIP %s layers: %s",
|
||||
zip_code, [r.get("layerName") for r in identify_results],
|
||||
)
|
||||
|
||||
for item in identify_results:
|
||||
if "Congressional" not in (item.get("layerName") or ""):
|
||||
continue
|
||||
attrs = item.get("attributes", {})
|
||||
# GEOID = 2-char state FIPS + 2-char district (e.g. "1218" = FL-18)
|
||||
geoid = str(attrs.get("GEOID") or "").strip()
|
||||
if len(geoid) == 4:
|
||||
state_fips = geoid[:2]
|
||||
district_fips = geoid[2:]
|
||||
state_code = _FIPS_TO_STATE.get(state_fips)
|
||||
district_num = str(int(district_fips)) if district_fips.strip("0") else None
|
||||
if state_code:
|
||||
break
|
||||
|
||||
# Fallback: explicit field names
|
||||
cd_field = next((k for k in attrs if re.match(r"CD\d+FP$", k)), None)
|
||||
state_field = next((k for k in attrs if "STATEFP" in k.upper()), None)
|
||||
if cd_field and state_field:
|
||||
state_fips = str(attrs[state_field]).zfill(2)
|
||||
district_fips = str(attrs[cd_field])
|
||||
state_code = _FIPS_TO_STATE.get(state_fips)
|
||||
district_num = str(int(district_fips)) if district_fips.strip("0") else None
|
||||
if state_code:
|
||||
break
|
||||
|
||||
if not state_code:
|
||||
logger.warning(
|
||||
"ZIP %s: no CD found. Layers: %s",
|
||||
zip_code, [r.get("layerName") for r in identify_results],
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("ZIP lookup error for %s: %s", zip_code, exc)
|
||||
return []
|
||||
|
||||
if not state_code:
|
||||
return []
|
||||
|
||||
members: list[MemberSchema] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
if district_num:
|
||||
result = await db.execute(
|
||||
select(Member).where(
|
||||
Member.state == state_code,
|
||||
Member.district == district_num,
|
||||
Member.chamber == "House of Representatives",
|
||||
)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if member:
|
||||
seen.add(member.bioguide_id)
|
||||
members.append(MemberSchema.model_validate(member))
|
||||
else:
|
||||
# At-large states (AK, DE, MT, ND, SD, VT, WY)
|
||||
result = await db.execute(
|
||||
select(Member).where(
|
||||
Member.state == state_code,
|
||||
Member.chamber == "House of Representatives",
|
||||
).limit(1)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if member:
|
||||
seen.add(member.bioguide_id)
|
||||
members.append(MemberSchema.model_validate(member))
|
||||
|
||||
result = await db.execute(
|
||||
select(Member).where(
|
||||
Member.state == state_code,
|
||||
Member.chamber == "Senate",
|
||||
)
|
||||
)
|
||||
for member in result.scalars().all():
|
||||
if member.bioguide_id not in seen:
|
||||
seen.add(member.bioguide_id)
|
||||
members.append(MemberSchema.model_validate(member))
|
||||
|
||||
return members
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[MemberSchema])
|
||||
async def list_members(
|
||||
chamber: Optional[str] = Query(None),
|
||||
party: Optional[str] = Query(None),
|
||||
state: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(50, ge=1, le=250),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Member)
|
||||
if chamber:
|
||||
query = query.where(Member.chamber == chamber)
|
||||
if party:
|
||||
query = query.where(Member.party == party)
|
||||
if state:
|
||||
query = query.where(Member.state == state)
|
||||
if q:
|
||||
# name is stored as "Last, First" — also match "First Last" order
|
||||
first_last = func.concat(
|
||||
func.split_part(Member.name, ", ", 2), " ",
|
||||
func.split_part(Member.name, ", ", 1),
|
||||
)
|
||||
query = query.where(or_(
|
||||
Member.name.ilike(f"%{q}%"),
|
||||
first_last.ilike(f"%{q}%"),
|
||||
))
|
||||
|
||||
total = await db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
query = query.order_by(Member.last_name, Member.first_name).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
members = result.scalars().all()
|
||||
|
||||
return PaginatedResponse(
|
||||
items=members,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}", response_model=MemberSchema)
|
||||
async def get_member(bioguide_id: str, db: AsyncSession = Depends(get_db)):
|
||||
member = await db.get(Member, bioguide_id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=404, detail="Member not found")
|
||||
|
||||
# Kick off member interest on first view — single combined task avoids duplicate API calls
|
||||
if member.detail_fetched is None:
|
||||
try:
|
||||
from app.workers.member_interest import sync_member_interest
|
||||
sync_member_interest.delay(bioguide_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Lazy-enrich with detail data from Congress.gov on first view
|
||||
if member.detail_fetched is None:
|
||||
try:
|
||||
detail_raw = congress_api.get_member_detail(bioguide_id)
|
||||
enriched = congress_api.parse_member_detail_from_api(detail_raw)
|
||||
for field, value in enriched.items():
|
||||
if value is not None:
|
||||
setattr(member, field, value)
|
||||
member.detail_fetched = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
await db.refresh(member)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich member detail for {bioguide_id}: {e}")
|
||||
|
||||
# Attach latest trend score
|
||||
result_schema = MemberSchema.model_validate(member)
|
||||
latest_trend = (
|
||||
await db.execute(
|
||||
select(MemberTrendScore)
|
||||
.where(MemberTrendScore.member_id == bioguide_id)
|
||||
.order_by(desc(MemberTrendScore.score_date))
|
||||
.limit(1)
|
||||
)
|
||||
)
|
||||
trend = latest_trend.scalar_one_or_none()
|
||||
if trend:
|
||||
result_schema.latest_trend = MemberTrendScoreSchema.model_validate(trend)
|
||||
return result_schema
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/trend", response_model=list[MemberTrendScoreSchema])
|
||||
async def get_member_trend(
|
||||
bioguide_id: str,
|
||||
days: int = Query(30, ge=7, le=365),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from datetime import date, timedelta
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(MemberTrendScore)
|
||||
.where(MemberTrendScore.member_id == bioguide_id, MemberTrendScore.score_date >= cutoff)
|
||||
.order_by(MemberTrendScore.score_date)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/news", response_model=list[MemberNewsArticleSchema])
|
||||
async def get_member_news(bioguide_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(MemberNewsArticle)
|
||||
.where(MemberNewsArticle.member_id == bioguide_id)
|
||||
.order_by(desc(MemberNewsArticle.published_at))
|
||||
.limit(20)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/bills", response_model=PaginatedResponse[BillSchema])
|
||||
async def get_member_bills(
|
||||
bioguide_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Bill).options(selectinload(Bill.briefs), selectinload(Bill.sponsor)).where(Bill.sponsor_id == bioguide_id)
|
||||
total = await db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
query = query.order_by(desc(Bill.introduced_date)).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
bills = result.scalars().all()
|
||||
|
||||
items = []
|
||||
for bill in bills:
|
||||
b = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
b.latest_brief = bill.briefs[0]
|
||||
items.append(b)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
89
backend/app/api/notes.py
Normal file
89
backend/app/api/notes.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Bill Notes API — private per-user notes on individual bills.
|
||||
One note per (user, bill). PUT upserts, DELETE removes.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.bill import Bill
|
||||
from app.models.note import BillNote
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import BillNoteSchema, BillNoteUpsert
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{bill_id}", response_model=BillNoteSchema)
|
||||
async def get_note(
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(BillNote).where(
|
||||
BillNote.user_id == current_user.id,
|
||||
BillNote.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
note = result.scalar_one_or_none()
|
||||
if not note:
|
||||
raise HTTPException(status_code=404, detail="No note for this bill")
|
||||
return note
|
||||
|
||||
|
||||
@router.put("/{bill_id}", response_model=BillNoteSchema)
|
||||
async def upsert_note(
|
||||
bill_id: str,
|
||||
body: BillNoteUpsert,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(BillNote).where(
|
||||
BillNote.user_id == current_user.id,
|
||||
BillNote.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
note = result.scalar_one_or_none()
|
||||
|
||||
if note:
|
||||
note.content = body.content
|
||||
note.pinned = body.pinned
|
||||
else:
|
||||
note = BillNote(
|
||||
user_id=current_user.id,
|
||||
bill_id=bill_id,
|
||||
content=body.content,
|
||||
pinned=body.pinned,
|
||||
)
|
||||
db.add(note)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(note)
|
||||
return note
|
||||
|
||||
|
||||
@router.delete("/{bill_id}", status_code=204)
|
||||
async def delete_note(
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(BillNote).where(
|
||||
BillNote.user_id == current_user.id,
|
||||
BillNote.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
note = result.scalar_one_or_none()
|
||||
if not note:
|
||||
raise HTTPException(status_code=404, detail="No note for this bill")
|
||||
await db.delete(note)
|
||||
await db.commit()
|
||||
465
backend/app/api/notifications.py
Normal file
465
backend/app/api/notifications.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Notifications API — user notification settings and per-user RSS feed.
|
||||
"""
|
||||
import base64
|
||||
import secrets
|
||||
from xml.etree.ElementTree import Element, SubElement, tostring
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import HTMLResponse, Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings as app_settings
|
||||
from app.core.crypto import decrypt_secret, encrypt_secret
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.notification import NotificationEvent
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import (
|
||||
FollowModeTestRequest,
|
||||
NotificationEventSchema,
|
||||
NotificationSettingsResponse,
|
||||
NotificationSettingsUpdate,
|
||||
NotificationTestResult,
|
||||
NtfyTestRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_EVENT_LABELS = {
|
||||
"new_document": "New Bill Text",
|
||||
"new_amendment": "Amendment Filed",
|
||||
"bill_updated": "Bill Updated",
|
||||
"weekly_digest": "Weekly Digest",
|
||||
}
|
||||
|
||||
|
||||
def _prefs_to_response(prefs: dict, rss_token: str | None) -> NotificationSettingsResponse:
|
||||
return NotificationSettingsResponse(
|
||||
ntfy_topic_url=prefs.get("ntfy_topic_url", ""),
|
||||
ntfy_auth_method=prefs.get("ntfy_auth_method", "none"),
|
||||
ntfy_token=prefs.get("ntfy_token", ""),
|
||||
ntfy_username=prefs.get("ntfy_username", ""),
|
||||
ntfy_password_set=bool(decrypt_secret(prefs.get("ntfy_password", ""))),
|
||||
ntfy_enabled=prefs.get("ntfy_enabled", False),
|
||||
rss_enabled=prefs.get("rss_enabled", False),
|
||||
rss_token=rss_token,
|
||||
email_enabled=prefs.get("email_enabled", False),
|
||||
email_address=prefs.get("email_address", ""),
|
||||
digest_enabled=prefs.get("digest_enabled", False),
|
||||
digest_frequency=prefs.get("digest_frequency", "daily"),
|
||||
quiet_hours_start=prefs.get("quiet_hours_start"),
|
||||
quiet_hours_end=prefs.get("quiet_hours_end"),
|
||||
timezone=prefs.get("timezone"),
|
||||
alert_filters=prefs.get("alert_filters"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/settings", response_model=NotificationSettingsResponse)
|
||||
async def get_notification_settings(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user = await db.get(User, current_user.id)
|
||||
# Auto-generate RSS token on first visit so the feed URL is always available
|
||||
if not user.rss_token:
|
||||
user.rss_token = secrets.token_urlsafe(32)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return _prefs_to_response(user.notification_prefs or {}, user.rss_token)
|
||||
|
||||
|
||||
@router.put("/settings", response_model=NotificationSettingsResponse)
|
||||
async def update_notification_settings(
|
||||
body: NotificationSettingsUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user = await db.get(User, current_user.id)
|
||||
prefs = dict(user.notification_prefs or {})
|
||||
|
||||
if body.ntfy_topic_url is not None:
|
||||
prefs["ntfy_topic_url"] = body.ntfy_topic_url.strip()
|
||||
if body.ntfy_auth_method is not None:
|
||||
prefs["ntfy_auth_method"] = body.ntfy_auth_method
|
||||
if body.ntfy_token is not None:
|
||||
prefs["ntfy_token"] = body.ntfy_token.strip()
|
||||
if body.ntfy_username is not None:
|
||||
prefs["ntfy_username"] = body.ntfy_username.strip()
|
||||
if body.ntfy_password is not None:
|
||||
prefs["ntfy_password"] = encrypt_secret(body.ntfy_password.strip())
|
||||
if body.ntfy_enabled is not None:
|
||||
prefs["ntfy_enabled"] = body.ntfy_enabled
|
||||
if body.rss_enabled is not None:
|
||||
prefs["rss_enabled"] = body.rss_enabled
|
||||
if body.email_enabled is not None:
|
||||
prefs["email_enabled"] = body.email_enabled
|
||||
if body.email_address is not None:
|
||||
prefs["email_address"] = body.email_address.strip()
|
||||
if body.digest_enabled is not None:
|
||||
prefs["digest_enabled"] = body.digest_enabled
|
||||
if body.digest_frequency is not None:
|
||||
prefs["digest_frequency"] = body.digest_frequency
|
||||
if body.quiet_hours_start is not None:
|
||||
prefs["quiet_hours_start"] = body.quiet_hours_start
|
||||
if body.quiet_hours_end is not None:
|
||||
prefs["quiet_hours_end"] = body.quiet_hours_end
|
||||
if body.timezone is not None:
|
||||
prefs["timezone"] = body.timezone
|
||||
if body.alert_filters is not None:
|
||||
prefs["alert_filters"] = body.alert_filters
|
||||
# Allow clearing quiet hours by passing -1
|
||||
if body.quiet_hours_start == -1:
|
||||
prefs.pop("quiet_hours_start", None)
|
||||
prefs.pop("quiet_hours_end", None)
|
||||
prefs.pop("timezone", None)
|
||||
|
||||
user.notification_prefs = prefs
|
||||
|
||||
if not user.rss_token:
|
||||
user.rss_token = secrets.token_urlsafe(32)
|
||||
# Generate unsubscribe token the first time an email address is saved
|
||||
if prefs.get("email_address") and not user.email_unsubscribe_token:
|
||||
user.email_unsubscribe_token = secrets.token_urlsafe(32)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return _prefs_to_response(user.notification_prefs or {}, user.rss_token)
|
||||
|
||||
|
||||
@router.post("/settings/rss-reset", response_model=NotificationSettingsResponse)
|
||||
async def reset_rss_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Regenerate the RSS token, invalidating the old feed URL."""
|
||||
user = await db.get(User, current_user.id)
|
||||
user.rss_token = secrets.token_urlsafe(32)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return _prefs_to_response(user.notification_prefs or {}, user.rss_token)
|
||||
|
||||
|
||||
@router.post("/test/ntfy", response_model=NotificationTestResult)
|
||||
async def test_ntfy(
|
||||
body: NtfyTestRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Send a test push notification to verify ntfy settings."""
|
||||
url = body.ntfy_topic_url.strip()
|
||||
if not url:
|
||||
return NotificationTestResult(status="error", detail="Topic URL is required")
|
||||
|
||||
base_url = (app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip("/")
|
||||
headers: dict[str, str] = {
|
||||
"Title": "PocketVeto: Test Notification",
|
||||
"Priority": "default",
|
||||
"Tags": "white_check_mark",
|
||||
"Click": f"{base_url}/notifications",
|
||||
}
|
||||
if body.ntfy_auth_method == "token" and body.ntfy_token.strip():
|
||||
headers["Authorization"] = f"Bearer {body.ntfy_token.strip()}"
|
||||
elif body.ntfy_auth_method == "basic" and body.ntfy_username.strip():
|
||||
creds = base64.b64encode(
|
||||
f"{body.ntfy_username.strip()}:{body.ntfy_password}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
content=(
|
||||
"Your PocketVeto notification settings are working correctly. "
|
||||
"Real alerts will link directly to the relevant bill page."
|
||||
).encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return NotificationTestResult(status="ok", detail=f"Test notification sent (HTTP {resp.status_code})")
|
||||
except httpx.HTTPStatusError as e:
|
||||
return NotificationTestResult(status="error", detail=f"HTTP {e.response.status_code}: {e.response.text[:200]}")
|
||||
except httpx.RequestError as e:
|
||||
return NotificationTestResult(status="error", detail=f"Connection error: {e}")
|
||||
|
||||
|
||||
@router.post("/test/email", response_model=NotificationTestResult)
|
||||
async def test_email(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Send a test email to the user's configured email address."""
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
user = await db.get(User, current_user.id)
|
||||
prefs = user.notification_prefs or {}
|
||||
email_addr = prefs.get("email_address", "").strip()
|
||||
if not email_addr:
|
||||
return NotificationTestResult(status="error", detail="No email address saved. Save your address first.")
|
||||
|
||||
if not app_settings.SMTP_HOST:
|
||||
return NotificationTestResult(status="error", detail="SMTP not configured on this server. Set SMTP_HOST in .env")
|
||||
|
||||
try:
|
||||
from_addr = app_settings.SMTP_FROM or app_settings.SMTP_USER
|
||||
base_url = (app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip("/")
|
||||
body = (
|
||||
"This is a test email from PocketVeto.\n\n"
|
||||
"Your email notification settings are working correctly. "
|
||||
"Real alerts will include bill titles, summaries, and direct links.\n\n"
|
||||
f"Visit your notifications page: {base_url}/notifications"
|
||||
)
|
||||
msg = MIMEText(body, "plain", "utf-8")
|
||||
msg["Subject"] = "PocketVeto: Test Email Notification"
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = email_addr
|
||||
|
||||
use_ssl = app_settings.SMTP_PORT == 465
|
||||
if use_ssl:
|
||||
ctx = smtplib.SMTP_SSL(app_settings.SMTP_HOST, app_settings.SMTP_PORT, timeout=10)
|
||||
else:
|
||||
ctx = smtplib.SMTP(app_settings.SMTP_HOST, app_settings.SMTP_PORT, timeout=10)
|
||||
with ctx as s:
|
||||
if not use_ssl and app_settings.SMTP_STARTTLS:
|
||||
s.starttls()
|
||||
if app_settings.SMTP_USER:
|
||||
s.login(app_settings.SMTP_USER, app_settings.SMTP_PASSWORD)
|
||||
s.sendmail(from_addr, [email_addr], msg.as_string())
|
||||
|
||||
return NotificationTestResult(status="ok", detail=f"Test email sent to {email_addr}")
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
return NotificationTestResult(status="error", detail="SMTP authentication failed — check SMTP_USER and SMTP_PASSWORD in .env")
|
||||
except smtplib.SMTPConnectError:
|
||||
return NotificationTestResult(status="error", detail=f"Could not connect to {app_settings.SMTP_HOST}:{app_settings.SMTP_PORT}")
|
||||
except Exception as e:
|
||||
return NotificationTestResult(status="error", detail=str(e))
|
||||
|
||||
|
||||
@router.post("/test/rss", response_model=NotificationTestResult)
|
||||
async def test_rss(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Verify the user's RSS feed is reachable and return its event count."""
|
||||
user = await db.get(User, current_user.id)
|
||||
if not user.rss_token:
|
||||
return NotificationTestResult(status="error", detail="RSS token not generated — save settings first")
|
||||
|
||||
count_result = await db.execute(
|
||||
select(NotificationEvent).where(NotificationEvent.user_id == user.id)
|
||||
)
|
||||
event_count = len(count_result.scalars().all())
|
||||
|
||||
return NotificationTestResult(
|
||||
status="ok",
|
||||
detail=f"RSS feed is active with {event_count} event{'s' if event_count != 1 else ''}. Subscribe to the URL shown above.",
|
||||
event_count=event_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history", response_model=list[NotificationEventSchema])
|
||||
async def get_notification_history(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return the 50 most recent notification events for the current user."""
|
||||
result = await db.execute(
|
||||
select(NotificationEvent)
|
||||
.where(NotificationEvent.user_id == current_user.id)
|
||||
.order_by(NotificationEvent.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/test/follow-mode", response_model=NotificationTestResult)
|
||||
async def test_follow_mode(
|
||||
body: FollowModeTestRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Simulate dispatcher behaviour for a given follow mode + event type."""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.follow import Follow
|
||||
|
||||
VALID_MODES = {"pocket_veto", "pocket_boost"}
|
||||
VALID_EVENTS = {"new_document", "new_amendment", "bill_updated"}
|
||||
if body.mode not in VALID_MODES:
|
||||
return NotificationTestResult(status="error", detail=f"mode must be one of {VALID_MODES}")
|
||||
if body.event_type not in VALID_EVENTS:
|
||||
return NotificationTestResult(status="error", detail=f"event_type must be one of {VALID_EVENTS}")
|
||||
|
||||
result = await db.execute(
|
||||
sa_select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == "bill",
|
||||
).limit(1)
|
||||
)
|
||||
follow = result.scalar_one_or_none()
|
||||
if not follow:
|
||||
return NotificationTestResult(
|
||||
status="error",
|
||||
detail="No bill follows found — follow at least one bill first",
|
||||
)
|
||||
|
||||
# Pocket Veto suppression: brief events are silently dropped
|
||||
if body.mode == "pocket_veto" and body.event_type in ("new_document", "new_amendment"):
|
||||
return NotificationTestResult(
|
||||
status="ok",
|
||||
detail=(
|
||||
f"✓ Suppressed — Pocket Veto correctly blocked a '{body.event_type}' event. "
|
||||
"No ntfy was sent (this is the expected behaviour)."
|
||||
),
|
||||
)
|
||||
|
||||
# Everything else would send ntfy — check the user has it configured
|
||||
user = await db.get(User, current_user.id)
|
||||
prefs = user.notification_prefs or {}
|
||||
ntfy_url = prefs.get("ntfy_topic_url", "").strip()
|
||||
ntfy_enabled = prefs.get("ntfy_enabled", False)
|
||||
if not ntfy_enabled or not ntfy_url:
|
||||
return NotificationTestResult(
|
||||
status="error",
|
||||
detail="ntfy not configured or disabled — enable it in Notification Settings first.",
|
||||
)
|
||||
|
||||
bill_url = f"{(app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip('/')}/bills/{follow.follow_value}"
|
||||
event_titles = {
|
||||
"new_document": "New Bill Text",
|
||||
"new_amendment": "Amendment Filed",
|
||||
"bill_updated": "Bill Updated",
|
||||
}
|
||||
mode_label = body.mode.replace("_", " ").title()
|
||||
headers: dict[str, str] = {
|
||||
"Title": f"[{mode_label} Test] {event_titles[body.event_type]}: {follow.follow_value.upper()}",
|
||||
"Priority": "default",
|
||||
"Tags": "test_tube",
|
||||
"Click": bill_url,
|
||||
}
|
||||
if body.mode == "pocket_boost":
|
||||
headers["Actions"] = (
|
||||
f"view, View Bill, {bill_url}; "
|
||||
"view, Find Your Rep, https://www.house.gov/representatives/find-your-representative"
|
||||
)
|
||||
|
||||
auth_method = prefs.get("ntfy_auth_method", "none")
|
||||
ntfy_token = prefs.get("ntfy_token", "").strip()
|
||||
ntfy_username = prefs.get("ntfy_username", "").strip()
|
||||
ntfy_password = prefs.get("ntfy_password", "").strip()
|
||||
if auth_method == "token" and ntfy_token:
|
||||
headers["Authorization"] = f"Bearer {ntfy_token}"
|
||||
elif auth_method == "basic" and ntfy_username:
|
||||
creds = base64.b64encode(f"{ntfy_username}:{ntfy_password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
message_lines = [
|
||||
f"This is a test of {mode_label} mode for bill {follow.follow_value.upper()}.",
|
||||
f"Event type: {event_titles[body.event_type]}",
|
||||
]
|
||||
if body.mode == "pocket_boost":
|
||||
message_lines.append("Tap the action buttons below to view the bill or find your representative.")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(
|
||||
ntfy_url,
|
||||
content="\n".join(message_lines).encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
detail = f"✓ ntfy sent (HTTP {resp.status_code})"
|
||||
if body.mode == "pocket_boost":
|
||||
detail += " — check your phone for 'View Bill' and 'Find Your Rep' action buttons"
|
||||
return NotificationTestResult(status="ok", detail=detail)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return NotificationTestResult(status="error", detail=f"HTTP {e.response.status_code}: {e.response.text[:200]}")
|
||||
except httpx.RequestError as e:
|
||||
return NotificationTestResult(status="error", detail=f"Connection error: {e}")
|
||||
|
||||
|
||||
@router.get("/unsubscribe/{token}", response_class=HTMLResponse, include_in_schema=False)
|
||||
async def email_unsubscribe(token: str, db: AsyncSession = Depends(get_db)):
|
||||
"""One-click email unsubscribe — no login required."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.email_unsubscribe_token == token)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return HTMLResponse(
|
||||
_unsubscribe_page("Invalid or expired link", success=False),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
prefs = dict(user.notification_prefs or {})
|
||||
prefs["email_enabled"] = False
|
||||
user.notification_prefs = prefs
|
||||
await db.commit()
|
||||
|
||||
return HTMLResponse(_unsubscribe_page("You've been unsubscribed from PocketVeto email notifications.", success=True))
|
||||
|
||||
|
||||
def _unsubscribe_page(message: str, success: bool) -> str:
|
||||
color = "#16a34a" if success else "#dc2626"
|
||||
icon = "✓" if success else "✗"
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head><meta charset="utf-8"><meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>PocketVeto — Unsubscribe</title>
|
||||
<style>
|
||||
body{{font-family:system-ui,sans-serif;background:#f9fafb;display:flex;align-items:center;justify-content:center;min-height:100vh;margin:0}}
|
||||
.card{{background:#fff;border:1px solid #e5e7eb;border-radius:12px;padding:2.5rem;max-width:420px;width:100%;text-align:center;box-shadow:0 1px 3px rgba(0,0,0,.08)}}
|
||||
.icon{{font-size:2.5rem;color:{color};margin-bottom:1rem}}
|
||||
h1{{font-size:1.1rem;font-weight:600;color:#111827;margin:0 0 .5rem}}
|
||||
p{{font-size:.9rem;color:#6b7280;margin:0 0 1.5rem;line-height:1.5}}
|
||||
a{{color:#2563eb;text-decoration:none;font-size:.875rem}}a:hover{{text-decoration:underline}}
|
||||
</style></head>
|
||||
<body><div class="card">
|
||||
<div class="icon">{icon}</div>
|
||||
<h1>Email Notifications</h1>
|
||||
<p>{message}</p>
|
||||
<a href="/">Return to PocketVeto</a>
|
||||
</div></body></html>"""
|
||||
|
||||
|
||||
@router.get("/feed/{rss_token}.xml", include_in_schema=False)
|
||||
async def rss_feed(rss_token: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Public tokenized RSS feed — no auth required."""
|
||||
result = await db.execute(select(User).where(User.rss_token == rss_token))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="Feed not found")
|
||||
|
||||
events_result = await db.execute(
|
||||
select(NotificationEvent)
|
||||
.where(NotificationEvent.user_id == user.id)
|
||||
.order_by(NotificationEvent.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
events = events_result.scalars().all()
|
||||
return Response(content=_build_rss(events), media_type="application/rss+xml")
|
||||
|
||||
|
||||
def _build_rss(events: list) -> bytes:
|
||||
rss = Element("rss", version="2.0")
|
||||
channel = SubElement(rss, "channel")
|
||||
SubElement(channel, "title").text = "PocketVeto — Bill Alerts"
|
||||
SubElement(channel, "description").text = "Updates on your followed bills"
|
||||
SubElement(channel, "language").text = "en-us"
|
||||
|
||||
for event in events:
|
||||
payload = event.payload or {}
|
||||
item = SubElement(channel, "item")
|
||||
label = _EVENT_LABELS.get(event.event_type, "Update")
|
||||
bill_label = payload.get("bill_label", event.bill_id.upper())
|
||||
SubElement(item, "title").text = f"{label}: {bill_label} — {payload.get('bill_title', '')}"
|
||||
SubElement(item, "description").text = payload.get("brief_summary", "")
|
||||
if payload.get("bill_url"):
|
||||
SubElement(item, "link").text = payload["bill_url"]
|
||||
SubElement(item, "pubDate").text = event.created_at.strftime("%a, %d %b %Y %H:%M:%S +0000")
|
||||
SubElement(item, "guid").text = str(event.id)
|
||||
|
||||
return tostring(rss, encoding="unicode").encode("utf-8")
|
||||
60
backend/app/api/search.py
Normal file
60
backend/app/api/search.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func, or_, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, Member
|
||||
from app.schemas.schemas import BillSchema, MemberSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=2, max_length=500),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
# Bill ID direct match
|
||||
id_results = await db.execute(
|
||||
select(Bill).where(Bill.bill_id.ilike(f"%{q}%")).limit(20)
|
||||
)
|
||||
id_bills = id_results.scalars().all()
|
||||
|
||||
# Full-text search on title/content via tsvector
|
||||
fts_results = await db.execute(
|
||||
select(Bill)
|
||||
.where(text("search_vector @@ plainto_tsquery('english', :q)"))
|
||||
.order_by(text("ts_rank(search_vector, plainto_tsquery('english', :q)) DESC"))
|
||||
.limit(20)
|
||||
.params(q=q)
|
||||
)
|
||||
fts_bills = fts_results.scalars().all()
|
||||
|
||||
# Merge, dedup, preserve order (ID matches first)
|
||||
seen = set()
|
||||
bills = []
|
||||
for b in id_bills + fts_bills:
|
||||
if b.bill_id not in seen:
|
||||
seen.add(b.bill_id)
|
||||
bills.append(b)
|
||||
|
||||
# Fuzzy member search — matches "Last, First" and "First Last"
|
||||
first_last = func.concat(
|
||||
func.split_part(Member.name, ", ", 2), " ",
|
||||
func.split_part(Member.name, ", ", 1),
|
||||
)
|
||||
member_results = await db.execute(
|
||||
select(Member)
|
||||
.where(or_(
|
||||
Member.name.ilike(f"%{q}%"),
|
||||
first_last.ilike(f"%{q}%"),
|
||||
))
|
||||
.order_by(Member.last_name)
|
||||
.limit(10)
|
||||
)
|
||||
members = member_results.scalars().all()
|
||||
|
||||
return {
|
||||
"bills": [BillSchema.model_validate(b) for b in bills],
|
||||
"members": [MemberSchema.model_validate(m) for m in members],
|
||||
}
|
||||
225
backend/app/api/settings.py
Normal file
225
backend/app/api/settings.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.core.dependencies import get_current_admin, get_current_user
|
||||
from app.database import get_db
|
||||
from app.models import AppSetting
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import SettingUpdate, SettingsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return current effective settings (env + DB overrides)."""
|
||||
# DB overrides take precedence over env vars
|
||||
overrides: dict[str, str] = {}
|
||||
result = await db.execute(select(AppSetting))
|
||||
for row in result.scalars().all():
|
||||
overrides[row.key] = row.value
|
||||
|
||||
return SettingsResponse(
|
||||
llm_provider=overrides.get("llm_provider", settings.LLM_PROVIDER),
|
||||
llm_model=overrides.get("llm_model", _current_model(overrides.get("llm_provider", settings.LLM_PROVIDER))),
|
||||
congress_poll_interval_minutes=int(overrides.get("congress_poll_interval_minutes", settings.CONGRESS_POLL_INTERVAL_MINUTES)),
|
||||
newsapi_enabled=bool(settings.NEWSAPI_KEY),
|
||||
pytrends_enabled=settings.PYTRENDS_ENABLED,
|
||||
api_keys_configured={
|
||||
"openai": bool(settings.OPENAI_API_KEY),
|
||||
"anthropic": bool(settings.ANTHROPIC_API_KEY),
|
||||
"gemini": bool(settings.GEMINI_API_KEY),
|
||||
"ollama": True, # no API key required
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_setting(
|
||||
body: SettingUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Update a runtime setting."""
|
||||
ALLOWED_KEYS = {"llm_provider", "llm_model", "congress_poll_interval_minutes"}
|
||||
if body.key not in ALLOWED_KEYS:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail=f"Allowed setting keys: {ALLOWED_KEYS}")
|
||||
|
||||
existing = await db.get(AppSetting, body.key)
|
||||
if existing:
|
||||
existing.value = body.value
|
||||
else:
|
||||
db.add(AppSetting(key=body.key, value=body.value))
|
||||
await db.commit()
|
||||
return {"key": body.key, "value": body.value}
|
||||
|
||||
|
||||
@router.post("/test-llm")
|
||||
async def test_llm_connection(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Ping the configured LLM provider with a minimal request."""
|
||||
import asyncio
|
||||
prov_row = await db.get(AppSetting, "llm_provider")
|
||||
model_row = await db.get(AppSetting, "llm_model")
|
||||
provider_name = prov_row.value if prov_row else settings.LLM_PROVIDER
|
||||
model_name = model_row.value if model_row else None
|
||||
try:
|
||||
return await asyncio.to_thread(_ping_provider, provider_name, model_name)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
_PING = "Reply with exactly three words: Connection test successful."
|
||||
|
||||
|
||||
def _ping_provider(provider_name: str, model_name: str | None) -> dict:
|
||||
if provider_name == "openai":
|
||||
from openai import OpenAI
|
||||
model = model_name or settings.OPENAI_MODEL
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
resp = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": _PING}],
|
||||
max_tokens=20,
|
||||
)
|
||||
reply = resp.choices[0].message.content.strip()
|
||||
return {"status": "ok", "provider": "openai", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "anthropic":
|
||||
import anthropic
|
||||
model = model_name or settings.ANTHROPIC_MODEL
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
resp = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=20,
|
||||
messages=[{"role": "user", "content": _PING}],
|
||||
)
|
||||
reply = resp.content[0].text.strip()
|
||||
return {"status": "ok", "provider": "anthropic", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "gemini":
|
||||
import google.generativeai as genai
|
||||
model = model_name or settings.GEMINI_MODEL
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
resp = genai.GenerativeModel(model_name=model).generate_content(_PING)
|
||||
reply = resp.text.strip()
|
||||
return {"status": "ok", "provider": "gemini", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "ollama":
|
||||
import requests as req
|
||||
model = model_name or settings.OLLAMA_MODEL
|
||||
resp = req.post(
|
||||
f"{settings.OLLAMA_BASE_URL}/api/generate",
|
||||
json={"model": model, "prompt": _PING, "stream": False},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
reply = resp.json().get("response", "").strip()
|
||||
return {"status": "ok", "provider": "ollama", "model": model, "reply": reply}
|
||||
|
||||
raise ValueError(f"Unknown provider: {provider_name}")
|
||||
|
||||
|
||||
@router.get("/llm-models")
|
||||
async def list_llm_models(
|
||||
provider: str,
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Fetch available models directly from the provider's API."""
|
||||
import asyncio
|
||||
handlers = {
|
||||
"openai": _list_openai_models,
|
||||
"anthropic": _list_anthropic_models,
|
||||
"gemini": _list_gemini_models,
|
||||
"ollama": _list_ollama_models,
|
||||
}
|
||||
fn = handlers.get(provider)
|
||||
if not fn:
|
||||
return {"models": [], "error": f"Unknown provider: {provider}"}
|
||||
try:
|
||||
return await asyncio.to_thread(fn)
|
||||
except Exception as exc:
|
||||
return {"models": [], "error": str(exc)}
|
||||
|
||||
|
||||
def _list_openai_models() -> dict:
|
||||
from openai import OpenAI
|
||||
if not settings.OPENAI_API_KEY:
|
||||
return {"models": [], "error": "OPENAI_API_KEY not configured"}
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
all_models = client.models.list().data
|
||||
CHAT_PREFIXES = ("gpt-", "o1", "o3", "o4", "chatgpt-")
|
||||
EXCLUDE = ("realtime", "audio", "tts", "whisper", "embedding", "dall-e", "instruct")
|
||||
filtered = sorted(
|
||||
[m.id for m in all_models
|
||||
if any(m.id.startswith(p) for p in CHAT_PREFIXES)
|
||||
and not any(x in m.id for x in EXCLUDE)],
|
||||
reverse=True,
|
||||
)
|
||||
return {"models": [{"id": m, "name": m} for m in filtered]}
|
||||
|
||||
|
||||
def _list_anthropic_models() -> dict:
|
||||
import requests as req
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
return {"models": [], "error": "ANTHROPIC_API_KEY not configured"}
|
||||
resp = req.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": settings.ANTHROPIC_API_KEY,
|
||||
"anthropic-version": "2023-06-01",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"models": [
|
||||
{"id": m["id"], "name": m.get("display_name", m["id"])}
|
||||
for m in data.get("data", [])
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _list_gemini_models() -> dict:
|
||||
import google.generativeai as genai
|
||||
if not settings.GEMINI_API_KEY:
|
||||
return {"models": [], "error": "GEMINI_API_KEY not configured"}
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
models = [
|
||||
{"id": m.name.replace("models/", ""), "name": m.display_name}
|
||||
for m in genai.list_models()
|
||||
if "generateContent" in m.supported_generation_methods
|
||||
]
|
||||
return {"models": sorted(models, key=lambda x: x["id"])}
|
||||
|
||||
|
||||
def _list_ollama_models() -> dict:
|
||||
import requests as req
|
||||
try:
|
||||
resp = req.get(f"{settings.OLLAMA_BASE_URL}/api/tags", timeout=5)
|
||||
resp.raise_for_status()
|
||||
tags = resp.json().get("models", [])
|
||||
return {"models": [{"id": m["name"], "name": m["name"]} for m in tags]}
|
||||
except Exception as exc:
|
||||
return {"models": [], "error": f"Ollama unreachable: {exc}"}
|
||||
|
||||
|
||||
def _current_model(provider: str) -> str:
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_MODEL
|
||||
elif provider == "anthropic":
|
||||
return settings.ANTHROPIC_MODEL
|
||||
elif provider == "gemini":
|
||||
return settings.GEMINI_MODEL
|
||||
elif provider == "ollama":
|
||||
return settings.OLLAMA_MODEL
|
||||
return "unknown"
|
||||
113
backend/app/api/share.py
Normal file
113
backend/app/api/share.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Public share router — no authentication required.
|
||||
Serves shareable read-only views for briefs and collections.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.bill import Bill, BillDocument
|
||||
from app.models.brief import BillBrief
|
||||
from app.models.collection import Collection, CollectionBill
|
||||
from app.schemas.schemas import (
|
||||
BillSchema,
|
||||
BriefSchema,
|
||||
BriefShareResponse,
|
||||
CollectionDetailSchema,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Brief share ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/brief/{token}", response_model=BriefShareResponse)
|
||||
async def get_shared_brief(
|
||||
token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(BillBrief)
|
||||
.options(
|
||||
selectinload(BillBrief.bill).selectinload(Bill.sponsor),
|
||||
selectinload(BillBrief.bill).selectinload(Bill.briefs),
|
||||
selectinload(BillBrief.bill).selectinload(Bill.trend_scores),
|
||||
)
|
||||
.where(BillBrief.share_token == token)
|
||||
)
|
||||
brief = result.scalar_one_or_none()
|
||||
if not brief:
|
||||
raise HTTPException(status_code=404, detail="Brief not found")
|
||||
|
||||
bill = brief.bill
|
||||
bill_schema = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bill_schema.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bill_schema.latest_trend = bill.trend_scores[0]
|
||||
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id == bill.bill_id).limit(1)
|
||||
)
|
||||
bill_schema.has_document = doc_result.scalar_one_or_none() is not None
|
||||
|
||||
return BriefShareResponse(
|
||||
brief=BriefSchema.model_validate(brief),
|
||||
bill=bill_schema,
|
||||
)
|
||||
|
||||
|
||||
# ── Collection share ──────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/collection/{token}", response_model=CollectionDetailSchema)
|
||||
async def get_shared_collection(
|
||||
token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.briefs),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.trend_scores),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.sponsor),
|
||||
)
|
||||
.where(Collection.share_token == token)
|
||||
)
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
cb_list = collection.collection_bills
|
||||
bills = [cb.bill for cb in cb_list]
|
||||
bill_ids = [b.bill_id for b in bills]
|
||||
|
||||
if bill_ids:
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id.in_(bill_ids)).distinct()
|
||||
)
|
||||
bills_with_docs = {row[0] for row in doc_result}
|
||||
else:
|
||||
bills_with_docs = set()
|
||||
|
||||
bill_schemas = []
|
||||
for bill in bills:
|
||||
bs = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bs.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bs.latest_trend = bill.trend_scores[0]
|
||||
bs.has_document = bill.bill_id in bills_with_docs
|
||||
bill_schemas.append(bs)
|
||||
|
||||
return CollectionDetailSchema(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
slug=collection.slug,
|
||||
is_public=collection.is_public,
|
||||
share_token=collection.share_token,
|
||||
bill_count=len(cb_list),
|
||||
created_at=collection.created_at,
|
||||
bills=bill_schemas,
|
||||
)
|
||||
86
backend/app/config.py
Normal file
86
backend/app/config.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from functools import lru_cache
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
# URLs
|
||||
LOCAL_URL: str = "http://localhost"
|
||||
PUBLIC_URL: str = ""
|
||||
|
||||
# Auth / JWT
|
||||
JWT_SECRET_KEY: str = "change-me-in-production"
|
||||
JWT_EXPIRE_MINUTES: int = 60 * 24 * 7 # 7 days
|
||||
|
||||
# Symmetric encryption for sensitive user prefs (ntfy password, etc.)
|
||||
# Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
# Falls back to JWT_SECRET_KEY derivation if not set (not recommended for production)
|
||||
ENCRYPTION_SECRET_KEY: str = ""
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "postgresql+asyncpg://congress:congress@postgres:5432/pocketveto"
|
||||
SYNC_DATABASE_URL: str = "postgresql://congress:congress@postgres:5432/pocketveto"
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
|
||||
# api.data.gov (shared key for Congress.gov and GovInfo)
|
||||
DATA_GOV_API_KEY: str = ""
|
||||
CONGRESS_POLL_INTERVAL_MINUTES: int = 30
|
||||
|
||||
# LLM
|
||||
LLM_PROVIDER: str = "openai" # openai | anthropic | gemini | ollama
|
||||
|
||||
OPENAI_API_KEY: str = ""
|
||||
OPENAI_MODEL: str = "gpt-4o-mini" # gpt-4o-mini: excellent JSON quality at ~10x lower cost than gpt-4o
|
||||
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
ANTHROPIC_MODEL: str = "claude-sonnet-4-6" # Sonnet matches Opus for structured tasks at ~5x lower cost
|
||||
|
||||
GEMINI_API_KEY: str = ""
|
||||
GEMINI_MODEL: str = "gemini-2.0-flash"
|
||||
|
||||
OLLAMA_BASE_URL: str = "http://host.docker.internal:11434"
|
||||
OLLAMA_MODEL: str = "llama3.1"
|
||||
|
||||
# Max LLM requests per minute — Celery enforces this globally across all workers.
|
||||
# Defaults: free Gemini=15 RPM, Anthropic paid=50 RPM, OpenAI paid=500 RPM.
|
||||
# Lower this in .env if you hit rate limit errors on a restricted tier.
|
||||
LLM_RATE_LIMIT_RPM: int = 50
|
||||
|
||||
# Google Civic Information API (zip → representative lookup)
|
||||
# Free key: https://console.cloud.google.com/apis/library/civicinfo.googleapis.com
|
||||
CIVIC_API_KEY: str = ""
|
||||
|
||||
# News
|
||||
NEWSAPI_KEY: str = ""
|
||||
|
||||
# pytrends
|
||||
PYTRENDS_ENABLED: bool = True
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_secrets(self) -> "Settings":
|
||||
if self.JWT_SECRET_KEY == "change-me-in-production":
|
||||
raise ValueError(
|
||||
"JWT_SECRET_KEY must be set to a secure random value in .env. "
|
||||
"Generate one with: python -c \"import secrets; print(secrets.token_hex(32))\""
|
||||
)
|
||||
return self
|
||||
|
||||
# SMTP (Email notifications)
|
||||
SMTP_HOST: str = ""
|
||||
SMTP_PORT: int = 587
|
||||
SMTP_USER: str = ""
|
||||
SMTP_PASSWORD: str = ""
|
||||
SMTP_FROM: str = "" # Defaults to SMTP_USER if blank
|
||||
SMTP_STARTTLS: bool = True
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
44
backend/app/core/crypto.py
Normal file
44
backend/app/core/crypto.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Symmetric encryption for sensitive user prefs (e.g. ntfy password).
|
||||
|
||||
Key priority:
|
||||
1. ENCRYPTION_SECRET_KEY env var (recommended — dedicated key, easily rotatable)
|
||||
2. Derived from JWT_SECRET_KEY (fallback for existing installs)
|
||||
|
||||
Generate a dedicated key:
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
_PREFIX = "enc:"
|
||||
_fernet_instance: Fernet | None = None
|
||||
|
||||
|
||||
def _fernet() -> Fernet:
|
||||
global _fernet_instance
|
||||
if _fernet_instance is None:
|
||||
from app.config import settings
|
||||
if settings.ENCRYPTION_SECRET_KEY:
|
||||
# Use dedicated key directly (must be a valid 32-byte base64url key)
|
||||
_fernet_instance = Fernet(settings.ENCRYPTION_SECRET_KEY.encode())
|
||||
else:
|
||||
# Fallback: derive from JWT secret
|
||||
key_bytes = hashlib.sha256(settings.JWT_SECRET_KEY.encode()).digest()
|
||||
_fernet_instance = Fernet(base64.urlsafe_b64encode(key_bytes))
|
||||
return _fernet_instance
|
||||
|
||||
|
||||
def encrypt_secret(plaintext: str) -> str:
|
||||
"""Encrypt a string and return a prefixed ciphertext."""
|
||||
if not plaintext:
|
||||
return plaintext
|
||||
return _PREFIX + _fernet().encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def decrypt_secret(value: str) -> str:
|
||||
"""Decrypt a value produced by encrypt_secret. Returns plaintext as-is (legacy support)."""
|
||||
if not value or not value.startswith(_PREFIX):
|
||||
return value # legacy plaintext — return unchanged
|
||||
return _fernet().decrypt(value[len(_PREFIX):].encode()).decode()
|
||||
55
backend/app/core/dependencies.py
Normal file
55
backend/app/core/dependencies.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import decode_token
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/api/auth/login", auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
credentials_error = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
user_id = decode_token(token)
|
||||
except JWTError:
|
||||
raise credentials_error
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
if user is None:
|
||||
raise credentials_error
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
token: str | None = Depends(oauth2_scheme_optional),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User | None:
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
user_id = decode_token(token)
|
||||
return await db.get(User, user_id)
|
||||
except (JWTError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_admin(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required",
|
||||
)
|
||||
return current_user
|
||||
36
backend/app/core/security.py
Normal file
36
backend/app/core/security.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
def create_access_token(user_id: int) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
||||
return jwt.encode(
|
||||
{"sub": str(user_id), "exp": expire},
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str) -> int:
|
||||
"""Decode JWT and return user_id. Raises JWTError on failure."""
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise JWTError("Missing sub claim")
|
||||
return int(user_id)
|
||||
53
backend/app/database.py
Normal file
53
backend/app/database.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
# ─── Async engine (FastAPI) ───────────────────────────────────────────────────
|
||||
|
||||
async_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# ─── Sync engine (Celery workers) ────────────────────────────────────────────
|
||||
|
||||
sync_engine = create_engine(
|
||||
settings.SYNC_DATABASE_URL,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
SyncSessionLocal = sessionmaker(
|
||||
bind=sync_engine,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
def get_sync_db() -> Session:
|
||||
return SyncSessionLocal()
|
||||
34
backend/app/main.py
Normal file
34
backend/app/main.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import bills, members, follows, dashboard, search, settings, admin, health, auth, notifications, notes, collections, share, alignment
|
||||
from app.config import settings as config
|
||||
|
||||
app = FastAPI(
|
||||
title="PocketVeto",
|
||||
description="Monitor US Congressional activity with AI-powered bill summaries.",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[o for o in [config.LOCAL_URL, config.PUBLIC_URL] if o],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(bills.router, prefix="/api/bills", tags=["bills"])
|
||||
app.include_router(members.router, prefix="/api/members", tags=["members"])
|
||||
app.include_router(follows.router, prefix="/api/follows", tags=["follows"])
|
||||
app.include_router(dashboard.router, prefix="/api/dashboard", tags=["dashboard"])
|
||||
app.include_router(search.router, prefix="/api/search", tags=["search"])
|
||||
app.include_router(settings.router, prefix="/api/settings", tags=["settings"])
|
||||
app.include_router(admin.router, prefix="/api/admin", tags=["admin"])
|
||||
app.include_router(health.router, prefix="/api/health", tags=["health"])
|
||||
app.include_router(notifications.router, prefix="/api/notifications", tags=["notifications"])
|
||||
app.include_router(notes.router, prefix="/api/notes", tags=["notes"])
|
||||
app.include_router(collections.router, prefix="/api/collections", tags=["collections"])
|
||||
app.include_router(share.router, prefix="/api/share", tags=["share"])
|
||||
app.include_router(alignment.router, prefix="/api/alignment", tags=["alignment"])
|
||||
0
backend/app/management/__init__.py
Normal file
0
backend/app/management/__init__.py
Normal file
117
backend/app/management/backfill.py
Normal file
117
backend/app/management/backfill.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Historical data backfill script.
|
||||
|
||||
Usage (run inside the api or worker container):
|
||||
python -m app.management.backfill --congress 118 119
|
||||
python -m app.management.backfill --congress 119 --skip-llm
|
||||
|
||||
This script fetches all bills from the specified Congress numbers,
|
||||
stores them in the database, and (optionally) enqueues document fetch
|
||||
and LLM processing tasks for each bill.
|
||||
|
||||
Cost note: LLM processing 15,000+ bills can be expensive.
|
||||
Consider using --skip-llm for initial backfill and processing
|
||||
manually / in batches.
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def backfill_congress(congress_number: int, skip_llm: bool = False, dry_run: bool = False):
|
||||
from app.database import get_sync_db
|
||||
from app.models import AppSetting, Bill, Member
|
||||
from app.services import congress_api
|
||||
from app.workers.congress_poller import _sync_sponsor
|
||||
|
||||
db = get_sync_db()
|
||||
offset = 0
|
||||
total_processed = 0
|
||||
total_new = 0
|
||||
|
||||
logger.info(f"Starting backfill for Congress {congress_number} (skip_llm={skip_llm}, dry_run={dry_run})")
|
||||
|
||||
try:
|
||||
while True:
|
||||
response = congress_api.get_bills(congress=congress_number, offset=offset, limit=250)
|
||||
bills_data = response.get("bills", [])
|
||||
|
||||
if not bills_data:
|
||||
break
|
||||
|
||||
for bill_data in bills_data:
|
||||
parsed = congress_api.parse_bill_from_api(bill_data, congress_number)
|
||||
bill_id = parsed["bill_id"]
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"[DRY RUN] Would process: {bill_id}")
|
||||
total_processed += 1
|
||||
continue
|
||||
|
||||
existing = db.get(Bill, bill_id)
|
||||
if existing:
|
||||
total_processed += 1
|
||||
continue
|
||||
|
||||
# Sync sponsor
|
||||
sponsor_id = _sync_sponsor(db, bill_data)
|
||||
parsed["sponsor_id"] = sponsor_id
|
||||
|
||||
db.add(Bill(**parsed))
|
||||
total_new += 1
|
||||
total_processed += 1
|
||||
|
||||
if total_new % 50 == 0:
|
||||
db.commit()
|
||||
logger.info(f"Progress: {total_processed} processed, {total_new} new")
|
||||
|
||||
# Enqueue document + LLM at low priority
|
||||
if not skip_llm:
|
||||
from app.workers.document_fetcher import fetch_bill_documents
|
||||
fetch_bill_documents.apply_async(args=[bill_id], priority=3)
|
||||
|
||||
# Stay well under Congress.gov rate limit (5,000/hr = ~1.4/sec)
|
||||
time.sleep(0.25)
|
||||
|
||||
db.commit()
|
||||
offset += 250
|
||||
|
||||
if len(bills_data) < 250:
|
||||
break # Last page
|
||||
|
||||
logger.info(f"Fetched page ending at offset {offset}, total processed: {total_processed}")
|
||||
time.sleep(1) # Polite pause between pages
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"Backfill complete: {total_new} new bills added ({total_processed} total processed)")
|
||||
return total_new
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backfill Congressional bill data")
|
||||
parser.add_argument("--congress", type=int, nargs="+", default=[119],
|
||||
help="Congress numbers to backfill (default: 119)")
|
||||
parser.add_argument("--skip-llm", action="store_true",
|
||||
help="Skip LLM processing (fetch documents only, don't enqueue briefs)")
|
||||
parser.add_argument("--dry-run", action="store_true",
|
||||
help="Count bills without actually inserting them")
|
||||
args = parser.parse_args()
|
||||
|
||||
total = 0
|
||||
for congress_number in args.congress:
|
||||
total += backfill_congress(congress_number, skip_llm=args.skip_llm, dry_run=args.dry_run)
|
||||
|
||||
logger.info(f"All done. Total new bills: {total}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
backend/app/models/__init__.py
Normal file
38
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from app.models.bill import Bill, BillAction, BillDocument, BillCosponsor
|
||||
from app.models.brief import BillBrief
|
||||
from app.models.collection import Collection, CollectionBill
|
||||
from app.models.follow import Follow
|
||||
from app.models.member import Member
|
||||
from app.models.member_interest import MemberTrendScore, MemberNewsArticle
|
||||
from app.models.news import NewsArticle
|
||||
from app.models.note import BillNote
|
||||
from app.models.notification import NotificationEvent
|
||||
from app.models.setting import AppSetting
|
||||
from app.models.trend import TrendScore
|
||||
from app.models.committee import Committee, CommitteeBill
|
||||
from app.models.user import User
|
||||
from app.models.vote import BillVote, MemberVotePosition
|
||||
|
||||
__all__ = [
|
||||
"Bill",
|
||||
"BillAction",
|
||||
"BillCosponsor",
|
||||
"BillDocument",
|
||||
"BillBrief",
|
||||
"BillNote",
|
||||
"BillVote",
|
||||
"Collection",
|
||||
"CollectionBill",
|
||||
"Follow",
|
||||
"Member",
|
||||
"MemberTrendScore",
|
||||
"MemberNewsArticle",
|
||||
"MemberVotePosition",
|
||||
"NewsArticle",
|
||||
"NotificationEvent",
|
||||
"AppSetting",
|
||||
"TrendScore",
|
||||
"Committee",
|
||||
"CommitteeBill",
|
||||
"User",
|
||||
]
|
||||
113
backend/app/models/bill.py
Normal file
113
backend/app/models/bill.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Date, DateTime, Text, ForeignKey, Index, UniqueConstraint
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Bill(Base):
|
||||
__tablename__ = "bills"
|
||||
|
||||
# Natural key: "{congress}-{bill_type_lower}-{bill_number}" e.g. "119-hr-1234"
|
||||
bill_id = Column(String, primary_key=True)
|
||||
congress_number = Column(Integer, nullable=False)
|
||||
bill_type = Column(String(10), nullable=False) # hr, s, hjres, sjres, hconres, sconres, hres, sres
|
||||
bill_number = Column(Integer, nullable=False)
|
||||
title = Column(Text)
|
||||
short_title = Column(Text)
|
||||
sponsor_id = Column(String, ForeignKey("members.bioguide_id"), nullable=True)
|
||||
introduced_date = Column(Date)
|
||||
latest_action_date = Column(Date)
|
||||
latest_action_text = Column(Text)
|
||||
status = Column(String(100))
|
||||
chamber = Column(String(50))
|
||||
congress_url = Column(String)
|
||||
govtrack_url = Column(String)
|
||||
|
||||
bill_category = Column(String(20), nullable=True) # substantive | commemorative | administrative
|
||||
cosponsors_fetched_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Ingestion tracking
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
actions_fetched_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
sponsor = relationship("Member", back_populates="bills", foreign_keys=[sponsor_id])
|
||||
actions = relationship("BillAction", back_populates="bill", order_by="desc(BillAction.action_date)")
|
||||
documents = relationship("BillDocument", back_populates="bill")
|
||||
briefs = relationship("BillBrief", back_populates="bill", order_by="desc(BillBrief.created_at)")
|
||||
news_articles = relationship("NewsArticle", back_populates="bill", order_by="desc(NewsArticle.published_at)")
|
||||
trend_scores = relationship("TrendScore", back_populates="bill", order_by="desc(TrendScore.score_date)")
|
||||
committee_bills = relationship("CommitteeBill", back_populates="bill")
|
||||
notes = relationship("BillNote", back_populates="bill", cascade="all, delete-orphan")
|
||||
cosponsors = relationship("BillCosponsor", back_populates="bill", cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_bills_congress_number", "congress_number"),
|
||||
Index("ix_bills_latest_action_date", "latest_action_date"),
|
||||
Index("ix_bills_introduced_date", "introduced_date"),
|
||||
Index("ix_bills_chamber", "chamber"),
|
||||
Index("ix_bills_sponsor_id", "sponsor_id"),
|
||||
)
|
||||
|
||||
|
||||
class BillAction(Base):
|
||||
__tablename__ = "bill_actions"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
action_date = Column(Date)
|
||||
action_text = Column(Text)
|
||||
action_type = Column(String(100))
|
||||
chamber = Column(String(50))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
bill = relationship("Bill", back_populates="actions")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_bill_actions_bill_id", "bill_id"),
|
||||
Index("ix_bill_actions_action_date", "action_date"),
|
||||
)
|
||||
|
||||
|
||||
class BillDocument(Base):
|
||||
__tablename__ = "bill_documents"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
doc_type = Column(String(50)) # bill_text | committee_report | amendment
|
||||
doc_version = Column(String(50)) # Introduced, Enrolled, etc.
|
||||
govinfo_url = Column(String)
|
||||
raw_text = Column(Text)
|
||||
fetched_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
bill = relationship("Bill", back_populates="documents")
|
||||
briefs = relationship("BillBrief", back_populates="document")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_bill_documents_bill_id", "bill_id"),
|
||||
)
|
||||
|
||||
|
||||
class BillCosponsor(Base):
|
||||
__tablename__ = "bill_cosponsors"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
bioguide_id = Column(String, ForeignKey("members.bioguide_id", ondelete="SET NULL"), nullable=True)
|
||||
name = Column(String(200))
|
||||
party = Column(String(50))
|
||||
state = Column(String(10))
|
||||
sponsored_date = Column(Date, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
bill = relationship("Bill", back_populates="cosponsors")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_bill_cosponsors_bill_id", "bill_id"),
|
||||
Index("ix_bill_cosponsors_bioguide_id", "bioguide_id"),
|
||||
)
|
||||
34
backend/app/models/brief.py
Normal file
34
backend/app/models/brief.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy import Column, Integer, String, Text, ForeignKey, DateTime, Index
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class BillBrief(Base):
|
||||
__tablename__ = "bill_briefs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
document_id = Column(Integer, ForeignKey("bill_documents.id", ondelete="SET NULL"), nullable=True)
|
||||
brief_type = Column(String(20), nullable=False, server_default="full") # full | amendment
|
||||
summary = Column(Text)
|
||||
key_points = Column(JSONB) # list[{text, citation, quote}]
|
||||
risks = Column(JSONB) # list[{text, citation, quote}]
|
||||
deadlines = Column(JSONB) # list[{date: str, description: str}]
|
||||
topic_tags = Column(JSONB) # list[str]
|
||||
llm_provider = Column(String(50))
|
||||
llm_model = Column(String(100))
|
||||
govinfo_url = Column(String, nullable=True)
|
||||
share_token = Column(postgresql.UUID(as_uuid=False), nullable=True, server_default=func.gen_random_uuid())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
bill = relationship("Bill", back_populates="briefs")
|
||||
document = relationship("BillDocument", back_populates="briefs")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_bill_briefs_bill_id", "bill_id"),
|
||||
Index("ix_bill_briefs_topic_tags", "topic_tags", postgresql_using="gin"),
|
||||
)
|
||||
51
backend/app/models/collection.py
Normal file
51
backend/app/models/collection.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Collection(Base):
|
||||
__tablename__ = "collections"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
name = Column(String(100), nullable=False)
|
||||
slug = Column(String(120), nullable=False)
|
||||
is_public = Column(Boolean, nullable=False, default=False, server_default="false")
|
||||
share_token = Column(UUID(as_uuid=False), nullable=False, server_default=func.gen_random_uuid())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
user = relationship("User", back_populates="collections")
|
||||
collection_bills = relationship(
|
||||
"CollectionBill",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="CollectionBill.added_at.desc()",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "slug", name="uq_collections_user_slug"),
|
||||
UniqueConstraint("share_token", name="uq_collections_share_token"),
|
||||
Index("ix_collections_user_id", "user_id"),
|
||||
Index("ix_collections_share_token", "share_token"),
|
||||
)
|
||||
|
||||
|
||||
class CollectionBill(Base):
|
||||
__tablename__ = "collection_bills"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
collection_id = Column(Integer, ForeignKey("collections.id", ondelete="CASCADE"), nullable=False)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
added_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
collection = relationship("Collection", back_populates="collection_bills")
|
||||
bill = relationship("Bill")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("collection_id", "bill_id", name="uq_collection_bills_collection_bill"),
|
||||
Index("ix_collection_bills_collection_id", "collection_id"),
|
||||
Index("ix_collection_bills_bill_id", "bill_id"),
|
||||
)
|
||||
33
backend/app/models/committee.py
Normal file
33
backend/app/models/committee.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from sqlalchemy import Column, Integer, String, Date, ForeignKey, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Committee(Base):
|
||||
__tablename__ = "committees"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
committee_code = Column(String(20), unique=True, nullable=False)
|
||||
name = Column(String(500))
|
||||
chamber = Column(String(10))
|
||||
committee_type = Column(String(50)) # Standing, Select, Joint, etc.
|
||||
|
||||
committee_bills = relationship("CommitteeBill", back_populates="committee")
|
||||
|
||||
|
||||
class CommitteeBill(Base):
|
||||
__tablename__ = "committee_bills"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
committee_id = Column(Integer, ForeignKey("committees.id", ondelete="CASCADE"), nullable=False)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
referral_date = Column(Date)
|
||||
|
||||
committee = relationship("Committee", back_populates="committee_bills")
|
||||
bill = relationship("Bill", back_populates="committee_bills")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_committee_bills_bill_id", "bill_id"),
|
||||
Index("ix_committee_bills_committee_id", "committee_id"),
|
||||
)
|
||||
22
backend/app/models/follow.py
Normal file
22
backend/app/models/follow.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Follow(Base):
|
||||
__tablename__ = "follows"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
follow_type = Column(String(20), nullable=False) # bill | member | topic
|
||||
follow_value = Column(String, nullable=False) # bill_id | bioguide_id | tag string
|
||||
follow_mode = Column(String(20), nullable=False, default="neutral") # neutral | pocket_veto | pocket_boost
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
user = relationship("User", back_populates="follows")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "follow_type", "follow_value", name="uq_follows_user_type_value"),
|
||||
)
|
||||
45
backend/app/models/member.py
Normal file
45
backend/app/models/member.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Column, Integer, JSON, String, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Member(Base):
|
||||
__tablename__ = "members"
|
||||
|
||||
bioguide_id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
first_name = Column(String)
|
||||
last_name = Column(String)
|
||||
party = Column(String(50))
|
||||
state = Column(String(50))
|
||||
chamber = Column(String(50))
|
||||
district = Column(String(50))
|
||||
photo_url = Column(String)
|
||||
official_url = Column(String)
|
||||
congress_url = Column(String)
|
||||
birth_year = Column(String(10))
|
||||
address = Column(String)
|
||||
phone = Column(String(50))
|
||||
terms_json = Column(JSON)
|
||||
leadership_json = Column(JSON)
|
||||
sponsored_count = Column(Integer)
|
||||
cosponsored_count = Column(Integer)
|
||||
effectiveness_score = Column(sa.Float, nullable=True)
|
||||
effectiveness_percentile = Column(sa.Float, nullable=True)
|
||||
effectiveness_tier = Column(String(20), nullable=True) # junior | mid | senior
|
||||
detail_fetched = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
bills = relationship("Bill", back_populates="sponsor", foreign_keys="Bill.sponsor_id")
|
||||
trend_scores = relationship(
|
||||
"MemberTrendScore", back_populates="member",
|
||||
order_by="desc(MemberTrendScore.score_date)", cascade="all, delete-orphan"
|
||||
)
|
||||
news_articles = relationship(
|
||||
"MemberNewsArticle", back_populates="member",
|
||||
order_by="desc(MemberNewsArticle.published_at)", cascade="all, delete-orphan"
|
||||
)
|
||||
47
backend/app/models/member_interest.py
Normal file
47
backend/app/models/member_interest.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from sqlalchemy import Column, Integer, String, Date, Float, Text, DateTime, ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class MemberTrendScore(Base):
|
||||
__tablename__ = "member_trend_scores"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
member_id = Column(String, ForeignKey("members.bioguide_id", ondelete="CASCADE"), nullable=False)
|
||||
score_date = Column(Date, nullable=False)
|
||||
newsapi_count = Column(Integer, default=0)
|
||||
gnews_count = Column(Integer, default=0)
|
||||
gtrends_score = Column(Float, default=0.0)
|
||||
composite_score = Column(Float, default=0.0)
|
||||
|
||||
member = relationship("Member", back_populates="trend_scores")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("member_id", "score_date", name="uq_member_trend_scores_member_date"),
|
||||
Index("ix_member_trend_scores_member_id", "member_id"),
|
||||
Index("ix_member_trend_scores_score_date", "score_date"),
|
||||
Index("ix_member_trend_scores_composite", "composite_score"),
|
||||
)
|
||||
|
||||
|
||||
class MemberNewsArticle(Base):
|
||||
__tablename__ = "member_news_articles"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
member_id = Column(String, ForeignKey("members.bioguide_id", ondelete="CASCADE"), nullable=False)
|
||||
source = Column(String(200))
|
||||
headline = Column(Text)
|
||||
url = Column(String)
|
||||
published_at = Column(DateTime(timezone=True))
|
||||
relevance_score = Column(Float, default=0.0)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
member = relationship("Member", back_populates="news_articles")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("member_id", "url", name="uq_member_news_member_url"),
|
||||
Index("ix_member_news_articles_member_id", "member_id"),
|
||||
Index("ix_member_news_articles_published_at", "published_at"),
|
||||
)
|
||||
26
backend/app/models/news.py
Normal file
26
backend/app/models/news.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import Column, Integer, String, Text, Float, DateTime, ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class NewsArticle(Base):
|
||||
__tablename__ = "news_articles"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
source = Column(String(200))
|
||||
headline = Column(Text)
|
||||
url = Column(String)
|
||||
published_at = Column(DateTime(timezone=True))
|
||||
relevance_score = Column(Float, default=0.0)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
bill = relationship("Bill", back_populates="news_articles")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("bill_id", "url", name="uq_news_articles_bill_url"),
|
||||
Index("ix_news_articles_bill_id", "bill_id"),
|
||||
Index("ix_news_articles_published_at", "published_at"),
|
||||
)
|
||||
26
backend/app/models/note.py
Normal file
26
backend/app/models/note.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class BillNote(Base):
|
||||
__tablename__ = "bill_notes"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
pinned = Column(Boolean, nullable=False, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
user = relationship("User", back_populates="bill_notes")
|
||||
bill = relationship("Bill", back_populates="notes")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "bill_id", name="uq_bill_notes_user_bill"),
|
||||
Index("ix_bill_notes_user_id", "user_id"),
|
||||
Index("ix_bill_notes_bill_id", "bill_id"),
|
||||
)
|
||||
27
backend/app/models/notification.py
Normal file
27
backend/app/models/notification.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class NotificationEvent(Base):
|
||||
__tablename__ = "notification_events"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
# new_document | new_amendment | bill_updated
|
||||
event_type = Column(String(50), nullable=False)
|
||||
# {bill_title, bill_label, brief_summary, bill_url}
|
||||
payload = Column(JSONB)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
dispatched_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="notification_events")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_notification_events_user_id", "user_id"),
|
||||
Index("ix_notification_events_dispatched_at", "dispatched_at"),
|
||||
)
|
||||
12
backend/app/models/setting.py
Normal file
12
backend/app/models/setting.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import Column, String, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class AppSetting(Base):
|
||||
__tablename__ = "app_settings"
|
||||
|
||||
key = Column(String, primary_key=True)
|
||||
value = Column(String)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
25
backend/app/models/trend.py
Normal file
25
backend/app/models/trend.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from sqlalchemy import Column, Integer, String, Date, Float, ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class TrendScore(Base):
|
||||
__tablename__ = "trend_scores"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
score_date = Column(Date, nullable=False)
|
||||
newsapi_count = Column(Integer, default=0)
|
||||
gnews_count = Column(Integer, default=0)
|
||||
gtrends_score = Column(Float, default=0.0)
|
||||
composite_score = Column(Float, default=0.0)
|
||||
|
||||
bill = relationship("Bill", back_populates="trend_scores")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("bill_id", "score_date", name="uq_trend_scores_bill_date"),
|
||||
Index("ix_trend_scores_bill_id", "bill_id"),
|
||||
Index("ix_trend_scores_score_date", "score_date"),
|
||||
Index("ix_trend_scores_composite", "composite_score"),
|
||||
)
|
||||
24
backend/app/models/user.py
Normal file
24
backend/app/models/user.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
email = Column(String, unique=True, nullable=False, index=True)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
is_admin = Column(Boolean, nullable=False, default=False)
|
||||
notification_prefs = Column(JSONB, nullable=False, default=dict)
|
||||
rss_token = Column(String, unique=True, nullable=True, index=True)
|
||||
email_unsubscribe_token = Column(String(64), unique=True, nullable=True, index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
follows = relationship("Follow", back_populates="user", cascade="all, delete-orphan")
|
||||
notification_events = relationship("NotificationEvent", back_populates="user", cascade="all, delete-orphan")
|
||||
bill_notes = relationship("BillNote", back_populates="user", cascade="all, delete-orphan")
|
||||
collections = relationship("Collection", back_populates="user", cascade="all, delete-orphan")
|
||||
53
backend/app/models/vote.py
Normal file
53
backend/app/models/vote.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sqlalchemy import Column, Date, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class BillVote(Base):
|
||||
__tablename__ = "bill_votes"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
bill_id = Column(String, ForeignKey("bills.bill_id", ondelete="CASCADE"), nullable=False)
|
||||
congress = Column(Integer, nullable=False)
|
||||
chamber = Column(String(50), nullable=False)
|
||||
session = Column(Integer, nullable=False)
|
||||
roll_number = Column(Integer, nullable=False)
|
||||
question = Column(Text)
|
||||
description = Column(Text)
|
||||
vote_date = Column(Date)
|
||||
yeas = Column(Integer)
|
||||
nays = Column(Integer)
|
||||
not_voting = Column(Integer)
|
||||
result = Column(String(200))
|
||||
source_url = Column(String)
|
||||
fetched_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
positions = relationship("MemberVotePosition", back_populates="vote", cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_bill_votes_bill_id", "bill_id"),
|
||||
UniqueConstraint("congress", "chamber", "session", "roll_number", name="uq_bill_votes_roll"),
|
||||
)
|
||||
|
||||
|
||||
class MemberVotePosition(Base):
|
||||
__tablename__ = "member_vote_positions"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
vote_id = Column(Integer, ForeignKey("bill_votes.id", ondelete="CASCADE"), nullable=False)
|
||||
bioguide_id = Column(String, ForeignKey("members.bioguide_id", ondelete="SET NULL"), nullable=True)
|
||||
member_name = Column(String(200))
|
||||
party = Column(String(50))
|
||||
state = Column(String(10))
|
||||
position = Column(String(50), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
vote = relationship("BillVote", back_populates="positions")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_member_vote_positions_vote_id", "vote_id"),
|
||||
Index("ix_member_vote_positions_bioguide_id", "bioguide_id"),
|
||||
)
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
381
backend/app/schemas/schemas.py
Normal file
381
backend/app/schemas/schemas.py
Normal file
@@ -0,0 +1,381 @@
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
# ── Notifications ──────────────────────────────────────────────────────────────
|
||||
|
||||
# ── Bill Notes ────────────────────────────────────────────────────────────────
|
||||
|
||||
class BillNoteSchema(BaseModel):
|
||||
id: int
|
||||
bill_id: str
|
||||
content: str
|
||||
pinned: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BillNoteUpsert(BaseModel):
|
||||
content: str
|
||||
pinned: bool = False
|
||||
|
||||
|
||||
# ── Notifications ──────────────────────────────────────────────────────────────
|
||||
|
||||
class NotificationSettingsResponse(BaseModel):
|
||||
ntfy_topic_url: str = ""
|
||||
ntfy_auth_method: str = "none" # none | token | basic
|
||||
ntfy_token: str = ""
|
||||
ntfy_username: str = ""
|
||||
ntfy_password_set: bool = False
|
||||
ntfy_enabled: bool = False
|
||||
rss_enabled: bool = False
|
||||
rss_token: Optional[str] = None
|
||||
email_enabled: bool = False
|
||||
email_address: str = ""
|
||||
# Digest
|
||||
digest_enabled: bool = False
|
||||
digest_frequency: str = "daily" # daily | weekly
|
||||
# Quiet hours — stored as local-time hour integers (0-23); timezone is IANA name
|
||||
quiet_hours_start: Optional[int] = None
|
||||
quiet_hours_end: Optional[int] = None
|
||||
timezone: Optional[str] = None # IANA name, e.g. "America/New_York"
|
||||
alert_filters: Optional[dict] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class NotificationSettingsUpdate(BaseModel):
|
||||
ntfy_topic_url: Optional[str] = None
|
||||
ntfy_auth_method: Optional[str] = None
|
||||
ntfy_token: Optional[str] = None
|
||||
ntfy_username: Optional[str] = None
|
||||
ntfy_password: Optional[str] = None
|
||||
ntfy_enabled: Optional[bool] = None
|
||||
rss_enabled: Optional[bool] = None
|
||||
email_enabled: Optional[bool] = None
|
||||
email_address: Optional[str] = None
|
||||
digest_enabled: Optional[bool] = None
|
||||
digest_frequency: Optional[str] = None
|
||||
quiet_hours_start: Optional[int] = None
|
||||
quiet_hours_end: Optional[int] = None
|
||||
timezone: Optional[str] = None # IANA name sent by the browser on save
|
||||
alert_filters: Optional[dict] = None
|
||||
|
||||
|
||||
class NotificationEventSchema(BaseModel):
|
||||
id: int
|
||||
bill_id: str
|
||||
event_type: str
|
||||
payload: Optional[Any] = None
|
||||
dispatched_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class NtfyTestRequest(BaseModel):
|
||||
ntfy_topic_url: str
|
||||
ntfy_auth_method: str = "none"
|
||||
ntfy_token: str = ""
|
||||
ntfy_username: str = ""
|
||||
ntfy_password: str = ""
|
||||
|
||||
|
||||
class FollowModeTestRequest(BaseModel):
|
||||
mode: str # pocket_veto | pocket_boost
|
||||
event_type: str # new_document | new_amendment | bill_updated
|
||||
|
||||
|
||||
class NotificationTestResult(BaseModel):
|
||||
status: str # "ok" | "error"
|
||||
detail: str
|
||||
event_count: Optional[int] = None # RSS only
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
items: list[T]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
pages: int
|
||||
|
||||
|
||||
# ── Member ────────────────────────────────────────────────────────────────────
|
||||
|
||||
class MemberSchema(BaseModel):
|
||||
bioguide_id: str
|
||||
name: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
party: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
chamber: Optional[str] = None
|
||||
district: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
official_url: Optional[str] = None
|
||||
congress_url: Optional[str] = None
|
||||
birth_year: Optional[str] = None
|
||||
address: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
terms_json: Optional[list[Any]] = None
|
||||
leadership_json: Optional[list[Any]] = None
|
||||
sponsored_count: Optional[int] = None
|
||||
cosponsored_count: Optional[int] = None
|
||||
effectiveness_score: Optional[float] = None
|
||||
effectiveness_percentile: Optional[float] = None
|
||||
effectiveness_tier: Optional[str] = None
|
||||
latest_trend: Optional["MemberTrendScoreSchema"] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Bill Brief ────────────────────────────────────────────────────────────────
|
||||
|
||||
class BriefSchema(BaseModel):
|
||||
id: int
|
||||
brief_type: str = "full"
|
||||
summary: Optional[str] = None
|
||||
key_points: Optional[list[Any]] = None
|
||||
risks: Optional[list[Any]] = None
|
||||
deadlines: Optional[list[dict[str, Any]]] = None
|
||||
topic_tags: Optional[list[str]] = None
|
||||
llm_provider: Optional[str] = None
|
||||
llm_model: Optional[str] = None
|
||||
govinfo_url: Optional[str] = None
|
||||
share_token: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Bill Action ───────────────────────────────────────────────────────────────
|
||||
|
||||
class BillActionSchema(BaseModel):
|
||||
id: int
|
||||
action_date: Optional[date] = None
|
||||
action_text: Optional[str] = None
|
||||
action_type: Optional[str] = None
|
||||
chamber: Optional[str] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── News Article ──────────────────────────────────────────────────────────────
|
||||
|
||||
class NewsArticleSchema(BaseModel):
|
||||
id: int
|
||||
source: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
published_at: Optional[datetime] = None
|
||||
relevance_score: Optional[float] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Trend Score ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TrendScoreSchema(BaseModel):
|
||||
score_date: date
|
||||
newsapi_count: int
|
||||
gnews_count: int
|
||||
gtrends_score: float
|
||||
composite_score: float
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MemberTrendScoreSchema(BaseModel):
|
||||
score_date: date
|
||||
newsapi_count: int
|
||||
gnews_count: int
|
||||
gtrends_score: float
|
||||
composite_score: float
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MemberNewsArticleSchema(BaseModel):
|
||||
id: int
|
||||
source: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
published_at: Optional[datetime] = None
|
||||
relevance_score: Optional[float] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Bill ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class BillSchema(BaseModel):
|
||||
bill_id: str
|
||||
congress_number: int
|
||||
bill_type: str
|
||||
bill_number: int
|
||||
title: Optional[str] = None
|
||||
short_title: Optional[str] = None
|
||||
introduced_date: Optional[date] = None
|
||||
latest_action_date: Optional[date] = None
|
||||
latest_action_text: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
chamber: Optional[str] = None
|
||||
congress_url: Optional[str] = None
|
||||
sponsor: Optional[MemberSchema] = None
|
||||
latest_brief: Optional[BriefSchema] = None
|
||||
latest_trend: Optional[TrendScoreSchema] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
bill_category: Optional[str] = None
|
||||
has_document: bool = False
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BillDetailSchema(BillSchema):
|
||||
actions: list[BillActionSchema] = []
|
||||
news_articles: list[NewsArticleSchema] = []
|
||||
trend_scores: list[TrendScoreSchema] = []
|
||||
briefs: list[BriefSchema] = []
|
||||
has_document: bool = False
|
||||
|
||||
|
||||
# ── Follow ────────────────────────────────────────────────────────────────────
|
||||
|
||||
class FollowCreate(BaseModel):
|
||||
follow_type: str # bill | member | topic
|
||||
follow_value: str
|
||||
|
||||
|
||||
class FollowSchema(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
follow_type: str
|
||||
follow_value: str
|
||||
follow_mode: str = "neutral"
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FollowModeUpdate(BaseModel):
|
||||
follow_mode: str
|
||||
|
||||
|
||||
# ── Settings ──────────────────────────────────────────────────────────────────
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
is_admin: bool
|
||||
notification_prefs: dict
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
user: "UserResponse"
|
||||
|
||||
|
||||
# ── Settings ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class SettingUpdate(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class SettingsResponse(BaseModel):
|
||||
llm_provider: str
|
||||
llm_model: str
|
||||
congress_poll_interval_minutes: int
|
||||
newsapi_enabled: bool
|
||||
pytrends_enabled: bool
|
||||
api_keys_configured: dict[str, bool]
|
||||
|
||||
|
||||
# ── Collections ────────────────────────────────────────────────────────────────
|
||||
|
||||
class CollectionCreate(BaseModel):
|
||||
name: str
|
||||
is_public: bool = False
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not 1 <= len(v) <= 100:
|
||||
raise ValueError("name must be 1–100 characters")
|
||||
return v
|
||||
|
||||
|
||||
class CollectionUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
is_public: Optional[bool] = None
|
||||
|
||||
|
||||
class CollectionSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
slug: str
|
||||
is_public: bool
|
||||
share_token: str
|
||||
bill_count: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CollectionDetailSchema(CollectionSchema):
|
||||
bills: list[BillSchema]
|
||||
|
||||
|
||||
class BriefShareResponse(BaseModel):
|
||||
brief: BriefSchema
|
||||
bill: BillSchema
|
||||
|
||||
|
||||
# ── Votes ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class MemberVotePositionSchema(BaseModel):
|
||||
bioguide_id: Optional[str] = None
|
||||
member_name: Optional[str] = None
|
||||
party: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
position: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BillVoteSchema(BaseModel):
|
||||
id: int
|
||||
congress: int
|
||||
chamber: str
|
||||
session: int
|
||||
roll_number: int
|
||||
question: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
vote_date: Optional[date] = None
|
||||
yeas: Optional[int] = None
|
||||
nays: Optional[int] = None
|
||||
not_voting: Optional[int] = None
|
||||
result: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
positions: list[MemberVotePositionSchema] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
228
backend/app/services/congress_api.py
Normal file
228
backend/app/services/congress_api.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Congress.gov API client.
|
||||
|
||||
Rate limit: 5,000 requests/hour (enforced server-side by Congress.gov).
|
||||
We track usage in Redis to stay well under the limit.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import settings
|
||||
|
||||
BASE_URL = "https://api.congress.gov/v3"
|
||||
|
||||
_BILL_TYPE_SLUG = {
|
||||
"hr": "house-bill",
|
||||
"s": "senate-bill",
|
||||
"hjres": "house-joint-resolution",
|
||||
"sjres": "senate-joint-resolution",
|
||||
"hres": "house-resolution",
|
||||
"sres": "senate-resolution",
|
||||
"hconres": "house-concurrent-resolution",
|
||||
"sconres": "senate-concurrent-resolution",
|
||||
}
|
||||
|
||||
|
||||
def _congress_ordinal(n: int) -> str:
|
||||
if 11 <= n % 100 <= 13:
|
||||
return f"{n}th"
|
||||
suffixes = {1: "st", 2: "nd", 3: "rd"}
|
||||
return f"{n}{suffixes.get(n % 10, 'th')}"
|
||||
|
||||
|
||||
def build_bill_public_url(congress: int, bill_type: str, bill_number: int) -> str:
|
||||
"""Return the public congress.gov page URL for a bill (not the API endpoint)."""
|
||||
slug = _BILL_TYPE_SLUG.get(bill_type.lower(), bill_type.lower())
|
||||
return f"https://www.congress.gov/bill/{_congress_ordinal(congress)}-congress/{slug}/{bill_number}"
|
||||
|
||||
|
||||
def _get_current_congress() -> int:
|
||||
"""Calculate the current Congress number. 119th started Jan 3, 2025."""
|
||||
year = datetime.utcnow().year
|
||||
# Congress changes on odd years (Jan 3)
|
||||
if datetime.utcnow().month == 1 and datetime.utcnow().day < 3:
|
||||
year -= 1
|
||||
return 118 + ((year - 2023) // 2 + (1 if year % 2 == 1 else 0))
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
|
||||
def _get(endpoint: str, params: dict) -> dict:
|
||||
params["api_key"] = settings.DATA_GOV_API_KEY
|
||||
params["format"] = "json"
|
||||
response = requests.get(f"{BASE_URL}{endpoint}", params=params, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_current_congress() -> int:
|
||||
return _get_current_congress()
|
||||
|
||||
|
||||
def build_bill_id(congress: int, bill_type: str, bill_number: int) -> str:
|
||||
return f"{congress}-{bill_type.lower()}-{bill_number}"
|
||||
|
||||
|
||||
def get_bills(
|
||||
congress: int,
|
||||
offset: int = 0,
|
||||
limit: int = 250,
|
||||
from_date_time: Optional[str] = None,
|
||||
) -> dict:
|
||||
params: dict = {"offset": offset, "limit": limit, "sort": "updateDate+desc"}
|
||||
if from_date_time:
|
||||
params["fromDateTime"] = from_date_time
|
||||
return _get(f"/bill/{congress}", params)
|
||||
|
||||
|
||||
def get_bill_detail(congress: int, bill_type: str, bill_number: int) -> dict:
|
||||
return _get(f"/bill/{congress}/{bill_type.lower()}/{bill_number}", {})
|
||||
|
||||
|
||||
def get_bill_actions(congress: int, bill_type: str, bill_number: int, offset: int = 0) -> dict:
|
||||
return _get(f"/bill/{congress}/{bill_type.lower()}/{bill_number}/actions", {"offset": offset, "limit": 250})
|
||||
|
||||
|
||||
def get_bill_cosponsors(congress: int, bill_type: str, bill_number: int, offset: int = 0) -> dict:
|
||||
return _get(f"/bill/{congress}/{bill_type.lower()}/{bill_number}/cosponsors", {"offset": offset, "limit": 250})
|
||||
|
||||
|
||||
def get_bill_text_versions(congress: int, bill_type: str, bill_number: int) -> dict:
|
||||
return _get(f"/bill/{congress}/{bill_type.lower()}/{bill_number}/text", {})
|
||||
|
||||
|
||||
def get_vote_detail(congress: int, chamber: str, session: int, roll_number: int) -> dict:
|
||||
chamber_slug = "house" if chamber.lower() == "house" else "senate"
|
||||
return _get(f"/vote/{congress}/{chamber_slug}/{session}/{roll_number}", {})
|
||||
|
||||
|
||||
def get_members(offset: int = 0, limit: int = 250, current_member: bool = True) -> dict:
|
||||
params: dict = {"offset": offset, "limit": limit}
|
||||
if current_member:
|
||||
params["currentMember"] = "true"
|
||||
return _get("/member", params)
|
||||
|
||||
|
||||
def get_member_detail(bioguide_id: str) -> dict:
|
||||
return _get(f"/member/{bioguide_id}", {})
|
||||
|
||||
|
||||
def get_committees(offset: int = 0, limit: int = 250) -> dict:
|
||||
return _get("/committee", {"offset": offset, "limit": limit})
|
||||
|
||||
|
||||
def parse_bill_from_api(data: dict, congress: int) -> dict:
|
||||
"""Normalize raw API bill data into our model fields."""
|
||||
bill_type = data.get("type", "").lower()
|
||||
bill_number = data.get("number", 0)
|
||||
latest_action = data.get("latestAction") or {}
|
||||
return {
|
||||
"bill_id": build_bill_id(congress, bill_type, bill_number),
|
||||
"congress_number": congress,
|
||||
"bill_type": bill_type,
|
||||
"bill_number": bill_number,
|
||||
"title": data.get("title"),
|
||||
"short_title": data.get("shortTitle"),
|
||||
"introduced_date": data.get("introducedDate"),
|
||||
"latest_action_date": latest_action.get("actionDate"),
|
||||
"latest_action_text": latest_action.get("text"),
|
||||
"status": latest_action.get("text", "")[:100] if latest_action.get("text") else None,
|
||||
"chamber": "House" if bill_type.startswith("h") else "Senate",
|
||||
"congress_url": build_bill_public_url(congress, bill_type, bill_number),
|
||||
}
|
||||
|
||||
|
||||
_STATE_NAME_TO_CODE: dict[str, str] = {
|
||||
"Alabama": "AL", "Alaska": "AK", "Arizona": "AZ", "Arkansas": "AR",
|
||||
"California": "CA", "Colorado": "CO", "Connecticut": "CT", "Delaware": "DE",
|
||||
"Florida": "FL", "Georgia": "GA", "Hawaii": "HI", "Idaho": "ID",
|
||||
"Illinois": "IL", "Indiana": "IN", "Iowa": "IA", "Kansas": "KS",
|
||||
"Kentucky": "KY", "Louisiana": "LA", "Maine": "ME", "Maryland": "MD",
|
||||
"Massachusetts": "MA", "Michigan": "MI", "Minnesota": "MN", "Mississippi": "MS",
|
||||
"Missouri": "MO", "Montana": "MT", "Nebraska": "NE", "Nevada": "NV",
|
||||
"New Hampshire": "NH", "New Jersey": "NJ", "New Mexico": "NM", "New York": "NY",
|
||||
"North Carolina": "NC", "North Dakota": "ND", "Ohio": "OH", "Oklahoma": "OK",
|
||||
"Oregon": "OR", "Pennsylvania": "PA", "Rhode Island": "RI", "South Carolina": "SC",
|
||||
"South Dakota": "SD", "Tennessee": "TN", "Texas": "TX", "Utah": "UT",
|
||||
"Vermont": "VT", "Virginia": "VA", "Washington": "WA", "West Virginia": "WV",
|
||||
"Wisconsin": "WI", "Wyoming": "WY",
|
||||
"American Samoa": "AS", "Guam": "GU", "Northern Mariana Islands": "MP",
|
||||
"Puerto Rico": "PR", "Virgin Islands": "VI", "District of Columbia": "DC",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_state(state: str | None) -> str | None:
|
||||
if not state:
|
||||
return None
|
||||
s = state.strip()
|
||||
if len(s) == 2:
|
||||
return s.upper()
|
||||
return _STATE_NAME_TO_CODE.get(s, s)
|
||||
|
||||
|
||||
def parse_member_from_api(data: dict) -> dict:
|
||||
"""Normalize raw API member list data into our model fields."""
|
||||
terms = data.get("terms", {}).get("item", [])
|
||||
current_term = terms[-1] if terms else {}
|
||||
return {
|
||||
"bioguide_id": data.get("bioguideId"),
|
||||
"name": data.get("name", ""),
|
||||
"first_name": data.get("firstName"),
|
||||
"last_name": data.get("lastName"),
|
||||
"party": data.get("partyName") or None,
|
||||
"state": _normalize_state(data.get("state")),
|
||||
"chamber": current_term.get("chamber"),
|
||||
"district": str(data.get("district")) if data.get("district") else None,
|
||||
"photo_url": data.get("depiction", {}).get("imageUrl"),
|
||||
"official_url": data.get("officialWebsiteUrl"),
|
||||
}
|
||||
|
||||
|
||||
def parse_member_detail_from_api(data: dict) -> dict:
|
||||
"""Normalize Congress.gov member detail response into enrichment fields."""
|
||||
member = data.get("member", data)
|
||||
addr = member.get("addressInformation") or {}
|
||||
terms_raw = member.get("terms", [])
|
||||
if isinstance(terms_raw, dict):
|
||||
terms_raw = terms_raw.get("item", [])
|
||||
leadership_raw = member.get("leadership") or []
|
||||
if isinstance(leadership_raw, dict):
|
||||
leadership_raw = leadership_raw.get("item", [])
|
||||
first = member.get("firstName", "")
|
||||
last = member.get("lastName", "")
|
||||
bioguide_id = member.get("bioguideId", "")
|
||||
slug = f"{first}-{last}".lower().replace(" ", "-").replace("'", "")
|
||||
return {
|
||||
"birth_year": str(member["birthYear"]) if member.get("birthYear") else None,
|
||||
"address": addr.get("officeAddress"),
|
||||
"phone": addr.get("phoneNumber"),
|
||||
"official_url": member.get("officialWebsiteUrl"),
|
||||
"photo_url": (member.get("depiction") or {}).get("imageUrl"),
|
||||
"congress_url": f"https://www.congress.gov/member/{slug}/{bioguide_id}" if bioguide_id else None,
|
||||
"terms_json": [
|
||||
{
|
||||
"congress": t.get("congress"),
|
||||
"chamber": t.get("chamber"),
|
||||
"partyName": t.get("partyName"),
|
||||
"stateCode": t.get("stateCode"),
|
||||
"stateName": t.get("stateName"),
|
||||
"startYear": t.get("startYear"),
|
||||
"endYear": t.get("endYear"),
|
||||
"district": t.get("district"),
|
||||
}
|
||||
for t in terms_raw
|
||||
],
|
||||
"leadership_json": [
|
||||
{
|
||||
"type": l.get("type"),
|
||||
"congress": l.get("congress"),
|
||||
"current": l.get("current"),
|
||||
}
|
||||
for l in leadership_raw
|
||||
],
|
||||
"sponsored_count": (member.get("sponsoredLegislation") or {}).get("count"),
|
||||
"cosponsored_count": (member.get("cosponsoredLegislation") or {}).get("count"),
|
||||
}
|
||||
138
backend/app/services/govinfo_api.py
Normal file
138
backend/app/services/govinfo_api.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
GovInfo API client for fetching actual bill text.
|
||||
|
||||
Priority order for text formats: htm > txt > pdf
|
||||
ETag support: stores ETags in Redis so repeat fetches skip unchanged documents.
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GOVINFO_BASE = "https://api.govinfo.gov"
|
||||
FORMAT_PRIORITY = ["htm", "html", "txt", "pdf"]
|
||||
_ETAG_CACHE_TTL = 86400 * 30 # 30 days
|
||||
|
||||
|
||||
class DocumentUnchangedError(Exception):
|
||||
"""Raised when GovInfo confirms the document is unchanged via ETag (HTTP 304)."""
|
||||
pass
|
||||
|
||||
|
||||
def _etag_redis():
|
||||
import redis
|
||||
return redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
def _etag_key(url: str) -> str:
|
||||
return f"govinfo:etag:{hashlib.md5(url.encode()).hexdigest()}"
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=15))
|
||||
def _get(url: str, params: dict = None) -> requests.Response:
|
||||
p = {"api_key": settings.DATA_GOV_API_KEY, **(params or {})}
|
||||
response = requests.get(url, params=p, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
|
||||
def get_package_summary(package_id: str) -> dict:
|
||||
response = _get(f"{GOVINFO_BASE}/packages/{package_id}/summary")
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_package_content_detail(package_id: str) -> dict:
|
||||
response = _get(f"{GOVINFO_BASE}/packages/{package_id}/content-detail")
|
||||
return response.json()
|
||||
|
||||
|
||||
def find_best_text_url(text_versions: list[dict]) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
From a list of text version objects (from Congress.gov API), find the best
|
||||
available text format. Returns (url, format) or None.
|
||||
"""
|
||||
for fmt in FORMAT_PRIORITY:
|
||||
for version in text_versions:
|
||||
for fmt_info in version.get("formats", []):
|
||||
if not isinstance(fmt_info, dict):
|
||||
continue
|
||||
url = fmt_info.get("url", "")
|
||||
if url.lower().endswith(f".{fmt}"):
|
||||
return url, fmt
|
||||
return None, None
|
||||
|
||||
|
||||
def fetch_text_from_url(url: str, fmt: str) -> Optional[str]:
|
||||
"""
|
||||
Download and extract plain text from a GovInfo document URL.
|
||||
|
||||
Uses ETag conditional GET: if GovInfo returns 304 Not Modified,
|
||||
raises DocumentUnchangedError so the caller can skip reprocessing.
|
||||
On a successful 200 response, stores the new ETag in Redis for next time.
|
||||
"""
|
||||
headers = {}
|
||||
try:
|
||||
stored_etag = _etag_redis().get(_etag_key(url))
|
||||
if stored_etag:
|
||||
headers["If-None-Match"] = stored_etag
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=120)
|
||||
|
||||
if response.status_code == 304:
|
||||
raise DocumentUnchangedError(f"Document unchanged (ETag match): {url}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Persist ETag for future conditional requests
|
||||
etag = response.headers.get("ETag")
|
||||
if etag:
|
||||
try:
|
||||
_etag_redis().setex(_etag_key(url), _ETAG_CACHE_TTL, etag)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if fmt in ("htm", "html"):
|
||||
return _extract_from_html(response.text)
|
||||
elif fmt == "txt":
|
||||
return response.text
|
||||
elif fmt == "pdf":
|
||||
return _extract_from_pdf(response.content)
|
||||
|
||||
except DocumentUnchangedError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch text from {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_from_html(html: str) -> str:
|
||||
"""Strip HTML tags and clean up whitespace."""
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
for tag in soup(["script", "style", "nav", "header", "footer"]):
|
||||
tag.decompose()
|
||||
text = soup.get_text(separator="\n")
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
text = re.sub(r" {2,}", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _extract_from_pdf(content: bytes) -> Optional[str]:
|
||||
"""Extract text from PDF bytes using pdfminer."""
|
||||
try:
|
||||
from io import BytesIO
|
||||
from pdfminer.high_level import extract_text as pdf_extract
|
||||
return pdf_extract(BytesIO(content))
|
||||
except Exception as e:
|
||||
logger.error(f"PDF extraction failed: {e}")
|
||||
return None
|
||||
523
backend/app/services/llm_service.py
Normal file
523
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,523 @@
|
||||
"""
|
||||
LLM provider abstraction.
|
||||
|
||||
All providers implement generate_brief(doc_text, bill_metadata) -> ReverseBrief.
|
||||
Select provider via LLM_PROVIDER env var.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Raised when a provider returns a rate-limit response (HTTP 429 / quota exceeded)."""
|
||||
|
||||
def __init__(self, provider: str, retry_after: int = 60):
|
||||
self.provider = provider
|
||||
self.retry_after = retry_after
|
||||
super().__init__(f"{provider} rate limit exceeded; retry after {retry_after}s")
|
||||
|
||||
|
||||
def _detect_rate_limit(exc: Exception) -> bool:
|
||||
"""Return True if exc represents a provider rate-limit / quota error."""
|
||||
exc_type = type(exc).__name__.lower()
|
||||
exc_str = str(exc).lower()
|
||||
# OpenAI / Anthropic SDK raise a class named *RateLimitError
|
||||
if "ratelimit" in exc_type or "rate_limit" in exc_type:
|
||||
return True
|
||||
# Google Gemini SDK raises ResourceExhausted
|
||||
if "resourceexhausted" in exc_type:
|
||||
return True
|
||||
# Generic HTTP 429 or quota messages (e.g. Ollama, raw requests)
|
||||
if "429" in exc_str or "rate limit" in exc_str or "quota" in exc_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """You are a nonpartisan legislative analyst specializing in translating complex \
|
||||
legislation into clear, accurate summaries for informed citizens. You analyze bills objectively \
|
||||
without political bias.
|
||||
|
||||
Always respond with valid JSON matching exactly this schema:
|
||||
{
|
||||
"summary": "2-4 paragraph plain-language summary of what this bill does",
|
||||
"key_points": [
|
||||
{"text": "specific concrete fact", "citation": "Section X(y)", "quote": "verbatim excerpt from bill ≤80 words", "label": "cited_fact"}
|
||||
],
|
||||
"risks": [
|
||||
{"text": "legitimate concern or challenge", "citation": "Section X(y)", "quote": "verbatim excerpt from bill ≤80 words", "label": "cited_fact"}
|
||||
],
|
||||
"deadlines": [{"date": "YYYY-MM-DD or null", "description": "what happens on this date"}],
|
||||
"topic_tags": ["healthcare", "taxation"]
|
||||
}
|
||||
|
||||
Rules:
|
||||
- summary: Explain WHAT the bill does, not whether it is good or bad. Be factual and complete.
|
||||
- key_points: 5-10 specific, concrete things the bill changes, authorizes, or appropriates. \
|
||||
Each item MUST include "text" (your claim), "citation" (the section number, e.g. "Section 301(a)(2)"), \
|
||||
"quote" (a verbatim excerpt of ≤80 words from that section that supports your claim), and "label".
|
||||
- risks: Legitimate concerns from any perspective — costs, implementation challenges, \
|
||||
constitutional questions, unintended consequences. Include at least 2 even for benign bills. \
|
||||
Each item MUST include "text", "citation", "quote", and "label" just like key_points.
|
||||
- label: "cited_fact" if the claim is directly and explicitly stated in the quoted text. \
|
||||
"inference" if the claim is an analytical interpretation, projection, or implication that goes \
|
||||
beyond what the text literally says (e.g. projected costs, likely downstream effects, \
|
||||
constitutional questions). When in doubt, use "inference".
|
||||
- deadlines: Only include if explicitly stated in the text. Use null for date if a deadline \
|
||||
is mentioned without a specific date. Empty list if none.
|
||||
- topic_tags: 3-8 lowercase tags. Prefer these standard tags: healthcare, taxation, defense, \
|
||||
education, immigration, environment, housing, infrastructure, technology, agriculture, judiciary, \
|
||||
foreign-policy, veterans, social-security, trade, budget, energy, banking, transportation, \
|
||||
public-lands, labor, civil-rights, science.
|
||||
|
||||
Respond with ONLY valid JSON. No preamble, no explanation, no markdown code blocks."""
|
||||
|
||||
MAX_TOKENS_DEFAULT = 6000
|
||||
MAX_TOKENS_OLLAMA = 3000
|
||||
TOKENS_PER_CHAR = 0.25 # rough approximation: 4 chars ≈ 1 token
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReverseBrief:
|
||||
summary: str
|
||||
key_points: list[dict]
|
||||
risks: list[dict]
|
||||
deadlines: list[dict]
|
||||
topic_tags: list[str]
|
||||
llm_provider: str
|
||||
llm_model: str
|
||||
|
||||
|
||||
def smart_truncate(text: str, max_tokens: int) -> str:
|
||||
"""Truncate bill text intelligently if it exceeds token budget."""
|
||||
approx_tokens = len(text) * TOKENS_PER_CHAR
|
||||
if approx_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
# Keep first 75% of budget for the preamble (purpose section)
|
||||
# and last 25% for effective dates / enforcement sections
|
||||
preamble_chars = int(max_tokens * 0.75 / TOKENS_PER_CHAR)
|
||||
tail_chars = int(max_tokens * 0.25 / TOKENS_PER_CHAR)
|
||||
omitted_chars = len(text) - preamble_chars - tail_chars
|
||||
|
||||
return (
|
||||
text[:preamble_chars]
|
||||
+ f"\n\n[... {omitted_chars:,} characters omitted for length ...]\n\n"
|
||||
+ text[-tail_chars:]
|
||||
)
|
||||
|
||||
|
||||
AMENDMENT_SYSTEM_PROMPT = """You are a nonpartisan legislative analyst. A bill has been updated \
|
||||
and you must summarize what changed between the previous and new version.
|
||||
|
||||
Always respond with valid JSON matching exactly this schema:
|
||||
{
|
||||
"summary": "2-3 paragraph plain-language description of what changed in this version",
|
||||
"key_points": [
|
||||
{"text": "specific change", "citation": "Section X(y)", "quote": "verbatim excerpt from new version ≤80 words", "label": "cited_fact"}
|
||||
],
|
||||
"risks": [
|
||||
{"text": "new concern introduced by this change", "citation": "Section X(y)", "quote": "verbatim excerpt from new version ≤80 words", "label": "cited_fact"}
|
||||
],
|
||||
"deadlines": [{"date": "YYYY-MM-DD or null", "description": "new deadline added"}],
|
||||
"topic_tags": ["healthcare", "taxation"]
|
||||
}
|
||||
|
||||
Rules:
|
||||
- summary: Focus ONLY on what is different from the previous version. Be specific.
|
||||
- key_points: List concrete additions, removals, or modifications in this version. \
|
||||
Each item MUST include "text" (your claim), "citation" (the section number, e.g. "Section 301(a)(2)"), \
|
||||
"quote" (a verbatim excerpt of ≤80 words from the NEW version that supports your claim), and "label".
|
||||
- risks: Only include risks that are new or changed relative to the previous version. \
|
||||
Each item MUST include "text", "citation", "quote", and "label" just like key_points.
|
||||
- label: "cited_fact" if the claim is directly and explicitly stated in the quoted text. \
|
||||
"inference" if the claim is an analytical interpretation, projection, or implication that goes \
|
||||
beyond what the text literally says. When in doubt, use "inference".
|
||||
- deadlines: Only new or changed deadlines. Empty list if none.
|
||||
- topic_tags: Same standard tags as before — include any new topics this version adds.
|
||||
|
||||
Respond with ONLY valid JSON. No preamble, no explanation, no markdown code blocks."""
|
||||
|
||||
|
||||
def build_amendment_prompt(new_text: str, previous_text: str, bill_metadata: dict, max_tokens: int) -> str:
|
||||
half = max_tokens // 2
|
||||
truncated_new = smart_truncate(new_text, half)
|
||||
truncated_prev = smart_truncate(previous_text, half)
|
||||
return f"""A bill has been updated. Summarize what changed between the previous and new version.
|
||||
|
||||
BILL METADATA:
|
||||
- Title: {bill_metadata.get('title', 'Unknown')}
|
||||
- Sponsor: {bill_metadata.get('sponsor_name', 'Unknown')} \
|
||||
({bill_metadata.get('party', '?')}-{bill_metadata.get('state', '?')})
|
||||
- Latest Action: {bill_metadata.get('latest_action_text', 'None')} \
|
||||
({bill_metadata.get('latest_action_date', 'Unknown')})
|
||||
|
||||
PREVIOUS VERSION:
|
||||
{truncated_prev}
|
||||
|
||||
NEW VERSION:
|
||||
{truncated_new}
|
||||
|
||||
Produce the JSON amendment summary now:"""
|
||||
|
||||
|
||||
def build_prompt(doc_text: str, bill_metadata: dict, max_tokens: int) -> str:
|
||||
truncated = smart_truncate(doc_text, max_tokens)
|
||||
return f"""Analyze this legislation and produce a structured brief.
|
||||
|
||||
BILL METADATA:
|
||||
- Title: {bill_metadata.get('title', 'Unknown')}
|
||||
- Sponsor: {bill_metadata.get('sponsor_name', 'Unknown')} \
|
||||
({bill_metadata.get('party', '?')}-{bill_metadata.get('state', '?')})
|
||||
- Introduced: {bill_metadata.get('introduced_date', 'Unknown')}
|
||||
- Chamber: {bill_metadata.get('chamber', 'Unknown')}
|
||||
- Latest Action: {bill_metadata.get('latest_action_text', 'None')} \
|
||||
({bill_metadata.get('latest_action_date', 'Unknown')})
|
||||
|
||||
BILL TEXT:
|
||||
{truncated}
|
||||
|
||||
Produce the JSON brief now:"""
|
||||
|
||||
|
||||
def parse_brief_json(raw: str | dict, provider: str, model: str) -> ReverseBrief:
|
||||
"""Parse and validate LLM JSON response into a ReverseBrief."""
|
||||
if isinstance(raw, str):
|
||||
# Strip markdown code fences if present
|
||||
raw = re.sub(r"^```(?:json)?\s*", "", raw.strip())
|
||||
raw = re.sub(r"\s*```$", "", raw.strip())
|
||||
data = json.loads(raw)
|
||||
else:
|
||||
data = raw
|
||||
|
||||
return ReverseBrief(
|
||||
summary=str(data.get("summary", "")),
|
||||
key_points=list(data.get("key_points", [])),
|
||||
risks=list(data.get("risks", [])),
|
||||
deadlines=list(data.get("deadlines", [])),
|
||||
topic_tags=list(data.get("topic_tags", [])),
|
||||
llm_provider=provider,
|
||||
llm_model=model,
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
_provider_name: str = "unknown"
|
||||
|
||||
def _call(self, fn):
|
||||
"""Invoke fn(), translating provider-specific rate-limit errors to RateLimitError."""
|
||||
try:
|
||||
return fn()
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if _detect_rate_limit(exc):
|
||||
raise RateLimitError(self._provider_name) from exc
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_text(self, prompt: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIProvider(LLMProvider):
|
||||
_provider_name = "openai"
|
||||
|
||||
def __init__(self, model: str | None = None):
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
self.model = model or settings.OPENAI_MODEL
|
||||
|
||||
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
response = self._call(lambda: self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
))
|
||||
raw = response.choices[0].message.content
|
||||
return parse_brief_json(raw, "openai", self.model)
|
||||
|
||||
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
response = self._call(lambda: self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": AMENDMENT_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.1,
|
||||
))
|
||||
raw = response.choices[0].message.content
|
||||
return parse_brief_json(raw, "openai", self.model)
|
||||
|
||||
def generate_text(self, prompt: str) -> str:
|
||||
response = self._call(lambda: self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
))
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
_provider_name = "anthropic"
|
||||
|
||||
def __init__(self, model: str | None = None):
|
||||
import anthropic
|
||||
self.client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
self.model = model or settings.ANTHROPIC_MODEL
|
||||
|
||||
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
response = self._call(lambda: self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
system=[{
|
||||
"type": "text",
|
||||
"text": SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text.",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
))
|
||||
raw = response.content[0].text
|
||||
return parse_brief_json(raw, "anthropic", self.model)
|
||||
|
||||
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
response = self._call(lambda: self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
system=[{
|
||||
"type": "text",
|
||||
"text": AMENDMENT_SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text.",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
))
|
||||
raw = response.content[0].text
|
||||
return parse_brief_json(raw, "anthropic", self.model)
|
||||
|
||||
def generate_text(self, prompt: str) -> str:
|
||||
response = self._call(lambda: self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
))
|
||||
return response.content[0].text
|
||||
|
||||
|
||||
class GeminiProvider(LLMProvider):
|
||||
_provider_name = "gemini"
|
||||
|
||||
def __init__(self, model: str | None = None):
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
self._genai = genai
|
||||
self.model_name = model or settings.GEMINI_MODEL
|
||||
|
||||
def _make_model(self, system_prompt: str):
|
||||
return self._genai.GenerativeModel(
|
||||
model_name=self.model_name,
|
||||
generation_config={"response_mime_type": "application/json", "temperature": 0.1},
|
||||
system_instruction=system_prompt,
|
||||
)
|
||||
|
||||
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
response = self._call(lambda: self._make_model(SYSTEM_PROMPT).generate_content(prompt))
|
||||
return parse_brief_json(response.text, "gemini", self.model_name)
|
||||
|
||||
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
response = self._call(lambda: self._make_model(AMENDMENT_SYSTEM_PROMPT).generate_content(prompt))
|
||||
return parse_brief_json(response.text, "gemini", self.model_name)
|
||||
|
||||
def generate_text(self, prompt: str) -> str:
|
||||
model = self._genai.GenerativeModel(
|
||||
model_name=self.model_name,
|
||||
generation_config={"temperature": 0.3},
|
||||
)
|
||||
response = self._call(lambda: model.generate_content(prompt))
|
||||
return response.text
|
||||
|
||||
|
||||
class OllamaProvider(LLMProvider):
|
||||
_provider_name = "ollama"
|
||||
|
||||
def __init__(self, model: str | None = None):
|
||||
self.base_url = settings.OLLAMA_BASE_URL.rstrip("/")
|
||||
self.model = model or settings.OLLAMA_MODEL
|
||||
|
||||
def _generate(self, system_prompt: str, user_prompt: str) -> str:
|
||||
import requests as req
|
||||
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
||||
response = req.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={"model": self.model, "prompt": full_prompt, "stream": False, "format": "json"},
|
||||
timeout=300,
|
||||
)
|
||||
response.raise_for_status()
|
||||
raw = response.json().get("response", "")
|
||||
try:
|
||||
return raw
|
||||
except Exception:
|
||||
strict = f"{full_prompt}\n\nCRITICAL: Your response MUST be valid JSON only."
|
||||
r2 = req.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={"model": self.model, "prompt": strict, "stream": False, "format": "json"},
|
||||
timeout=300,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
return r2.json().get("response", "")
|
||||
|
||||
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_OLLAMA)
|
||||
raw = self._generate(SYSTEM_PROMPT, prompt)
|
||||
try:
|
||||
return parse_brief_json(raw, "ollama", self.model)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Ollama JSON parse failed, retrying: {e}")
|
||||
raw2 = self._generate(
|
||||
SYSTEM_PROMPT,
|
||||
prompt + "\n\nCRITICAL: Your response MUST be valid JSON only. No text before or after the JSON object."
|
||||
)
|
||||
return parse_brief_json(raw2, "ollama", self.model)
|
||||
|
||||
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
||||
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_OLLAMA)
|
||||
raw = self._generate(AMENDMENT_SYSTEM_PROMPT, prompt)
|
||||
try:
|
||||
return parse_brief_json(raw, "ollama", self.model)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Ollama amendment JSON parse failed, retrying: {e}")
|
||||
raw2 = self._generate(
|
||||
AMENDMENT_SYSTEM_PROMPT,
|
||||
prompt + "\n\nCRITICAL: Your response MUST be valid JSON only. No text before or after the JSON object."
|
||||
)
|
||||
return parse_brief_json(raw2, "ollama", self.model)
|
||||
|
||||
def generate_text(self, prompt: str) -> str:
|
||||
import requests as req
|
||||
response = req.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={"model": self.model, "prompt": prompt, "stream": False},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("response", "")
|
||||
|
||||
|
||||
def get_llm_provider(provider: str | None = None, model: str | None = None) -> LLMProvider:
|
||||
"""Factory — returns the configured LLM provider.
|
||||
|
||||
Pass ``provider`` and/or ``model`` explicitly (e.g. from DB overrides) to bypass env defaults.
|
||||
"""
|
||||
if provider is None:
|
||||
provider = settings.LLM_PROVIDER
|
||||
provider = provider.lower()
|
||||
if provider == "openai":
|
||||
return OpenAIProvider(model=model)
|
||||
elif provider == "anthropic":
|
||||
return AnthropicProvider(model=model)
|
||||
elif provider == "gemini":
|
||||
return GeminiProvider(model=model)
|
||||
elif provider == "ollama":
|
||||
return OllamaProvider(model=model)
|
||||
raise ValueError(f"Unknown LLM_PROVIDER: '{provider}'. Must be one of: openai, anthropic, gemini, ollama")
|
||||
|
||||
|
||||
_BILL_TYPE_LABELS: dict[str, str] = {
|
||||
"hr": "H.R.",
|
||||
"s": "S.",
|
||||
"hjres": "H.J.Res.",
|
||||
"sjres": "S.J.Res.",
|
||||
"hconres": "H.Con.Res.",
|
||||
"sconres": "S.Con.Res.",
|
||||
"hres": "H.Res.",
|
||||
"sres": "S.Res.",
|
||||
}
|
||||
|
||||
_TONE_INSTRUCTIONS: dict[str, str] = {
|
||||
"short": "Keep the letter brief — 6 to 8 sentences total.",
|
||||
"polite": "Use a respectful, formal, and courteous tone throughout the letter.",
|
||||
"firm": "Use a direct, firm tone that makes clear the constituent's strong conviction.",
|
||||
}
|
||||
|
||||
|
||||
def generate_draft_letter(
|
||||
bill_label: str,
|
||||
bill_title: str,
|
||||
stance: str,
|
||||
recipient: str,
|
||||
tone: str,
|
||||
selected_points: list[str],
|
||||
include_citations: bool,
|
||||
zip_code: str | None,
|
||||
rep_name: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
llm_model: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a plain-text constituent letter draft using the configured LLM provider."""
|
||||
vote_word = "YES" if stance == "yes" else "NO"
|
||||
chamber_word = "House" if recipient == "house" else "Senate"
|
||||
tone_instruction = _TONE_INSTRUCTIONS.get(tone, _TONE_INSTRUCTIONS["polite"])
|
||||
|
||||
points_block = "\n".join(f"- {p}" for p in selected_points)
|
||||
|
||||
citation_instruction = (
|
||||
"You may reference the citation label for each point (e.g. 'as noted in Section 3') if it adds clarity."
|
||||
if include_citations
|
||||
else "Do not include any citation references."
|
||||
)
|
||||
|
||||
location_line = f"The constituent is writing from ZIP code {zip_code}." if zip_code else ""
|
||||
|
||||
if rep_name:
|
||||
title = "Senator" if recipient == "senate" else "Representative"
|
||||
salutation_instruction = f'- Open with "Dear {title} {rep_name},"'
|
||||
else:
|
||||
salutation_instruction = f'- Open with "Dear {chamber_word} Member,"'
|
||||
|
||||
prompt = f"""Write a short constituent letter to a {chamber_word} member of Congress.
|
||||
|
||||
RULES:
|
||||
- {tone_instruction}
|
||||
- 6 to 12 sentences total.
|
||||
- {salutation_instruction}
|
||||
- Second sentence must be a clear, direct ask: "Please vote {vote_word} on {bill_label}."
|
||||
- The body must reference ONLY the points listed below — do not invent any other claims or facts.
|
||||
- {citation_instruction}
|
||||
- Close with a brief sign-off and the placeholder "[Your Name]".
|
||||
- Plain text only. No markdown, no bullet points, no headers, no partisan framing.
|
||||
- Do not mention any political party.
|
||||
|
||||
BILL: {bill_label} — {bill_title}
|
||||
STANCE: Vote {vote_word}
|
||||
{location_line}
|
||||
|
||||
SELECTED POINTS TO REFERENCE:
|
||||
{points_block}
|
||||
|
||||
Write the letter now:"""
|
||||
|
||||
return get_llm_provider(provider=llm_provider, model=llm_model).generate_text(prompt)
|
||||
308
backend/app/services/news_service.py
Normal file
308
backend/app/services/news_service.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
News correlation service.
|
||||
|
||||
- NewsAPI.org: structured news articles per bill (100 req/day limit)
|
||||
- Google News RSS: volume signal for zeitgeist scoring (no limit)
|
||||
"""
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import feedparser
|
||||
import redis
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NEWSAPI_BASE = "https://newsapi.org/v2"
|
||||
GOOGLE_NEWS_RSS = "https://news.google.com/rss/search"
|
||||
NEWSAPI_DAILY_LIMIT = 95 # Leave 5 as buffer
|
||||
NEWSAPI_BATCH_SIZE = 4 # Bills per OR-combined API call
|
||||
|
||||
_NEWSAPI_REDIS_PREFIX = "newsapi:daily_calls:"
|
||||
_GNEWS_CACHE_TTL = 7200 # 2 hours — both trend_scorer and news_fetcher share cache
|
||||
|
||||
|
||||
def _redis():
|
||||
return redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
def _newsapi_quota_ok() -> bool:
|
||||
"""Return True if we have quota remaining for today."""
|
||||
try:
|
||||
key = f"{_NEWSAPI_REDIS_PREFIX}{date.today().isoformat()}"
|
||||
used = int(_redis().get(key) or 0)
|
||||
return used < NEWSAPI_DAILY_LIMIT
|
||||
except Exception:
|
||||
return True # Don't block on Redis errors
|
||||
|
||||
|
||||
def _newsapi_record_call():
|
||||
try:
|
||||
r = _redis()
|
||||
key = f"{_NEWSAPI_REDIS_PREFIX}{date.today().isoformat()}"
|
||||
pipe = r.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, 90000) # 25 hours — expires safely after midnight
|
||||
pipe.execute()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_newsapi_quota_remaining() -> int:
|
||||
"""Return the number of NewsAPI calls still available today."""
|
||||
try:
|
||||
key = f"{_NEWSAPI_REDIS_PREFIX}{date.today().isoformat()}"
|
||||
used = int(_redis().get(key) or 0)
|
||||
return max(0, NEWSAPI_DAILY_LIMIT - used)
|
||||
except Exception:
|
||||
return NEWSAPI_DAILY_LIMIT
|
||||
|
||||
|
||||
def clear_gnews_cache() -> int:
|
||||
"""Delete all cached Google News RSS results. Returns number of keys deleted."""
|
||||
try:
|
||||
r = _redis()
|
||||
keys = r.keys("gnews:*")
|
||||
if keys:
|
||||
return r.delete(*keys)
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_exponential(min=1, max=5))
|
||||
def _newsapi_get(endpoint: str, params: dict) -> dict:
|
||||
params["apiKey"] = settings.NEWSAPI_KEY
|
||||
response = requests.get(f"{NEWSAPI_BASE}/{endpoint}", params=params, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def build_news_query(bill_title: str, short_title: Optional[str], sponsor_name: Optional[str],
|
||||
bill_type: str, bill_number: int) -> str:
|
||||
"""Build a NewsAPI search query for a bill."""
|
||||
terms = []
|
||||
if short_title:
|
||||
terms.append(f'"{short_title}"')
|
||||
elif bill_title:
|
||||
# Use first 6 words of title as phrase
|
||||
words = bill_title.split()[:6]
|
||||
if len(words) >= 3:
|
||||
terms.append(f'"{" ".join(words)}"')
|
||||
# Add bill number as fallback
|
||||
terms.append(f'"{bill_type.upper()} {bill_number}"')
|
||||
return " OR ".join(terms[:2]) # Keep queries short for relevance
|
||||
|
||||
|
||||
def fetch_newsapi_articles(query: str, days: int = 30) -> list[dict]:
|
||||
"""Fetch articles from NewsAPI.org. Returns empty list if quota is exhausted or key not set."""
|
||||
if not settings.NEWSAPI_KEY:
|
||||
return []
|
||||
if not _newsapi_quota_ok():
|
||||
logger.warning("NewsAPI daily quota exhausted — skipping fetch")
|
||||
return []
|
||||
try:
|
||||
from_date = (datetime.now(timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
data = _newsapi_get("everything", {
|
||||
"q": query,
|
||||
"language": "en",
|
||||
"sortBy": "relevancy",
|
||||
"pageSize": 10,
|
||||
"from": from_date,
|
||||
})
|
||||
_newsapi_record_call()
|
||||
articles = data.get("articles", [])
|
||||
return [
|
||||
{
|
||||
"source": a.get("source", {}).get("name", ""),
|
||||
"headline": a.get("title", ""),
|
||||
"url": a.get("url", ""),
|
||||
"published_at": a.get("publishedAt"),
|
||||
}
|
||||
for a in articles
|
||||
if a.get("url") and a.get("title")
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"NewsAPI fetch failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def fetch_newsapi_articles_batch(
|
||||
bill_queries: list[tuple[str, str]],
|
||||
days: int = 30,
|
||||
) -> dict[str, list[dict]]:
|
||||
"""
|
||||
Fetch NewsAPI articles for up to NEWSAPI_BATCH_SIZE bills in ONE API call
|
||||
using OR syntax. Returns {bill_id: [articles]} — each article attributed
|
||||
to the bill whose query terms appear in the headline/description.
|
||||
"""
|
||||
empty = {bill_id: [] for bill_id, _ in bill_queries}
|
||||
if not settings.NEWSAPI_KEY or not bill_queries:
|
||||
return empty
|
||||
if not _newsapi_quota_ok():
|
||||
logger.warning("NewsAPI daily quota exhausted — skipping batch fetch")
|
||||
return empty
|
||||
|
||||
combined_q = " OR ".join(q for _, q in bill_queries)
|
||||
try:
|
||||
from_date = (datetime.now(timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
data = _newsapi_get("everything", {
|
||||
"q": combined_q,
|
||||
"language": "en",
|
||||
"sortBy": "relevancy",
|
||||
"pageSize": 20,
|
||||
"from": from_date,
|
||||
})
|
||||
_newsapi_record_call()
|
||||
articles = data.get("articles", [])
|
||||
|
||||
result: dict[str, list[dict]] = {bill_id: [] for bill_id, _ in bill_queries}
|
||||
for article in articles:
|
||||
content = " ".join([
|
||||
article.get("title", ""),
|
||||
article.get("description", "") or "",
|
||||
]).lower()
|
||||
for bill_id, query in bill_queries:
|
||||
# Match if any meaningful term from this bill's query appears in the article
|
||||
terms = [t.strip('" ').lower() for t in query.split(" OR ")]
|
||||
if any(len(t) > 3 and t in content for t in terms):
|
||||
result[bill_id].append({
|
||||
"source": article.get("source", {}).get("name", ""),
|
||||
"headline": article.get("title", ""),
|
||||
"url": article.get("url", ""),
|
||||
"published_at": article.get("publishedAt"),
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"NewsAPI batch fetch failed: {e}")
|
||||
return empty
|
||||
|
||||
|
||||
# ── Google News RSS ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _gnews_cache_key(query: str, kind: str, days: int) -> str:
|
||||
h = hashlib.md5(f"{query}:{days}".encode()).hexdigest()[:12]
|
||||
return f"gnews:{kind}:{h}"
|
||||
|
||||
|
||||
def fetch_gnews_count(query: str, days: int = 30) -> int:
|
||||
"""Count articles in Google News RSS. Results cached in Redis for 2 hours."""
|
||||
cache_key = _gnews_cache_key(query, "count", days)
|
||||
try:
|
||||
cached = _redis().get(cache_key)
|
||||
if cached is not None:
|
||||
return int(cached)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
count = _fetch_gnews_count_raw(query, days)
|
||||
|
||||
try:
|
||||
_redis().setex(cache_key, _GNEWS_CACHE_TTL, count)
|
||||
except Exception:
|
||||
pass
|
||||
return count
|
||||
|
||||
|
||||
def _fetch_gnews_count_raw(query: str, days: int) -> int:
|
||||
"""Fetch gnews article count directly (no cache)."""
|
||||
try:
|
||||
encoded = urllib.parse.quote(f"{query} when:{days}d")
|
||||
url = f"{GOOGLE_NEWS_RSS}?q={encoded}&hl=en-US&gl=US&ceid=US:en"
|
||||
time.sleep(1) # Polite delay
|
||||
feed = feedparser.parse(url)
|
||||
return len(feed.entries)
|
||||
except Exception as e:
|
||||
logger.error(f"Google News RSS fetch failed: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _gnews_entry_url(entry) -> str:
|
||||
"""Extract the article URL from a feedparser Google News RSS entry."""
|
||||
link = getattr(entry, "link", None) or entry.get("link", "")
|
||||
if link:
|
||||
return link
|
||||
for lnk in getattr(entry, "links", []):
|
||||
href = lnk.get("href", "")
|
||||
if href:
|
||||
return href
|
||||
return ""
|
||||
|
||||
|
||||
def fetch_gnews_articles(query: str, days: int = 30) -> list[dict]:
|
||||
"""Fetch articles from Google News RSS. Results cached in Redis for 2 hours."""
|
||||
import time as time_mod
|
||||
cache_key = _gnews_cache_key(query, "articles", days)
|
||||
try:
|
||||
cached = _redis().get(cache_key)
|
||||
if cached is not None:
|
||||
return json.loads(cached)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
articles = _fetch_gnews_articles_raw(query, days)
|
||||
|
||||
try:
|
||||
_redis().setex(cache_key, _GNEWS_CACHE_TTL, json.dumps(articles))
|
||||
except Exception:
|
||||
pass
|
||||
return articles
|
||||
|
||||
|
||||
def _fetch_gnews_articles_raw(query: str, days: int) -> list[dict]:
|
||||
"""Fetch gnews articles directly (no cache)."""
|
||||
import time as time_mod
|
||||
try:
|
||||
encoded = urllib.parse.quote(f"{query} when:{days}d")
|
||||
url = f"{GOOGLE_NEWS_RSS}?q={encoded}&hl=en-US&gl=US&ceid=US:en"
|
||||
time.sleep(1) # Polite delay
|
||||
feed = feedparser.parse(url)
|
||||
articles = []
|
||||
for entry in feed.entries[:20]:
|
||||
pub_at = None
|
||||
if getattr(entry, "published_parsed", None):
|
||||
try:
|
||||
pub_at = datetime.fromtimestamp(
|
||||
time_mod.mktime(entry.published_parsed), tz=timezone.utc
|
||||
).isoformat()
|
||||
except Exception:
|
||||
pass
|
||||
source = ""
|
||||
src = getattr(entry, "source", None)
|
||||
if src:
|
||||
source = getattr(src, "title", "") or src.get("title", "")
|
||||
headline = entry.get("title", "") or getattr(entry, "title", "")
|
||||
article_url = _gnews_entry_url(entry)
|
||||
if article_url and headline:
|
||||
articles.append({
|
||||
"source": source or "Google News",
|
||||
"headline": headline,
|
||||
"url": article_url,
|
||||
"published_at": pub_at,
|
||||
})
|
||||
return articles
|
||||
except Exception as e:
|
||||
logger.error(f"Google News RSS article fetch failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def build_member_query(first_name: str, last_name: str, chamber: Optional[str] = None) -> str:
|
||||
"""Build a news search query for a member of Congress."""
|
||||
full_name = f"{first_name} {last_name}".strip()
|
||||
title = ""
|
||||
if chamber:
|
||||
if "senate" in chamber.lower():
|
||||
title = "Senator"
|
||||
else:
|
||||
title = "Rep."
|
||||
if title:
|
||||
return f'"{full_name}" OR "{title} {last_name}"'
|
||||
return f'"{full_name}"'
|
||||
112
backend/app/services/trends_service.py
Normal file
112
backend/app/services/trends_service.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Google Trends service (via pytrends).
|
||||
|
||||
pytrends is unofficial web scraping — Google blocks it sporadically.
|
||||
All calls are wrapped in try/except and return 0 on any failure.
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_trends_score(keywords: list[str]) -> float:
|
||||
"""
|
||||
Return a 0–100 interest score for the given keywords over the past 90 days.
|
||||
Returns 0.0 on any failure (rate limit, empty data, exception).
|
||||
"""
|
||||
if not settings.PYTRENDS_ENABLED or not keywords:
|
||||
return 0.0
|
||||
try:
|
||||
from pytrends.request import TrendReq
|
||||
|
||||
# Jitter to avoid detection as bot
|
||||
time.sleep(random.uniform(2.0, 5.0))
|
||||
|
||||
pytrends = TrendReq(hl="en-US", tz=0, timeout=(10, 25))
|
||||
kw_list = [k for k in keywords[:5] if k] # max 5 keywords
|
||||
if not kw_list:
|
||||
return 0.0
|
||||
|
||||
pytrends.build_payload(kw_list, timeframe="today 3-m", geo="US")
|
||||
data = pytrends.interest_over_time()
|
||||
|
||||
if data is None or data.empty:
|
||||
return 0.0
|
||||
|
||||
# Average the most recent 14 data points for the primary keyword
|
||||
primary = kw_list[0]
|
||||
if primary not in data.columns:
|
||||
return 0.0
|
||||
|
||||
recent = data[primary].tail(14)
|
||||
return float(recent.mean())
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"pytrends failed (non-critical): {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_trends_scores_batch(keyword_groups: list[list[str]]) -> list[float]:
|
||||
"""
|
||||
Get pytrends scores for up to 5 keyword groups in a SINGLE pytrends call.
|
||||
Takes the first (most relevant) keyword from each group and compares them
|
||||
relative to each other. Falls back to per-group individual calls if the
|
||||
batch fails.
|
||||
|
||||
Returns a list of scores (0–100) in the same order as keyword_groups.
|
||||
"""
|
||||
if not settings.PYTRENDS_ENABLED or not keyword_groups:
|
||||
return [0.0] * len(keyword_groups)
|
||||
|
||||
# Extract the primary (first) keyword from each group, skip empty groups
|
||||
primaries = [(i, kws[0]) for i, kws in enumerate(keyword_groups) if kws]
|
||||
if not primaries:
|
||||
return [0.0] * len(keyword_groups)
|
||||
|
||||
try:
|
||||
from pytrends.request import TrendReq
|
||||
|
||||
time.sleep(random.uniform(2.0, 5.0))
|
||||
pytrends = TrendReq(hl="en-US", tz=0, timeout=(10, 25))
|
||||
kw_list = [kw for _, kw in primaries[:5]]
|
||||
|
||||
pytrends.build_payload(kw_list, timeframe="today 3-m", geo="US")
|
||||
data = pytrends.interest_over_time()
|
||||
|
||||
scores = [0.0] * len(keyword_groups)
|
||||
if data is not None and not data.empty:
|
||||
for idx, kw in primaries[:5]:
|
||||
if kw in data.columns:
|
||||
scores[idx] = float(data[kw].tail(14).mean())
|
||||
return scores
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"pytrends batch failed (non-critical): {e}")
|
||||
# Fallback: return zeros (individual calls would just multiply failures)
|
||||
return [0.0] * len(keyword_groups)
|
||||
|
||||
|
||||
def keywords_for_member(first_name: str, last_name: str) -> list[str]:
|
||||
"""Extract meaningful search keywords for a member of Congress."""
|
||||
full_name = f"{first_name} {last_name}".strip()
|
||||
if not full_name:
|
||||
return []
|
||||
return [full_name]
|
||||
|
||||
|
||||
def keywords_for_bill(title: str, short_title: str, topic_tags: list[str]) -> list[str]:
|
||||
"""Extract meaningful search keywords for a bill."""
|
||||
keywords = []
|
||||
if short_title:
|
||||
keywords.append(short_title)
|
||||
elif title:
|
||||
# Use first 5 words of title
|
||||
words = title.split()[:5]
|
||||
if len(words) >= 2:
|
||||
keywords.append(" ".join(words))
|
||||
keywords.extend(tag.replace("-", " ") for tag in (topic_tags or [])[:3])
|
||||
return keywords[:5]
|
||||
0
backend/app/workers/__init__.py
Normal file
0
backend/app/workers/__init__.py
Normal file
361
backend/app/workers/bill_classifier.py
Normal file
361
backend/app/workers/bill_classifier.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Bill classifier and Member Effectiveness Score workers.
|
||||
|
||||
Tasks:
|
||||
classify_bill_category — lightweight LLM call; triggered after brief generation
|
||||
fetch_bill_cosponsors — Congress.gov cosponsor fetch; triggered on new bill
|
||||
calculate_effectiveness_scores — nightly beat task
|
||||
backfill_bill_categories — one-time backfill for existing bills
|
||||
backfill_all_bill_cosponsors — one-time backfill for existing bills
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillCosponsor, BillDocument, Member
|
||||
from app.models.setting import AppSetting
|
||||
from app.services import congress_api
|
||||
from app.services.llm_service import RateLimitError, get_llm_provider
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Classification ─────────────────────────────────────────────────────────────
|
||||
|
||||
_CLASSIFICATION_PROMPT = """\
|
||||
Classify this bill into exactly one category.
|
||||
|
||||
Categories:
|
||||
- substantive: Creates, modifies, or repeals policy, programs, regulations, funding, or rights. Real legislative work.
|
||||
- commemorative: Names buildings/post offices, recognizes awareness days/weeks, honors individuals or events with no policy effect.
|
||||
- administrative: Technical corrections, routine reauthorizations, housekeeping changes with no new policy substance.
|
||||
|
||||
Respond with ONLY valid JSON: {{"category": "substantive" | "commemorative" | "administrative"}}
|
||||
|
||||
BILL TITLE: {title}
|
||||
|
||||
BILL TEXT (excerpt):
|
||||
{excerpt}
|
||||
|
||||
Classify now:"""
|
||||
|
||||
_VALID_CATEGORIES = {"substantive", "commemorative", "administrative"}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
max_retries=3,
|
||||
rate_limit=f"{settings.LLM_RATE_LIMIT_RPM}/m",
|
||||
name="app.workers.bill_classifier.classify_bill_category",
|
||||
)
|
||||
def classify_bill_category(self, bill_id: str, document_id: int):
|
||||
"""Set bill_category via a cheap one-shot LLM call. Idempotent."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill or bill.bill_category:
|
||||
return {"status": "skipped"}
|
||||
|
||||
doc = db.get(BillDocument, document_id)
|
||||
excerpt = (doc.raw_text[:1200] if doc and doc.raw_text else "").strip()
|
||||
|
||||
prov_row = db.get(AppSetting, "llm_provider")
|
||||
model_row = db.get(AppSetting, "llm_model")
|
||||
provider = get_llm_provider(
|
||||
prov_row.value if prov_row else None,
|
||||
model_row.value if model_row else None,
|
||||
)
|
||||
|
||||
prompt = _CLASSIFICATION_PROMPT.format(
|
||||
title=bill.title or "Unknown",
|
||||
excerpt=excerpt or "(no text available)",
|
||||
)
|
||||
|
||||
raw = provider.generate_text(prompt).strip()
|
||||
# Strip markdown fences if present
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("```")[1].lstrip("json").strip()
|
||||
raw = raw.rstrip("```").strip()
|
||||
|
||||
data = json.loads(raw)
|
||||
category = data.get("category", "").lower()
|
||||
if category not in _VALID_CATEGORIES:
|
||||
logger.warning(f"classify_bill_category: invalid category '{category}' for {bill_id}, defaulting to substantive")
|
||||
category = "substantive"
|
||||
|
||||
bill.bill_category = category
|
||||
db.commit()
|
||||
logger.info(f"Bill {bill_id} classified as '{category}'")
|
||||
return {"status": "ok", "bill_id": bill_id, "category": category}
|
||||
|
||||
except RateLimitError as exc:
|
||||
db.rollback()
|
||||
raise self.retry(exc=exc, countdown=exc.retry_after)
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"classify_bill_category failed for {bill_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=120)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.bill_classifier.backfill_bill_categories")
|
||||
def backfill_bill_categories(self):
|
||||
"""Queue classification for all bills with text but no category."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
rows = db.execute(text("""
|
||||
SELECT bd.bill_id, bd.id AS document_id
|
||||
FROM bill_documents bd
|
||||
JOIN bills b ON b.bill_id = bd.bill_id
|
||||
WHERE b.bill_category IS NULL AND bd.raw_text IS NOT NULL
|
||||
""")).fetchall()
|
||||
|
||||
queued = 0
|
||||
for row in rows:
|
||||
classify_bill_category.delay(row.bill_id, row.document_id)
|
||||
queued += 1
|
||||
time.sleep(0.05)
|
||||
|
||||
logger.info(f"backfill_bill_categories: queued {queued} classification tasks")
|
||||
return {"queued": queued}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ── Co-sponsor fetching ────────────────────────────────────────────────────────
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.bill_classifier.fetch_bill_cosponsors")
|
||||
def fetch_bill_cosponsors(self, bill_id: str):
|
||||
"""Fetch and store cosponsor list from Congress.gov. Idempotent."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill or bill.cosponsors_fetched_at:
|
||||
return {"status": "skipped"}
|
||||
|
||||
known_bioguides = {row[0] for row in db.execute(text("SELECT bioguide_id FROM members")).fetchall()}
|
||||
# Track bioguide_ids already inserted this run to handle within-page dupes
|
||||
# (Congress.gov sometimes lists the same member twice with different dates)
|
||||
inserted_this_run: set[str] = set()
|
||||
inserted = 0
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
data = congress_api.get_bill_cosponsors(
|
||||
bill.congress_number, bill.bill_type, bill.bill_number, offset=offset
|
||||
)
|
||||
cosponsors = data.get("cosponsors", [])
|
||||
if not cosponsors:
|
||||
break
|
||||
|
||||
for cs in cosponsors:
|
||||
bioguide_id = cs.get("bioguideId")
|
||||
# Only link to members we've already ingested
|
||||
if bioguide_id and bioguide_id not in known_bioguides:
|
||||
bioguide_id = None
|
||||
|
||||
# Skip dupes — both across runs (DB check) and within this page
|
||||
if bioguide_id:
|
||||
if bioguide_id in inserted_this_run:
|
||||
continue
|
||||
exists = db.query(BillCosponsor).filter_by(
|
||||
bill_id=bill_id, bioguide_id=bioguide_id
|
||||
).first()
|
||||
if exists:
|
||||
inserted_this_run.add(bioguide_id)
|
||||
continue
|
||||
|
||||
date_str = cs.get("sponsorshipDate")
|
||||
try:
|
||||
sponsored_date = datetime.strptime(date_str, "%Y-%m-%d").date() if date_str else None
|
||||
except ValueError:
|
||||
sponsored_date = None
|
||||
|
||||
db.add(BillCosponsor(
|
||||
bill_id=bill_id,
|
||||
bioguide_id=bioguide_id,
|
||||
name=cs.get("fullName") or cs.get("name"),
|
||||
party=cs.get("party"),
|
||||
state=cs.get("state"),
|
||||
sponsored_date=sponsored_date,
|
||||
))
|
||||
if bioguide_id:
|
||||
inserted_this_run.add(bioguide_id)
|
||||
inserted += 1
|
||||
|
||||
db.commit()
|
||||
offset += 250
|
||||
if len(cosponsors) < 250:
|
||||
break
|
||||
time.sleep(0.25)
|
||||
|
||||
bill.cosponsors_fetched_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return {"bill_id": bill_id, "inserted": inserted}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"fetch_bill_cosponsors failed for {bill_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.bill_classifier.backfill_all_bill_cosponsors")
|
||||
def backfill_all_bill_cosponsors(self):
|
||||
"""Queue cosponsor fetches for all bills that haven't been fetched yet."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
rows = db.execute(text(
|
||||
"SELECT bill_id FROM bills WHERE cosponsors_fetched_at IS NULL"
|
||||
)).fetchall()
|
||||
|
||||
queued = 0
|
||||
for row in rows:
|
||||
fetch_bill_cosponsors.delay(row.bill_id)
|
||||
queued += 1
|
||||
time.sleep(0.05)
|
||||
|
||||
logger.info(f"backfill_all_bill_cosponsors: queued {queued} tasks")
|
||||
return {"queued": queued}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ── Effectiveness scoring ──────────────────────────────────────────────────────
|
||||
|
||||
def _distance_points(latest_action_text: str | None) -> int:
|
||||
"""Map latest action text to a distance-traveled score."""
|
||||
text = (latest_action_text or "").lower()
|
||||
if "became public law" in text or "signed by president" in text or "enacted" in text:
|
||||
return 50
|
||||
if "passed house" in text or "passed senate" in text or "agreed to in" in text:
|
||||
return 20
|
||||
if "placed on" in text and "calendar" in text:
|
||||
return 10
|
||||
if "reported by" in text or "ordered to be reported" in text or "discharged" in text:
|
||||
return 5
|
||||
return 1
|
||||
|
||||
|
||||
def _bipartisan_multiplier(db, bill_id: str, sponsor_party: str | None) -> float:
|
||||
"""1.5x if ≥20% of cosponsors are from the opposing party."""
|
||||
if not sponsor_party:
|
||||
return 1.0
|
||||
cosponsors = db.query(BillCosponsor).filter_by(bill_id=bill_id).all()
|
||||
if not cosponsors:
|
||||
return 1.0
|
||||
opposing = [c for c in cosponsors if c.party and c.party != sponsor_party]
|
||||
if len(cosponsors) > 0 and len(opposing) / len(cosponsors) >= 0.20:
|
||||
return 1.5
|
||||
return 1.0
|
||||
|
||||
|
||||
def _substance_multiplier(bill_category: str | None) -> float:
|
||||
return 0.1 if bill_category == "commemorative" else 1.0
|
||||
|
||||
|
||||
def _leadership_multiplier(member: Member, congress_number: int) -> float:
|
||||
"""1.2x if member chaired a committee during this Congress."""
|
||||
if not member.leadership_json:
|
||||
return 1.0
|
||||
for role in member.leadership_json:
|
||||
if (role.get("congress") == congress_number and
|
||||
"chair" in (role.get("type") or "").lower()):
|
||||
return 1.2
|
||||
return 1.0
|
||||
|
||||
|
||||
def _seniority_tier(terms_json: list | None) -> str:
|
||||
"""Return 'junior' | 'mid' | 'senior' based on number of terms served."""
|
||||
if not terms_json:
|
||||
return "junior"
|
||||
count = len(terms_json)
|
||||
if count <= 2:
|
||||
return "junior"
|
||||
if count <= 5:
|
||||
return "mid"
|
||||
return "senior"
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.bill_classifier.calculate_effectiveness_scores")
|
||||
def calculate_effectiveness_scores(self):
|
||||
"""Nightly: compute effectiveness score and within-tier percentile for all members."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
members = db.query(Member).all()
|
||||
if not members:
|
||||
return {"status": "no_members"}
|
||||
|
||||
# Map bioguide_id → Member for quick lookup
|
||||
member_map = {m.bioguide_id: m for m in members}
|
||||
|
||||
# Load all bills sponsored by current members (current congress only)
|
||||
current_congress = congress_api.get_current_congress()
|
||||
bills = db.query(Bill).filter_by(congress_number=current_congress).all()
|
||||
|
||||
# Compute raw score per member
|
||||
raw_scores: dict[str, float] = {m.bioguide_id: 0.0 for m in members}
|
||||
|
||||
for bill in bills:
|
||||
if not bill.sponsor_id or bill.sponsor_id not in member_map:
|
||||
continue
|
||||
sponsor = member_map[bill.sponsor_id]
|
||||
|
||||
pts = _distance_points(bill.latest_action_text)
|
||||
bipartisan = _bipartisan_multiplier(db, bill.bill_id, sponsor.party)
|
||||
substance = _substance_multiplier(bill.bill_category)
|
||||
leadership = _leadership_multiplier(sponsor, current_congress)
|
||||
|
||||
raw_scores[bill.sponsor_id] = raw_scores.get(bill.sponsor_id, 0.0) + (
|
||||
pts * bipartisan * substance * leadership
|
||||
)
|
||||
|
||||
# Group members by (tier, party) for percentile normalisation
|
||||
# We treat party as a proxy for majority/minority — grouped separately so
|
||||
# a minority-party junior isn't unfairly compared to a majority-party senior.
|
||||
from collections import defaultdict
|
||||
buckets: dict[tuple, list[str]] = defaultdict(list)
|
||||
for m in members:
|
||||
tier = _seniority_tier(m.terms_json)
|
||||
party_bucket = m.party or "Unknown"
|
||||
buckets[(tier, party_bucket)].append(m.bioguide_id)
|
||||
|
||||
# Compute percentile within each bucket
|
||||
percentiles: dict[str, float] = {}
|
||||
tiers: dict[str, str] = {}
|
||||
for (tier, _), ids in buckets.items():
|
||||
scores = [(bid, raw_scores.get(bid, 0.0)) for bid in ids]
|
||||
scores.sort(key=lambda x: x[1])
|
||||
n = len(scores)
|
||||
for rank, (bid, _) in enumerate(scores):
|
||||
percentiles[bid] = round((rank / max(n - 1, 1)) * 100, 1)
|
||||
tiers[bid] = tier
|
||||
|
||||
# Bulk update members
|
||||
updated = 0
|
||||
for m in members:
|
||||
score = raw_scores.get(m.bioguide_id, 0.0)
|
||||
pct = percentiles.get(m.bioguide_id)
|
||||
tier = tiers.get(m.bioguide_id, _seniority_tier(m.terms_json))
|
||||
m.effectiveness_score = round(score, 2)
|
||||
m.effectiveness_percentile = pct
|
||||
m.effectiveness_tier = tier
|
||||
updated += 1
|
||||
|
||||
db.commit()
|
||||
logger.info(f"calculate_effectiveness_scores: updated {updated} members for Congress {current_congress}")
|
||||
return {"status": "ok", "updated": updated, "congress": current_congress}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"calculate_effectiveness_scores failed: {exc}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
112
backend/app/workers/celery_app.py
Normal file
112
backend/app/workers/celery_app.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
from kombu import Queue
|
||||
|
||||
from app.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"pocketveto",
|
||||
broker=settings.REDIS_URL,
|
||||
backend=settings.REDIS_URL,
|
||||
include=[
|
||||
"app.workers.congress_poller",
|
||||
"app.workers.document_fetcher",
|
||||
"app.workers.llm_processor",
|
||||
"app.workers.news_fetcher",
|
||||
"app.workers.trend_scorer",
|
||||
"app.workers.member_interest",
|
||||
"app.workers.notification_dispatcher",
|
||||
"app.workers.llm_batch_processor",
|
||||
"app.workers.bill_classifier",
|
||||
"app.workers.vote_fetcher",
|
||||
],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Late ack: task is only removed from queue after completion, not on pickup.
|
||||
# Combined with idempotent tasks, this ensures no work is lost if a worker crashes.
|
||||
task_acks_late=True,
|
||||
# Prevent workers from prefetching LLM tasks and blocking other workers.
|
||||
worker_prefetch_multiplier=1,
|
||||
# Route tasks to named queues
|
||||
task_routes={
|
||||
"app.workers.congress_poller.*": {"queue": "polling"},
|
||||
"app.workers.document_fetcher.*": {"queue": "documents"},
|
||||
"app.workers.llm_processor.*": {"queue": "llm"},
|
||||
"app.workers.llm_batch_processor.*": {"queue": "llm"},
|
||||
"app.workers.bill_classifier.*": {"queue": "llm"},
|
||||
"app.workers.news_fetcher.*": {"queue": "news"},
|
||||
"app.workers.trend_scorer.*": {"queue": "news"},
|
||||
"app.workers.member_interest.*": {"queue": "news"},
|
||||
"app.workers.notification_dispatcher.*": {"queue": "polling"},
|
||||
"app.workers.vote_fetcher.*": {"queue": "polling"},
|
||||
},
|
||||
task_queues=[
|
||||
Queue("polling"),
|
||||
Queue("documents"),
|
||||
Queue("llm"),
|
||||
Queue("news"),
|
||||
],
|
||||
# RedBeat stores schedule in Redis — restart-safe and dynamically updatable
|
||||
redbeat_redis_url=settings.REDIS_URL,
|
||||
beat_scheduler="redbeat.RedBeatScheduler",
|
||||
beat_schedule={
|
||||
"poll-congress-bills": {
|
||||
"task": "app.workers.congress_poller.poll_congress_bills",
|
||||
"schedule": crontab(minute=f"*/{settings.CONGRESS_POLL_INTERVAL_MINUTES}"),
|
||||
},
|
||||
"fetch-news-active-bills": {
|
||||
"task": "app.workers.news_fetcher.fetch_news_for_active_bills",
|
||||
"schedule": crontab(hour="*/6", minute=0),
|
||||
},
|
||||
"calculate-trend-scores": {
|
||||
"task": "app.workers.trend_scorer.calculate_all_trend_scores",
|
||||
"schedule": crontab(hour=2, minute=0),
|
||||
},
|
||||
"fetch-news-active-members": {
|
||||
"task": "app.workers.member_interest.fetch_news_for_active_members",
|
||||
"schedule": crontab(hour="*/12", minute=30),
|
||||
},
|
||||
"calculate-member-trend-scores": {
|
||||
"task": "app.workers.member_interest.calculate_all_member_trend_scores",
|
||||
"schedule": crontab(hour=3, minute=0),
|
||||
},
|
||||
"sync-members": {
|
||||
"task": "app.workers.congress_poller.sync_members",
|
||||
"schedule": crontab(hour=1, minute=0), # 1 AM UTC daily — refreshes chamber/district/contact info
|
||||
},
|
||||
"fetch-actions-active-bills": {
|
||||
"task": "app.workers.congress_poller.fetch_actions_for_active_bills",
|
||||
"schedule": crontab(hour=4, minute=0), # 4 AM UTC, after trend + member scoring
|
||||
},
|
||||
"fetch-votes-for-stanced-bills": {
|
||||
"task": "app.workers.vote_fetcher.fetch_votes_for_stanced_bills",
|
||||
"schedule": crontab(hour=4, minute=30), # 4:30 AM UTC daily
|
||||
},
|
||||
"dispatch-notifications": {
|
||||
"task": "app.workers.notification_dispatcher.dispatch_notifications",
|
||||
"schedule": crontab(minute="*/5"), # Every 5 minutes
|
||||
},
|
||||
"send-notification-digest": {
|
||||
"task": "app.workers.notification_dispatcher.send_notification_digest",
|
||||
"schedule": crontab(hour=8, minute=0), # 8 AM UTC daily
|
||||
},
|
||||
"send-weekly-digest": {
|
||||
"task": "app.workers.notification_dispatcher.send_weekly_digest",
|
||||
"schedule": crontab(hour=8, minute=30, day_of_week=1), # Monday 8:30 AM UTC
|
||||
},
|
||||
"poll-llm-batch-results": {
|
||||
"task": "app.workers.llm_batch_processor.poll_llm_batch_results",
|
||||
"schedule": crontab(minute="*/30"),
|
||||
},
|
||||
"calculate-effectiveness-scores": {
|
||||
"task": "app.workers.bill_classifier.calculate_effectiveness_scores",
|
||||
"schedule": crontab(hour=5, minute=0), # 5 AM UTC, after all other nightly tasks
|
||||
},
|
||||
},
|
||||
)
|
||||
480
backend/app/workers/congress_poller.py
Normal file
480
backend/app/workers/congress_poller.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Congress.gov poller — incremental bill and member sync.
|
||||
|
||||
Runs on Celery Beat schedule (every 30 min by default).
|
||||
Uses fromDateTime to fetch only recently updated bills.
|
||||
All operations are idempotent.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillAction, Member, AppSetting
|
||||
from app.services import congress_api
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_setting(db, key: str, default=None) -> str | None:
|
||||
row = db.get(AppSetting, key)
|
||||
return row.value if row else default
|
||||
|
||||
|
||||
def _set_setting(db, key: str, value: str) -> None:
|
||||
row = db.get(AppSetting, key)
|
||||
if row:
|
||||
row.value = value
|
||||
else:
|
||||
db.add(AppSetting(key=key, value=value))
|
||||
db.commit()
|
||||
|
||||
|
||||
# Only track legislation that can become law. Simple/concurrent resolutions
|
||||
# (hres, sres, hconres, sconres) are procedural and not worth analyzing.
|
||||
TRACKED_BILL_TYPES = {"hr", "s", "hjres", "sjres"}
|
||||
|
||||
# Action categories that produce new bill text versions on GovInfo.
|
||||
# Procedural/administrative actions (referral to committee, calendar placement)
|
||||
# rarely produce a new text version, so we skip document fetching for them.
|
||||
_DOC_PRODUCING_CATEGORIES = {"vote", "committee_report", "presidential", "new_document", "new_amendment"}
|
||||
|
||||
|
||||
def _is_congress_off_hours() -> bool:
|
||||
"""Return True during periods when Congress.gov is unlikely to publish new content."""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
now_est = datetime.now(ZoneInfo("America/New_York"))
|
||||
except Exception:
|
||||
return False
|
||||
# Weekends
|
||||
if now_est.weekday() >= 5:
|
||||
return True
|
||||
# Nights: before 9 AM or after 9 PM EST
|
||||
if now_est.hour < 9 or now_est.hour >= 21:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.congress_poller.poll_congress_bills")
|
||||
def poll_congress_bills(self):
|
||||
"""Fetch recently updated bills from Congress.gov and enqueue document + LLM processing."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
last_polled = _get_setting(db, "congress_last_polled_at")
|
||||
|
||||
# Adaptive: skip off-hours polls if last poll was recent (< 1 hour ago)
|
||||
if _is_congress_off_hours() and last_polled:
|
||||
try:
|
||||
last_dt = datetime.fromisoformat(last_polled.replace("Z", "+00:00"))
|
||||
if (datetime.now(timezone.utc) - last_dt) < timedelta(hours=1):
|
||||
logger.info("Skipping poll — off-hours and last poll < 1 hour ago")
|
||||
return {"new": 0, "updated": 0, "skipped": "off_hours"}
|
||||
except Exception:
|
||||
pass
|
||||
# On first run, seed from 2 months back rather than the full congress history
|
||||
if not last_polled:
|
||||
two_months_ago = datetime.now(timezone.utc) - timedelta(days=60)
|
||||
last_polled = two_months_ago.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
current_congress = congress_api.get_current_congress()
|
||||
logger.info(f"Polling Congress {current_congress} (since {last_polled})")
|
||||
|
||||
new_count = 0
|
||||
updated_count = 0
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
response = congress_api.get_bills(
|
||||
congress=current_congress,
|
||||
offset=offset,
|
||||
limit=250,
|
||||
from_date_time=last_polled,
|
||||
)
|
||||
bills_data = response.get("bills", [])
|
||||
if not bills_data:
|
||||
break
|
||||
|
||||
for bill_data in bills_data:
|
||||
parsed = congress_api.parse_bill_from_api(bill_data, current_congress)
|
||||
if parsed.get("bill_type") not in TRACKED_BILL_TYPES:
|
||||
continue
|
||||
bill_id = parsed["bill_id"]
|
||||
existing = db.get(Bill, bill_id)
|
||||
|
||||
if existing is None:
|
||||
# Save bill immediately; fetch sponsor detail asynchronously
|
||||
parsed["sponsor_id"] = None
|
||||
parsed["last_checked_at"] = datetime.now(timezone.utc)
|
||||
db.add(Bill(**parsed))
|
||||
db.commit()
|
||||
new_count += 1
|
||||
# Enqueue document, action, sponsor, and cosponsor fetches
|
||||
from app.workers.document_fetcher import fetch_bill_documents
|
||||
fetch_bill_documents.delay(bill_id)
|
||||
fetch_bill_actions.delay(bill_id)
|
||||
fetch_sponsor_for_bill.delay(
|
||||
bill_id, current_congress, parsed["bill_type"], parsed["bill_number"]
|
||||
)
|
||||
from app.workers.bill_classifier import fetch_bill_cosponsors
|
||||
fetch_bill_cosponsors.delay(bill_id)
|
||||
else:
|
||||
_update_bill_if_changed(db, existing, parsed)
|
||||
updated_count += 1
|
||||
|
||||
db.commit()
|
||||
offset += 250
|
||||
if len(bills_data) < 250:
|
||||
break
|
||||
|
||||
# Update last polled timestamp
|
||||
_set_setting(db, "congress_last_polled_at", datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"))
|
||||
logger.info(f"Poll complete: {new_count} new, {updated_count} updated")
|
||||
return {"new": new_count, "updated": updated_count}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Poll failed: {exc}")
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.congress_poller.sync_members")
|
||||
def sync_members(self):
|
||||
"""Sync current Congress members."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
offset = 0
|
||||
synced = 0
|
||||
while True:
|
||||
response = congress_api.get_members(offset=offset, limit=250, current_member=True)
|
||||
members_data = response.get("members", [])
|
||||
if not members_data:
|
||||
break
|
||||
|
||||
for member_data in members_data:
|
||||
parsed = congress_api.parse_member_from_api(member_data)
|
||||
if not parsed.get("bioguide_id"):
|
||||
continue
|
||||
existing = db.get(Member, parsed["bioguide_id"])
|
||||
if existing is None:
|
||||
db.add(Member(**parsed))
|
||||
else:
|
||||
for k, v in parsed.items():
|
||||
setattr(existing, k, v)
|
||||
synced += 1
|
||||
|
||||
db.commit()
|
||||
offset += 250
|
||||
if len(members_data) < 250:
|
||||
break
|
||||
|
||||
logger.info(f"Synced {synced} members")
|
||||
return {"synced": synced}
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
raise self.retry(exc=exc, countdown=120)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _sync_sponsor(db, bill_data: dict) -> str | None:
|
||||
"""Ensure the bill sponsor exists in the members table. Returns bioguide_id or None."""
|
||||
sponsors = bill_data.get("sponsors", [])
|
||||
if not sponsors:
|
||||
return None
|
||||
sponsor_raw = sponsors[0]
|
||||
bioguide_id = sponsor_raw.get("bioguideId")
|
||||
if not bioguide_id:
|
||||
return None
|
||||
existing = db.get(Member, bioguide_id)
|
||||
if existing is None:
|
||||
db.add(Member(
|
||||
bioguide_id=bioguide_id,
|
||||
name=sponsor_raw.get("fullName", ""),
|
||||
first_name=sponsor_raw.get("firstName"),
|
||||
last_name=sponsor_raw.get("lastName"),
|
||||
party=sponsor_raw.get("party", "")[:10] if sponsor_raw.get("party") else None,
|
||||
state=sponsor_raw.get("state"),
|
||||
))
|
||||
db.commit()
|
||||
return bioguide_id
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.congress_poller.fetch_sponsor_for_bill")
|
||||
def fetch_sponsor_for_bill(self, bill_id: str, congress: int, bill_type: str, bill_number: str):
|
||||
"""Async sponsor fetch: get bill detail from Congress.gov and link the sponsor. Idempotent."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
return {"status": "not_found"}
|
||||
if bill.sponsor_id:
|
||||
return {"status": "already_set", "sponsor_id": bill.sponsor_id}
|
||||
detail = congress_api.get_bill_detail(congress, bill_type, bill_number)
|
||||
sponsor_id = _sync_sponsor(db, detail.get("bill", {}))
|
||||
if sponsor_id:
|
||||
bill.sponsor_id = sponsor_id
|
||||
db.commit()
|
||||
return {"status": "ok", "sponsor_id": sponsor_id}
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.congress_poller.backfill_sponsor_ids")
|
||||
def backfill_sponsor_ids(self):
|
||||
"""Backfill sponsor_id for all bills where it is NULL by fetching bill detail from Congress.gov."""
|
||||
import time
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bills = db.query(Bill).filter(Bill.sponsor_id.is_(None)).all()
|
||||
total = len(bills)
|
||||
updated = 0
|
||||
logger.info(f"Backfilling sponsors for {total} bills")
|
||||
for bill in bills:
|
||||
try:
|
||||
detail = congress_api.get_bill_detail(bill.congress_number, bill.bill_type, bill.bill_number)
|
||||
sponsor_id = _sync_sponsor(db, detail.get("bill", {}))
|
||||
if sponsor_id:
|
||||
bill.sponsor_id = sponsor_id
|
||||
db.commit()
|
||||
updated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not backfill sponsor for {bill.bill_id}: {e}")
|
||||
time.sleep(0.1) # ~10 req/sec, well under Congress.gov 5000/hr limit
|
||||
logger.info(f"Sponsor backfill complete: {updated}/{total} updated")
|
||||
return {"total": total, "updated": updated}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.congress_poller.fetch_bill_actions")
|
||||
def fetch_bill_actions(self, bill_id: str):
|
||||
"""Fetch and sync all actions for a bill from Congress.gov. Idempotent."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
logger.warning(f"fetch_bill_actions: bill {bill_id} not found")
|
||||
return
|
||||
|
||||
offset = 0
|
||||
inserted = 0
|
||||
while True:
|
||||
try:
|
||||
response = congress_api.get_bill_actions(
|
||||
bill.congress_number, bill.bill_type, bill.bill_number, offset=offset
|
||||
)
|
||||
except Exception as exc:
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
|
||||
actions_data = response.get("actions", [])
|
||||
if not actions_data:
|
||||
break
|
||||
|
||||
for action in actions_data:
|
||||
stmt = pg_insert(BillAction.__table__).values(
|
||||
bill_id=bill_id,
|
||||
action_date=action.get("actionDate"),
|
||||
action_text=action.get("text", ""),
|
||||
action_type=action.get("type"),
|
||||
chamber=action.get("chamber"),
|
||||
).on_conflict_do_nothing(constraint="uq_bill_actions_bill_date_text")
|
||||
result = db.execute(stmt)
|
||||
inserted += result.rowcount
|
||||
|
||||
db.commit()
|
||||
offset += 250
|
||||
if len(actions_data) < 250:
|
||||
break
|
||||
|
||||
bill.actions_fetched_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
logger.info(f"fetch_bill_actions: {bill_id} — inserted {inserted} new actions")
|
||||
return {"bill_id": bill_id, "inserted": inserted}
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.congress_poller.fetch_actions_for_active_bills")
|
||||
def fetch_actions_for_active_bills(self):
|
||||
"""Nightly batch: enqueue action fetches for recently active bills missing action data."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
cutoff = datetime.now(timezone.utc).date() - timedelta(days=30)
|
||||
bills = (
|
||||
db.query(Bill)
|
||||
.filter(
|
||||
Bill.latest_action_date >= cutoff,
|
||||
or_(
|
||||
Bill.actions_fetched_at.is_(None),
|
||||
Bill.latest_action_date > Bill.actions_fetched_at,
|
||||
),
|
||||
)
|
||||
.limit(200)
|
||||
.all()
|
||||
)
|
||||
queued = 0
|
||||
for bill in bills:
|
||||
fetch_bill_actions.delay(bill.bill_id)
|
||||
queued += 1
|
||||
time.sleep(0.2) # ~5 tasks/sec to avoid Redis burst
|
||||
logger.info(f"fetch_actions_for_active_bills: queued {queued} bills")
|
||||
return {"queued": queued}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.congress_poller.backfill_all_bill_actions")
|
||||
def backfill_all_bill_actions(self):
|
||||
"""One-time backfill: enqueue action fetches for every bill that has never had actions fetched."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bills = (
|
||||
db.query(Bill)
|
||||
.filter(Bill.actions_fetched_at.is_(None))
|
||||
.order_by(Bill.latest_action_date.desc())
|
||||
.all()
|
||||
)
|
||||
queued = 0
|
||||
for bill in bills:
|
||||
fetch_bill_actions.delay(bill.bill_id)
|
||||
queued += 1
|
||||
time.sleep(0.05) # ~20 tasks/sec — workers will self-throttle against Congress.gov
|
||||
logger.info(f"backfill_all_bill_actions: queued {queued} bills")
|
||||
return {"queued": queued}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _update_bill_if_changed(db, existing: Bill, parsed: dict) -> bool:
|
||||
"""Update bill fields if anything has changed. Returns True if updated."""
|
||||
changed = False
|
||||
dirty = False
|
||||
|
||||
# Meaningful change fields — trigger document + action fetch when updated
|
||||
track_fields = ["title", "short_title", "latest_action_date", "latest_action_text", "status"]
|
||||
for field in track_fields:
|
||||
new_val = parsed.get(field)
|
||||
if new_val and getattr(existing, field) != new_val:
|
||||
setattr(existing, field, new_val)
|
||||
changed = True
|
||||
dirty = True
|
||||
|
||||
# Static fields — only fill in if currently null; no change trigger needed
|
||||
fill_null_fields = ["introduced_date", "congress_url", "chamber"]
|
||||
for field in fill_null_fields:
|
||||
new_val = parsed.get(field)
|
||||
if new_val and getattr(existing, field) is None:
|
||||
setattr(existing, field, new_val)
|
||||
dirty = True
|
||||
|
||||
if changed:
|
||||
existing.last_checked_at = datetime.now(timezone.utc)
|
||||
if dirty:
|
||||
db.commit()
|
||||
if changed:
|
||||
from app.workers.notification_utils import (
|
||||
emit_bill_notification,
|
||||
emit_member_follow_notifications,
|
||||
emit_topic_follow_notifications,
|
||||
categorize_action,
|
||||
)
|
||||
action_text = parsed.get("latest_action_text", "")
|
||||
action_category = categorize_action(action_text)
|
||||
# Only fetch new documents for actions that produce new text versions on GovInfo.
|
||||
# Skip procedural/administrative actions (referral, calendar) to avoid unnecessary calls.
|
||||
if not action_category or action_category in _DOC_PRODUCING_CATEGORIES:
|
||||
from app.workers.document_fetcher import fetch_bill_documents
|
||||
fetch_bill_documents.delay(existing.bill_id)
|
||||
fetch_bill_actions.delay(existing.bill_id)
|
||||
if action_category:
|
||||
emit_bill_notification(db, existing, "bill_updated", action_text, action_category=action_category)
|
||||
emit_member_follow_notifications(db, existing, "bill_updated", action_text, action_category=action_category)
|
||||
# Topic followers — pull tags from the bill's latest brief
|
||||
from app.models.brief import BillBrief
|
||||
latest_brief = (
|
||||
db.query(BillBrief)
|
||||
.filter_by(bill_id=existing.bill_id)
|
||||
.order_by(BillBrief.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
topic_tags = latest_brief.topic_tags or [] if latest_brief else []
|
||||
emit_topic_follow_notifications(
|
||||
db, existing, "bill_updated", action_text, topic_tags, action_category=action_category
|
||||
)
|
||||
return changed
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.congress_poller.backfill_bill_metadata")
|
||||
def backfill_bill_metadata(self):
|
||||
"""
|
||||
Find bills with null introduced_date (or other static fields) and
|
||||
re-fetch their detail from Congress.gov to fill in the missing values.
|
||||
No document or LLM calls — metadata only.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
from sqlalchemy import text as sa_text
|
||||
rows = db.execute(sa_text("""
|
||||
SELECT bill_id, congress_number, bill_type, bill_number
|
||||
FROM bills
|
||||
WHERE introduced_date IS NULL
|
||||
OR congress_url IS NULL
|
||||
OR chamber IS NULL
|
||||
""")).fetchall()
|
||||
|
||||
updated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
try:
|
||||
detail = congress_api.get_bill_detail(
|
||||
row.congress_number, row.bill_type, row.bill_number
|
||||
)
|
||||
bill_data = detail.get("bill", {})
|
||||
parsed = congress_api.parse_bill_from_api(
|
||||
{
|
||||
"type": row.bill_type,
|
||||
"number": row.bill_number,
|
||||
"introducedDate": bill_data.get("introducedDate"),
|
||||
"title": bill_data.get("title"),
|
||||
"shortTitle": bill_data.get("shortTitle"),
|
||||
"latestAction": bill_data.get("latestAction") or {},
|
||||
},
|
||||
row.congress_number,
|
||||
)
|
||||
bill = db.get(Bill, row.bill_id)
|
||||
if not bill:
|
||||
skipped += 1
|
||||
continue
|
||||
fill_null_fields = ["introduced_date", "congress_url", "chamber", "title", "short_title"]
|
||||
dirty = False
|
||||
for field in fill_null_fields:
|
||||
new_val = parsed.get(field)
|
||||
if new_val and getattr(bill, field) is None:
|
||||
setattr(bill, field, new_val)
|
||||
dirty = True
|
||||
if dirty:
|
||||
db.commit()
|
||||
updated += 1
|
||||
else:
|
||||
skipped += 1
|
||||
time.sleep(0.2) # ~300 req/min — well under the 5k/hr limit
|
||||
except Exception as exc:
|
||||
logger.warning(f"backfill_bill_metadata: failed for {row.bill_id}: {exc}")
|
||||
skipped += 1
|
||||
|
||||
logger.info(f"backfill_bill_metadata: {updated} updated, {skipped} skipped")
|
||||
return {"updated": updated, "skipped": skipped}
|
||||
finally:
|
||||
db.close()
|
||||
92
backend/app/workers/document_fetcher.py
Normal file
92
backend/app/workers/document_fetcher.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Document fetcher — retrieves bill text from GovInfo and stores it.
|
||||
Triggered by congress_poller when a new bill is detected.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillDocument
|
||||
from app.services import congress_api, govinfo_api
|
||||
from app.services.govinfo_api import DocumentUnchangedError
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, name="app.workers.document_fetcher.fetch_bill_documents")
|
||||
def fetch_bill_documents(self, bill_id: str):
|
||||
"""Fetch bill text from GovInfo and store it. Then enqueue LLM processing."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
logger.warning(f"Bill {bill_id} not found in DB")
|
||||
return {"status": "not_found"}
|
||||
|
||||
# Get text versions from Congress.gov
|
||||
try:
|
||||
text_response = congress_api.get_bill_text_versions(
|
||||
bill.congress_number, bill.bill_type, bill.bill_number
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"No text versions for {bill_id}: {e}")
|
||||
return {"status": "no_text_versions"}
|
||||
|
||||
text_versions = text_response.get("textVersions", [])
|
||||
if not text_versions:
|
||||
return {"status": "no_text_versions"}
|
||||
|
||||
url, fmt = govinfo_api.find_best_text_url(text_versions)
|
||||
if not url:
|
||||
return {"status": "no_suitable_format"}
|
||||
|
||||
# Idempotency: skip if we already have this exact document version
|
||||
existing = (
|
||||
db.query(BillDocument)
|
||||
.filter_by(bill_id=bill_id, govinfo_url=url)
|
||||
.filter(BillDocument.raw_text.isnot(None))
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
return {"status": "already_fetched", "bill_id": bill_id}
|
||||
|
||||
logger.info(f"Fetching {bill_id} document ({fmt}) from {url}")
|
||||
try:
|
||||
raw_text = govinfo_api.fetch_text_from_url(url, fmt)
|
||||
except DocumentUnchangedError:
|
||||
logger.info(f"Document unchanged for {bill_id} (ETag match) — skipping")
|
||||
return {"status": "unchanged", "bill_id": bill_id}
|
||||
if not raw_text:
|
||||
raise ValueError(f"Empty text returned for {bill_id}")
|
||||
|
||||
# Get version label from first text version
|
||||
type_obj = text_versions[0].get("type", {}) if text_versions else {}
|
||||
doc_version = type_obj.get("name") if isinstance(type_obj, dict) else type_obj
|
||||
|
||||
doc = BillDocument(
|
||||
bill_id=bill_id,
|
||||
doc_type="bill_text",
|
||||
doc_version=doc_version,
|
||||
govinfo_url=url,
|
||||
raw_text=raw_text,
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
|
||||
logger.info(f"Stored document {doc.id} for bill {bill_id} ({len(raw_text):,} chars)")
|
||||
|
||||
# Enqueue LLM processing
|
||||
from app.workers.llm_processor import process_document_with_llm
|
||||
process_document_with_llm.delay(doc.id)
|
||||
|
||||
return {"status": "ok", "document_id": doc.id, "chars": len(raw_text)}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Document fetch failed for {bill_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=120)
|
||||
finally:
|
||||
db.close()
|
||||
401
backend/app/workers/llm_batch_processor.py
Normal file
401
backend/app/workers/llm_batch_processor.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
LLM Batch processor — submits and polls OpenAI/Anthropic Batch API jobs.
|
||||
50% cheaper than synchronous calls; 24-hour processing window.
|
||||
New bills still use the synchronous llm_processor task.
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillBrief, BillDocument, Member
|
||||
from app.models.setting import AppSetting
|
||||
from app.services.llm_service import (
|
||||
AMENDMENT_SYSTEM_PROMPT,
|
||||
MAX_TOKENS_DEFAULT,
|
||||
SYSTEM_PROMPT,
|
||||
build_amendment_prompt,
|
||||
build_prompt,
|
||||
parse_brief_json,
|
||||
)
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BATCH_SETTING_KEY = "llm_active_batch"
|
||||
|
||||
|
||||
# ── State helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _save_batch_state(db, state: dict):
|
||||
row = db.get(AppSetting, _BATCH_SETTING_KEY)
|
||||
if row:
|
||||
row.value = json.dumps(state)
|
||||
else:
|
||||
row = AppSetting(key=_BATCH_SETTING_KEY, value=json.dumps(state))
|
||||
db.add(row)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _clear_batch_state(db):
|
||||
row = db.get(AppSetting, _BATCH_SETTING_KEY)
|
||||
if row:
|
||||
db.delete(row)
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── Request builder ────────────────────────────────────────────────────────────
|
||||
|
||||
def _build_request_data(db, doc_id: int, bill_id: str) -> tuple[str, str, str]:
|
||||
"""Returns (custom_id, system_prompt, user_prompt) for a document."""
|
||||
doc = db.get(BillDocument, doc_id)
|
||||
if not doc or not doc.raw_text:
|
||||
raise ValueError(f"Document {doc_id} missing or has no text")
|
||||
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise ValueError(f"Bill {bill_id} not found")
|
||||
|
||||
sponsor = db.get(Member, bill.sponsor_id) if bill.sponsor_id else None
|
||||
|
||||
bill_metadata = {
|
||||
"title": bill.title or "Unknown Title",
|
||||
"sponsor_name": sponsor.name if sponsor else "Unknown",
|
||||
"party": sponsor.party if sponsor else "Unknown",
|
||||
"state": sponsor.state if sponsor else "Unknown",
|
||||
"chamber": bill.chamber or "Unknown",
|
||||
"introduced_date": str(bill.introduced_date) if bill.introduced_date else "Unknown",
|
||||
"latest_action_text": bill.latest_action_text or "None",
|
||||
"latest_action_date": str(bill.latest_action_date) if bill.latest_action_date else "Unknown",
|
||||
}
|
||||
|
||||
previous_full_brief = (
|
||||
db.query(BillBrief)
|
||||
.filter_by(bill_id=bill_id, brief_type="full")
|
||||
.order_by(BillBrief.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if previous_full_brief and previous_full_brief.document_id:
|
||||
previous_doc = db.get(BillDocument, previous_full_brief.document_id)
|
||||
if previous_doc and previous_doc.raw_text:
|
||||
brief_type = "amendment"
|
||||
prompt = build_amendment_prompt(doc.raw_text, previous_doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
system_prompt = AMENDMENT_SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text."
|
||||
else:
|
||||
brief_type = "full"
|
||||
prompt = build_prompt(doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
system_prompt = SYSTEM_PROMPT
|
||||
else:
|
||||
brief_type = "full"
|
||||
prompt = build_prompt(doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
||||
system_prompt = SYSTEM_PROMPT
|
||||
|
||||
custom_id = f"doc-{doc_id}-{brief_type}"
|
||||
return custom_id, system_prompt, prompt
|
||||
|
||||
|
||||
# ── Submit task ────────────────────────────────────────────────────────────────
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.llm_batch_processor.submit_llm_batch")
|
||||
def submit_llm_batch(self):
|
||||
"""Submit all unbriefed documents to the OpenAI or Anthropic Batch API."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
prov_row = db.get(AppSetting, "llm_provider")
|
||||
model_row = db.get(AppSetting, "llm_model")
|
||||
provider_name = ((prov_row.value if prov_row else None) or settings.LLM_PROVIDER).lower()
|
||||
|
||||
if provider_name not in ("openai", "anthropic"):
|
||||
return {"status": "unsupported", "provider": provider_name}
|
||||
|
||||
# Check for already-active batch
|
||||
active_row = db.get(AppSetting, _BATCH_SETTING_KEY)
|
||||
if active_row:
|
||||
try:
|
||||
active = json.loads(active_row.value)
|
||||
if active.get("status") == "processing":
|
||||
return {"status": "already_active", "batch_id": active.get("batch_id")}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Find docs with text but no brief
|
||||
rows = db.execute(text("""
|
||||
SELECT bd.id AS doc_id, bd.bill_id, bd.govinfo_url
|
||||
FROM bill_documents bd
|
||||
LEFT JOIN bill_briefs bb ON bb.document_id = bd.id
|
||||
WHERE bd.raw_text IS NOT NULL AND bb.id IS NULL
|
||||
LIMIT 1000
|
||||
""")).fetchall()
|
||||
|
||||
if not rows:
|
||||
return {"status": "nothing_to_process"}
|
||||
|
||||
doc_ids = [r.doc_id for r in rows]
|
||||
|
||||
if provider_name == "openai":
|
||||
model = (model_row.value if model_row else None) or settings.OPENAI_MODEL
|
||||
batch_id = _submit_openai_batch(db, rows, model)
|
||||
else:
|
||||
model = (model_row.value if model_row else None) or settings.ANTHROPIC_MODEL
|
||||
batch_id = _submit_anthropic_batch(db, rows, model)
|
||||
|
||||
state = {
|
||||
"batch_id": batch_id,
|
||||
"provider": provider_name,
|
||||
"model": model,
|
||||
"doc_ids": doc_ids,
|
||||
"doc_count": len(doc_ids),
|
||||
"submitted_at": datetime.utcnow().isoformat(),
|
||||
"status": "processing",
|
||||
}
|
||||
_save_batch_state(db, state)
|
||||
logger.info(f"Submitted {len(doc_ids)}-doc batch to {provider_name}: {batch_id}")
|
||||
return {"status": "submitted", "batch_id": batch_id, "doc_count": len(doc_ids)}
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _submit_openai_batch(db, rows, model: str) -> str:
|
||||
from openai import OpenAI
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
|
||||
lines = []
|
||||
for row in rows:
|
||||
try:
|
||||
custom_id, system_prompt, prompt = _build_request_data(db, row.doc_id, row.bill_id)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Skipping doc {row.doc_id}: {exc}")
|
||||
continue
|
||||
lines.append(json.dumps({
|
||||
"custom_id": custom_id,
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"response_format": {"type": "json_object"},
|
||||
"temperature": 0.1,
|
||||
"max_tokens": MAX_TOKENS_DEFAULT,
|
||||
},
|
||||
}))
|
||||
|
||||
jsonl_bytes = "\n".join(lines).encode()
|
||||
file_obj = client.files.create(
|
||||
file=("batch.jsonl", io.BytesIO(jsonl_bytes), "application/jsonl"),
|
||||
purpose="batch",
|
||||
)
|
||||
batch = client.batches.create(
|
||||
input_file_id=file_obj.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
return batch.id
|
||||
|
||||
|
||||
def _submit_anthropic_batch(db, rows, model: str) -> str:
|
||||
import anthropic
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
|
||||
requests = []
|
||||
for row in rows:
|
||||
try:
|
||||
custom_id, system_prompt, prompt = _build_request_data(db, row.doc_id, row.bill_id)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Skipping doc {row.doc_id}: {exc}")
|
||||
continue
|
||||
requests.append({
|
||||
"custom_id": custom_id,
|
||||
"params": {
|
||||
"model": model,
|
||||
"max_tokens": 4096,
|
||||
"system": [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}],
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
})
|
||||
|
||||
batch = client.messages.batches.create(requests=requests)
|
||||
return batch.id
|
||||
|
||||
|
||||
# ── Poll task ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.llm_batch_processor.poll_llm_batch_results")
|
||||
def poll_llm_batch_results(self):
|
||||
"""Check active batch status and import completed results (runs every 30 min via beat)."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
active_row = db.get(AppSetting, _BATCH_SETTING_KEY)
|
||||
if not active_row:
|
||||
return {"status": "no_active_batch"}
|
||||
|
||||
try:
|
||||
state = json.loads(active_row.value)
|
||||
except Exception:
|
||||
_clear_batch_state(db)
|
||||
return {"status": "invalid_state"}
|
||||
|
||||
batch_id = state["batch_id"]
|
||||
provider_name = state["provider"]
|
||||
model = state["model"]
|
||||
|
||||
if provider_name == "openai":
|
||||
return _poll_openai(db, state, batch_id, model)
|
||||
elif provider_name == "anthropic":
|
||||
return _poll_anthropic(db, state, batch_id, model)
|
||||
else:
|
||||
_clear_batch_state(db)
|
||||
return {"status": "unknown_provider"}
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ── Result processing helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _save_brief(db, doc_id: int, bill_id: str, brief, brief_type: str, govinfo_url) -> bool:
|
||||
"""Idempotency check + save. Returns True if saved, False if already exists."""
|
||||
if db.query(BillBrief).filter_by(document_id=doc_id).first():
|
||||
return False
|
||||
|
||||
db_brief = BillBrief(
|
||||
bill_id=bill_id,
|
||||
document_id=doc_id,
|
||||
brief_type=brief_type,
|
||||
summary=brief.summary,
|
||||
key_points=brief.key_points,
|
||||
risks=brief.risks,
|
||||
deadlines=brief.deadlines,
|
||||
topic_tags=brief.topic_tags,
|
||||
llm_provider=brief.llm_provider,
|
||||
llm_model=brief.llm_model,
|
||||
govinfo_url=govinfo_url,
|
||||
)
|
||||
db.add(db_brief)
|
||||
db.commit()
|
||||
db.refresh(db_brief)
|
||||
return True
|
||||
|
||||
|
||||
def _emit_notifications_and_news(db, bill_id: str, brief, brief_type: str):
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
return
|
||||
from app.workers.notification_utils import (
|
||||
emit_bill_notification,
|
||||
emit_member_follow_notifications,
|
||||
emit_topic_follow_notifications,
|
||||
)
|
||||
event_type = "new_amendment" if brief_type == "amendment" else "new_document"
|
||||
emit_bill_notification(db, bill, event_type, brief.summary)
|
||||
emit_member_follow_notifications(db, bill, event_type, brief.summary)
|
||||
emit_topic_follow_notifications(db, bill, event_type, brief.summary, brief.topic_tags or [])
|
||||
|
||||
from app.workers.news_fetcher import fetch_news_for_bill
|
||||
fetch_news_for_bill.delay(bill_id)
|
||||
|
||||
|
||||
def _parse_custom_id(custom_id: str) -> tuple[int, str]:
|
||||
"""Parse 'doc-{doc_id}-{brief_type}' → (doc_id, brief_type)."""
|
||||
parts = custom_id.split("-")
|
||||
return int(parts[1]), parts[2]
|
||||
|
||||
|
||||
def _poll_openai(db, state: dict, batch_id: str, model: str) -> dict:
|
||||
from openai import OpenAI
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
|
||||
batch = client.batches.retrieve(batch_id)
|
||||
logger.info(f"OpenAI batch {batch_id} status: {batch.status}")
|
||||
|
||||
if batch.status in ("failed", "cancelled", "expired"):
|
||||
_clear_batch_state(db)
|
||||
return {"status": batch.status}
|
||||
|
||||
if batch.status != "completed":
|
||||
return {"status": "processing", "batch_status": batch.status}
|
||||
|
||||
content = client.files.content(batch.output_file_id).read().decode()
|
||||
saved = failed = 0
|
||||
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
item = json.loads(line)
|
||||
custom_id = item["custom_id"]
|
||||
doc_id, brief_type = _parse_custom_id(custom_id)
|
||||
|
||||
if item.get("error"):
|
||||
logger.warning(f"Batch result error for {custom_id}: {item['error']}")
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
raw = item["response"]["body"]["choices"][0]["message"]["content"]
|
||||
brief = parse_brief_json(raw, "openai", model)
|
||||
|
||||
doc = db.get(BillDocument, doc_id)
|
||||
if not doc:
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if _save_brief(db, doc_id, doc.bill_id, brief, brief_type, doc.govinfo_url):
|
||||
_emit_notifications_and_news(db, doc.bill_id, brief, brief_type)
|
||||
saved += 1
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to process OpenAI batch result line: {exc}")
|
||||
failed += 1
|
||||
|
||||
_clear_batch_state(db)
|
||||
logger.info(f"OpenAI batch {batch_id} complete: {saved} saved, {failed} failed")
|
||||
return {"status": "completed", "saved": saved, "failed": failed}
|
||||
|
||||
|
||||
def _poll_anthropic(db, state: dict, batch_id: str, model: str) -> dict:
|
||||
import anthropic
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
|
||||
batch = client.messages.batches.retrieve(batch_id)
|
||||
logger.info(f"Anthropic batch {batch_id} processing_status: {batch.processing_status}")
|
||||
|
||||
if batch.processing_status != "ended":
|
||||
return {"status": "processing", "batch_status": batch.processing_status}
|
||||
|
||||
saved = failed = 0
|
||||
|
||||
for result in client.messages.batches.results(batch_id):
|
||||
try:
|
||||
custom_id = result.custom_id
|
||||
doc_id, brief_type = _parse_custom_id(custom_id)
|
||||
|
||||
if result.result.type != "succeeded":
|
||||
logger.warning(f"Batch result {custom_id} type: {result.result.type}")
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
raw = result.result.message.content[0].text
|
||||
brief = parse_brief_json(raw, "anthropic", model)
|
||||
|
||||
doc = db.get(BillDocument, doc_id)
|
||||
if not doc:
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if _save_brief(db, doc_id, doc.bill_id, brief, brief_type, doc.govinfo_url):
|
||||
_emit_notifications_and_news(db, doc.bill_id, brief, brief_type)
|
||||
saved += 1
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to process Anthropic batch result: {exc}")
|
||||
failed += 1
|
||||
|
||||
_clear_batch_state(db)
|
||||
logger.info(f"Anthropic batch {batch_id} complete: {saved} saved, {failed} failed")
|
||||
return {"status": "completed", "saved": saved, "failed": failed}
|
||||
380
backend/app/workers/llm_processor.py
Normal file
380
backend/app/workers/llm_processor.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
LLM processor — generates AI briefs for fetched bill documents.
|
||||
Triggered by document_fetcher after successful text retrieval.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillBrief, BillDocument, Member
|
||||
from app.services.llm_service import RateLimitError, get_llm_provider
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
max_retries=8,
|
||||
rate_limit=f"{settings.LLM_RATE_LIMIT_RPM}/m",
|
||||
name="app.workers.llm_processor.process_document_with_llm",
|
||||
)
|
||||
def process_document_with_llm(self, document_id: int):
|
||||
"""Generate an AI brief for a bill document. Full brief for first version, amendment brief for subsequent versions."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
# Idempotency: skip if brief already exists for this document
|
||||
existing = db.query(BillBrief).filter_by(document_id=document_id).first()
|
||||
if existing:
|
||||
return {"status": "already_processed", "brief_id": existing.id}
|
||||
|
||||
doc = db.get(BillDocument, document_id)
|
||||
if not doc or not doc.raw_text:
|
||||
logger.warning(f"Document {document_id} not found or has no text")
|
||||
return {"status": "no_document"}
|
||||
|
||||
bill = db.get(Bill, doc.bill_id)
|
||||
if not bill:
|
||||
return {"status": "no_bill"}
|
||||
|
||||
sponsor = db.get(Member, bill.sponsor_id) if bill.sponsor_id else None
|
||||
|
||||
bill_metadata = {
|
||||
"title": bill.title or "Unknown Title",
|
||||
"sponsor_name": sponsor.name if sponsor else "Unknown",
|
||||
"party": sponsor.party if sponsor else "Unknown",
|
||||
"state": sponsor.state if sponsor else "Unknown",
|
||||
"chamber": bill.chamber or "Unknown",
|
||||
"introduced_date": str(bill.introduced_date) if bill.introduced_date else "Unknown",
|
||||
"latest_action_text": bill.latest_action_text or "None",
|
||||
"latest_action_date": str(bill.latest_action_date) if bill.latest_action_date else "Unknown",
|
||||
}
|
||||
|
||||
# Check if a full brief already exists for this bill (from an earlier document version)
|
||||
previous_full_brief = (
|
||||
db.query(BillBrief)
|
||||
.filter_by(bill_id=doc.bill_id, brief_type="full")
|
||||
.order_by(BillBrief.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
from app.models.setting import AppSetting
|
||||
prov_row = db.get(AppSetting, "llm_provider")
|
||||
model_row = db.get(AppSetting, "llm_model")
|
||||
provider = get_llm_provider(
|
||||
prov_row.value if prov_row else None,
|
||||
model_row.value if model_row else None,
|
||||
)
|
||||
|
||||
if previous_full_brief and previous_full_brief.document_id:
|
||||
# New version of a bill we've already analyzed — generate amendment brief
|
||||
previous_doc = db.get(BillDocument, previous_full_brief.document_id)
|
||||
if previous_doc and previous_doc.raw_text:
|
||||
logger.info(f"Generating amendment brief for document {document_id} (bill {doc.bill_id})")
|
||||
brief = provider.generate_amendment_brief(doc.raw_text, previous_doc.raw_text, bill_metadata)
|
||||
brief_type = "amendment"
|
||||
else:
|
||||
logger.info(f"Previous document unavailable, generating full brief for document {document_id}")
|
||||
brief = provider.generate_brief(doc.raw_text, bill_metadata)
|
||||
brief_type = "full"
|
||||
else:
|
||||
logger.info(f"Generating full brief for document {document_id} (bill {doc.bill_id})")
|
||||
brief = provider.generate_brief(doc.raw_text, bill_metadata)
|
||||
brief_type = "full"
|
||||
|
||||
db_brief = BillBrief(
|
||||
bill_id=doc.bill_id,
|
||||
document_id=document_id,
|
||||
brief_type=brief_type,
|
||||
summary=brief.summary,
|
||||
key_points=brief.key_points,
|
||||
risks=brief.risks,
|
||||
deadlines=brief.deadlines,
|
||||
topic_tags=brief.topic_tags,
|
||||
llm_provider=brief.llm_provider,
|
||||
llm_model=brief.llm_model,
|
||||
govinfo_url=doc.govinfo_url,
|
||||
)
|
||||
db.add(db_brief)
|
||||
db.commit()
|
||||
db.refresh(db_brief)
|
||||
|
||||
logger.info(f"{brief_type.capitalize()} brief {db_brief.id} created for bill {doc.bill_id} using {brief.llm_provider}/{brief.llm_model}")
|
||||
|
||||
# Emit notification events for bill followers, sponsor followers, and topic followers
|
||||
from app.workers.notification_utils import (
|
||||
emit_bill_notification,
|
||||
emit_member_follow_notifications,
|
||||
emit_topic_follow_notifications,
|
||||
)
|
||||
event_type = "new_amendment" if brief_type == "amendment" else "new_document"
|
||||
emit_bill_notification(db, bill, event_type, brief.summary)
|
||||
emit_member_follow_notifications(db, bill, event_type, brief.summary)
|
||||
emit_topic_follow_notifications(db, bill, event_type, brief.summary, brief.topic_tags or [])
|
||||
|
||||
# Trigger news fetch now that we have topic tags
|
||||
from app.workers.news_fetcher import fetch_news_for_bill
|
||||
fetch_news_for_bill.delay(doc.bill_id)
|
||||
|
||||
# Classify bill as substantive / commemorative / administrative
|
||||
from app.workers.bill_classifier import classify_bill_category
|
||||
classify_bill_category.delay(doc.bill_id, document_id)
|
||||
|
||||
return {"status": "ok", "brief_id": db_brief.id, "brief_type": brief_type}
|
||||
|
||||
except RateLimitError as exc:
|
||||
db.rollback()
|
||||
logger.warning(f"LLM rate limit hit ({exc.provider}); retrying in {exc.retry_after}s")
|
||||
raise self.retry(exc=exc, countdown=exc.retry_after)
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"LLM processing failed for document {document_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=300) # 5 min backoff for other failures
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.llm_processor.backfill_brief_citations")
|
||||
def backfill_brief_citations(self):
|
||||
"""
|
||||
Find briefs generated before citation support was added (key_points contains plain
|
||||
strings instead of {text, citation, quote} objects), delete them, and re-queue
|
||||
LLM processing against the already-stored document text.
|
||||
|
||||
No Congress.gov or GovInfo calls — only LLM calls.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
uncited = db.execute(text("""
|
||||
SELECT id, document_id, bill_id
|
||||
FROM bill_briefs
|
||||
WHERE key_points IS NOT NULL
|
||||
AND jsonb_array_length(key_points) > 0
|
||||
AND jsonb_typeof(key_points->0) = 'string'
|
||||
""")).fetchall()
|
||||
|
||||
total = len(uncited)
|
||||
queued = 0
|
||||
skipped = 0
|
||||
|
||||
for row in uncited:
|
||||
if not row.document_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Confirm the document still has text before deleting the brief
|
||||
doc = db.get(BillDocument, row.document_id)
|
||||
if not doc or not doc.raw_text:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
brief = db.get(BillBrief, row.id)
|
||||
if brief:
|
||||
db.delete(brief)
|
||||
db.commit()
|
||||
|
||||
process_document_with_llm.delay(row.document_id)
|
||||
queued += 1
|
||||
time.sleep(0.1) # Avoid burst-queuing all LLM tasks at once
|
||||
|
||||
logger.info(
|
||||
f"backfill_brief_citations: {total} uncited briefs found, "
|
||||
f"{queued} re-queued, {skipped} skipped (no document text)"
|
||||
)
|
||||
return {"total": total, "queued": queued, "skipped": skipped}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.llm_processor.backfill_brief_labels")
|
||||
def backfill_brief_labels(self):
|
||||
"""
|
||||
Add fact/inference labels to existing cited brief points without re-generating them.
|
||||
Sends one compact classification call per brief (all unlabeled points batched).
|
||||
Skips briefs already fully labeled and plain-string points (no quote to classify).
|
||||
"""
|
||||
import json
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from app.models.setting import AppSetting
|
||||
|
||||
db = get_sync_db()
|
||||
try:
|
||||
# Step 1: Bulk auto-label quoteless unlabeled points as "inference" via raw SQL.
|
||||
# This runs before any ORM objects are loaded so the session identity map cannot
|
||||
# interfere with the commit (the classic "ORM flush overwrites raw UPDATE" trap).
|
||||
_BULK_AUTO_LABEL = """
|
||||
UPDATE bill_briefs SET {col} = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN jsonb_typeof(p) = 'object'
|
||||
AND (p->>'label') IS NULL
|
||||
AND (p->>'quote') IS NULL
|
||||
THEN p || '{{"label":"inference"}}'
|
||||
ELSE p
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements({col}) AS p
|
||||
)
|
||||
WHERE {col} IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements({col}) AS p
|
||||
WHERE jsonb_typeof(p) = 'object'
|
||||
AND (p->>'label') IS NULL
|
||||
AND (p->>'quote') IS NULL
|
||||
)
|
||||
"""
|
||||
auto_rows = 0
|
||||
for col in ("key_points", "risks"):
|
||||
result = db.execute(text(_BULK_AUTO_LABEL.format(col=col)))
|
||||
auto_rows += result.rowcount
|
||||
db.commit()
|
||||
logger.info(f"backfill_brief_labels: bulk auto-labeled {auto_rows} rows (quoteless → inference)")
|
||||
|
||||
# Step 2: Find briefs that still have unlabeled points (must have quotes → need LLM).
|
||||
unlabeled_ids = db.execute(text("""
|
||||
SELECT id FROM bill_briefs
|
||||
WHERE (
|
||||
key_points IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements(key_points) AS p
|
||||
WHERE jsonb_typeof(p) = 'object' AND (p->>'label') IS NULL
|
||||
)
|
||||
) OR (
|
||||
risks IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements(risks) AS r
|
||||
WHERE jsonb_typeof(r) = 'object' AND (r->>'label') IS NULL
|
||||
)
|
||||
)
|
||||
""")).fetchall()
|
||||
|
||||
total = len(unlabeled_ids)
|
||||
updated = 0
|
||||
skipped = 0
|
||||
|
||||
prov_row = db.get(AppSetting, "llm_provider")
|
||||
model_row = db.get(AppSetting, "llm_model")
|
||||
provider = get_llm_provider(
|
||||
prov_row.value if prov_row else None,
|
||||
model_row.value if model_row else None,
|
||||
)
|
||||
|
||||
for row in unlabeled_ids:
|
||||
brief = db.get(BillBrief, row.id)
|
||||
if not brief:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Only points with a quote can be LLM-classified as cited_fact vs inference
|
||||
to_classify: list[tuple[str, int, dict]] = []
|
||||
for field_name in ("key_points", "risks"):
|
||||
for i, p in enumerate(getattr(brief, field_name) or []):
|
||||
if isinstance(p, dict) and p.get("label") is None and p.get("quote"):
|
||||
to_classify.append((field_name, i, p))
|
||||
|
||||
if not to_classify:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
lines = [
|
||||
f'{i + 1}. TEXT: "{p["text"]}" | QUOTE: "{p.get("quote", "")}"'
|
||||
for i, (_, __, p) in enumerate(to_classify)
|
||||
]
|
||||
prompt = (
|
||||
"Classify each item as 'cited_fact' or 'inference'.\n"
|
||||
"cited_fact = the claim is explicitly and directly stated in the quoted text.\n"
|
||||
"inference = analytical interpretation, projection, or implication not literally stated.\n\n"
|
||||
"Return ONLY a JSON array of strings, one per item, in order. No explanation.\n\n"
|
||||
"Items:\n" + "\n".join(lines)
|
||||
)
|
||||
|
||||
try:
|
||||
raw = provider.generate_text(prompt).strip()
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("```")[1]
|
||||
if raw.startswith("json"):
|
||||
raw = raw[4:]
|
||||
labels = json.loads(raw.strip())
|
||||
if not isinstance(labels, list) or len(labels) != len(to_classify):
|
||||
logger.warning(f"Brief {brief.id}: label count mismatch, skipping")
|
||||
skipped += 1
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(f"Brief {brief.id}: classification failed: {exc}")
|
||||
skipped += 1
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
fields_modified: set[str] = set()
|
||||
for (field_name, point_idx, _), label in zip(to_classify, labels):
|
||||
if label in ("cited_fact", "inference"):
|
||||
getattr(brief, field_name)[point_idx]["label"] = label
|
||||
fields_modified.add(field_name)
|
||||
|
||||
for field_name in fields_modified:
|
||||
flag_modified(brief, field_name)
|
||||
|
||||
db.commit()
|
||||
updated += 1
|
||||
time.sleep(0.2)
|
||||
|
||||
logger.info(
|
||||
f"backfill_brief_labels: {total} briefs needing LLM, "
|
||||
f"{updated} updated, {skipped} skipped"
|
||||
)
|
||||
return {"auto_labeled_rows": auto_rows, "total_llm": total, "updated": updated, "skipped": skipped}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.llm_processor.resume_pending_analysis")
|
||||
def resume_pending_analysis(self):
|
||||
"""
|
||||
Two-pass backfill for bills missing analysis:
|
||||
|
||||
Pass 1 — Documents with no brief (LLM tasks failed/timed out):
|
||||
Find BillDocuments that have raw_text but no BillBrief, re-queue LLM.
|
||||
|
||||
Pass 2 — Bills with no document at all:
|
||||
Find Bills with no BillDocument, re-queue document fetch (which will
|
||||
then chain into LLM if text is available on GovInfo).
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
# Pass 1: docs with raw_text but no brief
|
||||
docs_no_brief = db.execute(text("""
|
||||
SELECT bd.id
|
||||
FROM bill_documents bd
|
||||
LEFT JOIN bill_briefs bb ON bb.document_id = bd.id
|
||||
WHERE bb.id IS NULL AND bd.raw_text IS NOT NULL
|
||||
""")).fetchall()
|
||||
|
||||
queued_llm = 0
|
||||
for row in docs_no_brief:
|
||||
process_document_with_llm.delay(row.id)
|
||||
queued_llm += 1
|
||||
time.sleep(0.1)
|
||||
|
||||
# Pass 2: bills with no document at all
|
||||
bills_no_doc = db.execute(text("""
|
||||
SELECT b.bill_id
|
||||
FROM bills b
|
||||
LEFT JOIN bill_documents bd ON bd.bill_id = b.bill_id
|
||||
WHERE bd.id IS NULL
|
||||
""")).fetchall()
|
||||
|
||||
queued_fetch = 0
|
||||
from app.workers.document_fetcher import fetch_bill_documents
|
||||
for row in bills_no_doc:
|
||||
fetch_bill_documents.delay(row.bill_id)
|
||||
queued_fetch += 1
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info(
|
||||
f"resume_pending_analysis: {queued_llm} LLM tasks queued, "
|
||||
f"{queued_fetch} document fetch tasks queued"
|
||||
)
|
||||
return {"queued_llm": queued_llm, "queued_fetch": queued_fetch}
|
||||
finally:
|
||||
db.close()
|
||||
252
backend/app/workers/member_interest.py
Normal file
252
backend/app/workers/member_interest.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Member interest worker — tracks public interest in members of Congress.
|
||||
|
||||
Fetches news articles and calculates trend scores for members using the
|
||||
same composite scoring model as bills (NewsAPI + Google News RSS + pytrends).
|
||||
Runs on a schedule and can also be triggered per-member.
|
||||
"""
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from app.database import get_sync_db
|
||||
from app.models import Member, MemberNewsArticle, MemberTrendScore
|
||||
from app.services import news_service, trends_service
|
||||
from app.workers.celery_app import celery_app
|
||||
from app.workers.trend_scorer import calculate_composite_score
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_pub_at(raw: str | None) -> datetime | None:
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(raw.replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=2, name="app.workers.member_interest.sync_member_interest")
|
||||
def sync_member_interest(self, bioguide_id: str):
|
||||
"""
|
||||
Fetch news and score a member in a single API pass.
|
||||
Called on first profile view — avoids the 2x NewsAPI + GNews calls that
|
||||
result from queuing fetch_member_news and calculate_member_trend_score separately.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
member = db.get(Member, bioguide_id)
|
||||
if not member or not member.first_name or not member.last_name:
|
||||
return {"status": "skipped"}
|
||||
|
||||
query = news_service.build_member_query(
|
||||
first_name=member.first_name,
|
||||
last_name=member.last_name,
|
||||
chamber=member.chamber,
|
||||
)
|
||||
|
||||
# Single fetch — results reused for both article storage and scoring
|
||||
newsapi_articles = news_service.fetch_newsapi_articles(query, days=30)
|
||||
gnews_articles = news_service.fetch_gnews_articles(query, days=30)
|
||||
all_articles = newsapi_articles + gnews_articles
|
||||
|
||||
saved = 0
|
||||
for article in all_articles:
|
||||
url = article.get("url")
|
||||
if not url:
|
||||
continue
|
||||
existing = (
|
||||
db.query(MemberNewsArticle)
|
||||
.filter_by(member_id=bioguide_id, url=url)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
continue
|
||||
db.add(MemberNewsArticle(
|
||||
member_id=bioguide_id,
|
||||
source=article.get("source", "")[:200],
|
||||
headline=article.get("headline", ""),
|
||||
url=url,
|
||||
published_at=_parse_pub_at(article.get("published_at")),
|
||||
relevance_score=1.0,
|
||||
))
|
||||
saved += 1
|
||||
|
||||
# Score using counts already in hand — no second API round-trip
|
||||
today = date.today()
|
||||
if not db.query(MemberTrendScore).filter_by(member_id=bioguide_id, score_date=today).first():
|
||||
keywords = trends_service.keywords_for_member(member.first_name, member.last_name)
|
||||
gtrends_score = trends_service.get_trends_score(keywords)
|
||||
composite = calculate_composite_score(
|
||||
len(newsapi_articles), len(gnews_articles), gtrends_score
|
||||
)
|
||||
db.add(MemberTrendScore(
|
||||
member_id=bioguide_id,
|
||||
score_date=today,
|
||||
newsapi_count=len(newsapi_articles),
|
||||
gnews_count=len(gnews_articles),
|
||||
gtrends_score=gtrends_score,
|
||||
composite_score=composite,
|
||||
))
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Synced member interest for {bioguide_id}: {saved} articles saved")
|
||||
return {"status": "ok", "saved": saved}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Member interest sync failed for {bioguide_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=300)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=2, name="app.workers.member_interest.fetch_member_news")
|
||||
def fetch_member_news(self, bioguide_id: str):
|
||||
"""Fetch and store recent news articles for a specific member."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
member = db.get(Member, bioguide_id)
|
||||
if not member or not member.first_name or not member.last_name:
|
||||
return {"status": "skipped"}
|
||||
|
||||
query = news_service.build_member_query(
|
||||
first_name=member.first_name,
|
||||
last_name=member.last_name,
|
||||
chamber=member.chamber,
|
||||
)
|
||||
|
||||
newsapi_articles = news_service.fetch_newsapi_articles(query, days=30)
|
||||
gnews_articles = news_service.fetch_gnews_articles(query, days=30)
|
||||
all_articles = newsapi_articles + gnews_articles
|
||||
|
||||
saved = 0
|
||||
for article in all_articles:
|
||||
url = article.get("url")
|
||||
if not url:
|
||||
continue
|
||||
existing = (
|
||||
db.query(MemberNewsArticle)
|
||||
.filter_by(member_id=bioguide_id, url=url)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
continue
|
||||
db.add(MemberNewsArticle(
|
||||
member_id=bioguide_id,
|
||||
source=article.get("source", "")[:200],
|
||||
headline=article.get("headline", ""),
|
||||
url=url,
|
||||
published_at=_parse_pub_at(article.get("published_at")),
|
||||
relevance_score=1.0,
|
||||
))
|
||||
saved += 1
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Saved {saved} news articles for member {bioguide_id}")
|
||||
return {"status": "ok", "saved": saved}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Member news fetch failed for {bioguide_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=300)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.member_interest.calculate_member_trend_score")
|
||||
def calculate_member_trend_score(self, bioguide_id: str):
|
||||
"""Calculate and store today's public interest score for a member."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
member = db.get(Member, bioguide_id)
|
||||
if not member or not member.first_name or not member.last_name:
|
||||
return {"status": "skipped"}
|
||||
|
||||
today = date.today()
|
||||
existing = (
|
||||
db.query(MemberTrendScore)
|
||||
.filter_by(member_id=bioguide_id, score_date=today)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
return {"status": "already_scored"}
|
||||
|
||||
query = news_service.build_member_query(
|
||||
first_name=member.first_name,
|
||||
last_name=member.last_name,
|
||||
chamber=member.chamber,
|
||||
)
|
||||
keywords = trends_service.keywords_for_member(member.first_name, member.last_name)
|
||||
|
||||
newsapi_articles = news_service.fetch_newsapi_articles(query, days=30)
|
||||
newsapi_count = len(newsapi_articles)
|
||||
gnews_count = news_service.fetch_gnews_count(query, days=30)
|
||||
gtrends_score = trends_service.get_trends_score(keywords)
|
||||
|
||||
composite = calculate_composite_score(newsapi_count, gnews_count, gtrends_score)
|
||||
|
||||
db.add(MemberTrendScore(
|
||||
member_id=bioguide_id,
|
||||
score_date=today,
|
||||
newsapi_count=newsapi_count,
|
||||
gnews_count=gnews_count,
|
||||
gtrends_score=gtrends_score,
|
||||
composite_score=composite,
|
||||
))
|
||||
db.commit()
|
||||
logger.info(f"Scored member {bioguide_id}: composite={composite:.1f}")
|
||||
return {"status": "ok", "composite": composite}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Member trend scoring failed for {bioguide_id}: {exc}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.member_interest.fetch_news_for_active_members")
|
||||
def fetch_news_for_active_members(self):
|
||||
"""
|
||||
Scheduled task: fetch news for members who have been viewed or followed.
|
||||
Prioritises members with detail_fetched set (profile has been viewed).
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
members = (
|
||||
db.query(Member)
|
||||
.filter(Member.detail_fetched.isnot(None))
|
||||
.filter(Member.first_name.isnot(None))
|
||||
.all()
|
||||
)
|
||||
for member in members:
|
||||
fetch_member_news.delay(member.bioguide_id)
|
||||
|
||||
logger.info(f"Queued news fetch for {len(members)} members")
|
||||
return {"queued": len(members)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.member_interest.calculate_all_member_trend_scores")
|
||||
def calculate_all_member_trend_scores(self):
|
||||
"""
|
||||
Scheduled nightly task: score all members that have been viewed.
|
||||
Members are scored only after their profile has been loaded at least once.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
members = (
|
||||
db.query(Member)
|
||||
.filter(Member.detail_fetched.isnot(None))
|
||||
.filter(Member.first_name.isnot(None))
|
||||
.all()
|
||||
)
|
||||
for member in members:
|
||||
calculate_member_trend_score.delay(member.bioguide_id)
|
||||
|
||||
logger.info(f"Queued trend scoring for {len(members)} members")
|
||||
return {"queued": len(members)}
|
||||
finally:
|
||||
db.close()
|
||||
159
backend/app/workers/news_fetcher.py
Normal file
159
backend/app/workers/news_fetcher.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
News fetcher — correlates bills with news articles.
|
||||
Triggered after LLM brief creation and on a 6-hour schedule for active bills.
|
||||
"""
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillBrief, NewsArticle
|
||||
from app.services import news_service
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _save_articles(db, bill_id: str, articles: list[dict]) -> int:
|
||||
"""Persist a list of article dicts for a bill, skipping duplicates. Returns saved count."""
|
||||
saved = 0
|
||||
for article in articles:
|
||||
url = article.get("url")
|
||||
if not url:
|
||||
continue
|
||||
existing = db.query(NewsArticle).filter_by(bill_id=bill_id, url=url).first()
|
||||
if existing:
|
||||
continue
|
||||
pub_at = None
|
||||
if article.get("published_at"):
|
||||
try:
|
||||
pub_at = datetime.fromisoformat(article["published_at"].replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
pass
|
||||
db.add(NewsArticle(
|
||||
bill_id=bill_id,
|
||||
source=article.get("source", "")[:200],
|
||||
headline=article.get("headline", ""),
|
||||
url=url,
|
||||
published_at=pub_at,
|
||||
relevance_score=1.0,
|
||||
))
|
||||
saved += 1
|
||||
return saved
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=2, name="app.workers.news_fetcher.fetch_news_for_bill")
|
||||
def fetch_news_for_bill(self, bill_id: str):
|
||||
"""Fetch news articles for a specific bill."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
return {"status": "not_found"}
|
||||
|
||||
query = news_service.build_news_query(
|
||||
bill_title=bill.title,
|
||||
short_title=bill.short_title,
|
||||
sponsor_name=None,
|
||||
bill_type=bill.bill_type,
|
||||
bill_number=bill.bill_number,
|
||||
)
|
||||
|
||||
newsapi_articles = news_service.fetch_newsapi_articles(query)
|
||||
gnews_articles = news_service.fetch_gnews_articles(query)
|
||||
saved = _save_articles(db, bill_id, newsapi_articles + gnews_articles)
|
||||
db.commit()
|
||||
logger.info(f"Saved {saved} news articles for bill {bill_id}")
|
||||
return {"status": "ok", "saved": saved}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"News fetch failed for {bill_id}: {exc}")
|
||||
raise self.retry(exc=exc, countdown=300)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=2, name="app.workers.news_fetcher.fetch_news_for_bill_batch")
|
||||
def fetch_news_for_bill_batch(self, bill_ids: list):
|
||||
"""
|
||||
Fetch news for a batch of bills in ONE NewsAPI call using OR query syntax
|
||||
(up to NEWSAPI_BATCH_SIZE bills per call). Google News is fetched per-bill
|
||||
but served from the 2-hour Redis cache so the RSS is only hit once per query.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bills = [db.get(Bill, bid) for bid in bill_ids]
|
||||
bills = [b for b in bills if b]
|
||||
if not bills:
|
||||
return {"status": "no_bills"}
|
||||
|
||||
# Build (bill_id, query) pairs for the batch NewsAPI call
|
||||
bill_queries = [
|
||||
(
|
||||
bill.bill_id,
|
||||
news_service.build_news_query(
|
||||
bill_title=bill.title,
|
||||
short_title=bill.short_title,
|
||||
sponsor_name=None,
|
||||
bill_type=bill.bill_type,
|
||||
bill_number=bill.bill_number,
|
||||
),
|
||||
)
|
||||
for bill in bills
|
||||
]
|
||||
|
||||
# One NewsAPI call for the whole batch
|
||||
newsapi_batch = news_service.fetch_newsapi_articles_batch(bill_queries)
|
||||
|
||||
total_saved = 0
|
||||
for bill in bills:
|
||||
query = next(q for bid, q in bill_queries if bid == bill.bill_id)
|
||||
newsapi_articles = newsapi_batch.get(bill.bill_id, [])
|
||||
# Google News is cached — fine to call per-bill (cache hit after first)
|
||||
gnews_articles = news_service.fetch_gnews_articles(query)
|
||||
total_saved += _save_articles(db, bill.bill_id, newsapi_articles + gnews_articles)
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Batch saved {total_saved} articles for {len(bills)} bills")
|
||||
return {"status": "ok", "bills": len(bills), "saved": total_saved}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Batch news fetch failed: {exc}")
|
||||
raise self.retry(exc=exc, countdown=300)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.news_fetcher.fetch_news_for_active_bills")
|
||||
def fetch_news_for_active_bills(self):
|
||||
"""
|
||||
Scheduled task: fetch news for bills with recent actions (last 7 days).
|
||||
Groups bills into batches of NEWSAPI_BATCH_SIZE to multiply effective quota.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
cutoff = date.today() - timedelta(days=7)
|
||||
active_bills = (
|
||||
db.query(Bill)
|
||||
.filter(Bill.latest_action_date >= cutoff)
|
||||
.order_by(Bill.latest_action_date.desc())
|
||||
.limit(80)
|
||||
.all()
|
||||
)
|
||||
|
||||
bill_ids = [b.bill_id for b in active_bills]
|
||||
batch_size = news_service.NEWSAPI_BATCH_SIZE
|
||||
batches = [bill_ids[i:i + batch_size] for i in range(0, len(bill_ids), batch_size)]
|
||||
for batch in batches:
|
||||
fetch_news_for_bill_batch.delay(batch)
|
||||
|
||||
logger.info(
|
||||
f"Queued {len(batches)} news batches for {len(active_bills)} active bills "
|
||||
f"({batch_size} bills/batch)"
|
||||
)
|
||||
return {"queued_batches": len(batches), "total_bills": len(active_bills)}
|
||||
finally:
|
||||
db.close()
|
||||
572
backend/app/workers/notification_dispatcher.py
Normal file
572
backend/app/workers/notification_dispatcher.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
Notification dispatcher — sends pending notification events via ntfy.
|
||||
|
||||
RSS is pull-based so no dispatch is needed for it; events are simply
|
||||
marked dispatched once ntfy is sent (or immediately if the user has no
|
||||
ntfy configured but has an RSS token, so the feed can clean up old items).
|
||||
|
||||
Runs every 5 minutes on Celery Beat.
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import requests
|
||||
|
||||
from app.core.crypto import decrypt_secret
|
||||
from app.database import get_sync_db
|
||||
from app.models.follow import Follow
|
||||
from app.models.notification import NotificationEvent
|
||||
from app.models.user import User
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NTFY_TIMEOUT = 10
|
||||
|
||||
_EVENT_TITLES = {
|
||||
"new_document": "New Bill Text",
|
||||
"new_amendment": "Amendment Filed",
|
||||
"bill_updated": "Bill Updated",
|
||||
"weekly_digest": "Weekly Digest",
|
||||
}
|
||||
|
||||
_EVENT_TAGS = {
|
||||
"new_document": "page_facing_up",
|
||||
"new_amendment": "memo",
|
||||
"bill_updated": "rotating_light",
|
||||
}
|
||||
|
||||
# Milestone events are more urgent than LLM brief events
|
||||
_EVENT_PRIORITY = {
|
||||
"bill_updated": "high",
|
||||
"new_document": "default",
|
||||
"new_amendment": "default",
|
||||
}
|
||||
|
||||
|
||||
_FILTER_DEFAULTS = {
|
||||
"new_document": False, "new_amendment": False, "vote": False,
|
||||
"presidential": False, "committee_report": False, "calendar": False,
|
||||
"procedural": False, "referral": False,
|
||||
}
|
||||
|
||||
|
||||
def _should_dispatch(event, prefs: dict, follow_mode: str = "neutral") -> bool:
|
||||
payload = event.payload or {}
|
||||
source = payload.get("source", "bill_follow")
|
||||
|
||||
# Map event type directly for document events
|
||||
if event.event_type == "new_document":
|
||||
key = "new_document"
|
||||
elif event.event_type == "new_amendment":
|
||||
key = "new_amendment"
|
||||
else:
|
||||
# Use action_category if present (new events), fall back from milestone_tier (old events)
|
||||
key = payload.get("action_category")
|
||||
if not key:
|
||||
key = "referral" if payload.get("milestone_tier") == "referral" else "vote"
|
||||
|
||||
all_filters = prefs.get("alert_filters")
|
||||
if all_filters is None:
|
||||
return True # user hasn't configured filters yet — send everything
|
||||
|
||||
if source in ("member_follow", "topic_follow"):
|
||||
source_filters = all_filters.get(source)
|
||||
if source_filters is None:
|
||||
return True # section not configured — send everything
|
||||
if not source_filters.get("enabled", True):
|
||||
return False # master toggle off
|
||||
# Per-entity mute checks
|
||||
if source == "member_follow":
|
||||
muted_ids = source_filters.get("muted_ids") or []
|
||||
if payload.get("matched_member_id") in muted_ids:
|
||||
return False
|
||||
if source == "topic_follow":
|
||||
muted_tags = source_filters.get("muted_tags") or []
|
||||
if payload.get("matched_topic") in muted_tags:
|
||||
return False
|
||||
return bool(source_filters.get(key, _FILTER_DEFAULTS.get(key, True)))
|
||||
|
||||
# Bill follow — use follow mode filters (existing behaviour)
|
||||
mode_filters = all_filters.get(follow_mode) or {}
|
||||
return bool(mode_filters.get(key, _FILTER_DEFAULTS.get(key, True)))
|
||||
|
||||
|
||||
def _in_quiet_hours(prefs: dict, now: datetime) -> bool:
|
||||
"""Return True if the current local time falls within the user's quiet window.
|
||||
|
||||
Quiet hours are stored as local-time hour integers. If the user has a stored
|
||||
IANA timezone name we convert ``now`` (UTC) to that zone before comparing.
|
||||
Falls back to UTC if the timezone is absent or unrecognised.
|
||||
"""
|
||||
start = prefs.get("quiet_hours_start")
|
||||
end = prefs.get("quiet_hours_end")
|
||||
if start is None or end is None:
|
||||
return False
|
||||
|
||||
tz_name = prefs.get("timezone")
|
||||
if tz_name:
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
h = now.astimezone(ZoneInfo(tz_name)).hour
|
||||
except Exception:
|
||||
h = now.hour # unrecognised timezone — degrade gracefully to UTC
|
||||
else:
|
||||
h = now.hour
|
||||
|
||||
if start <= end:
|
||||
return start <= h < end
|
||||
# Wraps midnight (e.g. 22 → 8)
|
||||
return h >= start or h < end
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.notification_dispatcher.dispatch_notifications")
|
||||
def dispatch_notifications(self):
|
||||
"""Fan out pending notification events to ntfy and mark dispatched."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
pending = (
|
||||
db.query(NotificationEvent)
|
||||
.filter(NotificationEvent.dispatched_at.is_(None))
|
||||
.order_by(NotificationEvent.created_at)
|
||||
.limit(200)
|
||||
.all()
|
||||
)
|
||||
|
||||
sent = 0
|
||||
failed = 0
|
||||
held = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for event in pending:
|
||||
user = db.get(User, event.user_id)
|
||||
if not user:
|
||||
event.dispatched_at = now
|
||||
db.commit()
|
||||
continue
|
||||
|
||||
# Look up follow mode for this (user, bill) pair
|
||||
follow = db.query(Follow).filter_by(
|
||||
user_id=event.user_id, follow_type="bill", follow_value=event.bill_id
|
||||
).first()
|
||||
follow_mode = follow.follow_mode if follow else "neutral"
|
||||
|
||||
prefs = user.notification_prefs or {}
|
||||
|
||||
if not _should_dispatch(event, prefs, follow_mode):
|
||||
event.dispatched_at = now
|
||||
db.commit()
|
||||
continue
|
||||
ntfy_url = prefs.get("ntfy_topic_url", "").strip()
|
||||
ntfy_auth_method = prefs.get("ntfy_auth_method", "none")
|
||||
ntfy_token = prefs.get("ntfy_token", "").strip()
|
||||
ntfy_username = prefs.get("ntfy_username", "").strip()
|
||||
ntfy_password = decrypt_secret(prefs.get("ntfy_password", "").strip())
|
||||
ntfy_enabled = prefs.get("ntfy_enabled", False)
|
||||
rss_enabled = prefs.get("rss_enabled", False)
|
||||
digest_enabled = prefs.get("digest_enabled", False)
|
||||
|
||||
ntfy_configured = ntfy_enabled and bool(ntfy_url)
|
||||
|
||||
# Hold events when ntfy is configured but delivery should be deferred
|
||||
in_quiet = _in_quiet_hours(prefs, now) if ntfy_configured else False
|
||||
hold = ntfy_configured and (in_quiet or digest_enabled)
|
||||
|
||||
if hold:
|
||||
held += 1
|
||||
continue # Leave undispatched — digest task or next run after quiet hours
|
||||
|
||||
if ntfy_configured:
|
||||
try:
|
||||
_send_ntfy(
|
||||
event, ntfy_url, ntfy_auth_method, ntfy_token,
|
||||
ntfy_username, ntfy_password, follow_mode=follow_mode,
|
||||
)
|
||||
sent += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"ntfy dispatch failed for event {event.id}: {e}")
|
||||
failed += 1
|
||||
|
||||
email_enabled = prefs.get("email_enabled", False)
|
||||
email_address = prefs.get("email_address", "").strip()
|
||||
if email_enabled and email_address:
|
||||
try:
|
||||
_send_email(event, email_address, unsubscribe_token=user.email_unsubscribe_token)
|
||||
sent += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"email dispatch failed for event {event.id}: {e}")
|
||||
failed += 1
|
||||
|
||||
# Mark dispatched: channels were attempted, or user has no channels configured (RSS-only)
|
||||
event.dispatched_at = now
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
f"dispatch_notifications: {sent} sent, {failed} failed, "
|
||||
f"{held} held (quiet hours/digest), {len(pending)} total pending"
|
||||
)
|
||||
return {"sent": sent, "failed": failed, "held": held, "total": len(pending)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.notification_dispatcher.send_notification_digest")
|
||||
def send_notification_digest(self):
|
||||
"""
|
||||
Send a bundled ntfy digest for users with digest mode enabled.
|
||||
Runs daily; weekly-frequency users only receive on Mondays.
|
||||
"""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
users = db.query(User).all()
|
||||
digest_users = [
|
||||
u for u in users
|
||||
if (u.notification_prefs or {}).get("digest_enabled", False)
|
||||
and (u.notification_prefs or {}).get("ntfy_enabled", False)
|
||||
and (u.notification_prefs or {}).get("ntfy_topic_url", "").strip()
|
||||
]
|
||||
|
||||
sent = 0
|
||||
for user in digest_users:
|
||||
prefs = user.notification_prefs or {}
|
||||
frequency = prefs.get("digest_frequency", "daily")
|
||||
|
||||
# Weekly digests only fire on Mondays (weekday 0)
|
||||
if frequency == "weekly" and now.weekday() != 0:
|
||||
continue
|
||||
|
||||
lookback_hours = 168 if frequency == "weekly" else 24
|
||||
cutoff = now - timedelta(hours=lookback_hours)
|
||||
|
||||
events = (
|
||||
db.query(NotificationEvent)
|
||||
.filter_by(user_id=user.id)
|
||||
.filter(
|
||||
NotificationEvent.dispatched_at.is_(None),
|
||||
NotificationEvent.created_at > cutoff,
|
||||
)
|
||||
.order_by(NotificationEvent.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not events:
|
||||
continue
|
||||
|
||||
try:
|
||||
ntfy_url = prefs.get("ntfy_topic_url", "").strip()
|
||||
_send_digest_ntfy(events, ntfy_url, prefs)
|
||||
for event in events:
|
||||
event.dispatched_at = now
|
||||
db.commit()
|
||||
sent += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Digest send failed for user {user.id}: {e}")
|
||||
|
||||
logger.info(f"send_notification_digest: digests sent to {sent} users")
|
||||
return {"sent": sent}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _build_reason(payload: dict) -> str | None:
|
||||
source = payload.get("source", "bill_follow")
|
||||
mode_labels = {"pocket_veto": "Pocket Veto", "pocket_boost": "Pocket Boost", "neutral": "Following"}
|
||||
if source == "bill_follow":
|
||||
mode = payload.get("follow_mode", "neutral")
|
||||
return f"\U0001f4cc {mode_labels.get(mode, 'Following')} this bill"
|
||||
if source == "member_follow":
|
||||
name = payload.get("matched_member_name")
|
||||
return f"\U0001f464 You follow {name}" if name else "\U0001f464 Member you follow"
|
||||
if source == "topic_follow":
|
||||
topic = payload.get("matched_topic")
|
||||
return f"\U0001f3f7 You follow \"{topic}\"" if topic else "\U0001f3f7 Topic you follow"
|
||||
return None
|
||||
|
||||
|
||||
def _send_email(
|
||||
event: NotificationEvent,
|
||||
email_address: str,
|
||||
unsubscribe_token: str | None = None,
|
||||
) -> None:
|
||||
"""Send a plain-text email notification via SMTP."""
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from app.config import settings as app_settings
|
||||
|
||||
if not app_settings.SMTP_HOST or not email_address:
|
||||
return
|
||||
|
||||
payload = event.payload or {}
|
||||
bill_label = payload.get("bill_label", event.bill_id.upper())
|
||||
bill_title = payload.get("bill_title", "")
|
||||
event_label = _EVENT_TITLES.get(event.event_type, "Bill Update")
|
||||
base_url = (app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip("/")
|
||||
|
||||
subject = f"PocketVeto: {event_label} — {bill_label}"
|
||||
|
||||
lines = [f"{event_label}: {bill_label}"]
|
||||
if bill_title:
|
||||
lines.append(bill_title)
|
||||
if payload.get("brief_summary"):
|
||||
lines.append("")
|
||||
lines.append(payload["brief_summary"][:500])
|
||||
reason = _build_reason(payload)
|
||||
if reason:
|
||||
lines.append("")
|
||||
lines.append(reason)
|
||||
if payload.get("bill_url"):
|
||||
lines.append("")
|
||||
lines.append(f"View bill: {payload['bill_url']}")
|
||||
|
||||
unsubscribe_url = f"{base_url}/api/notifications/unsubscribe/{unsubscribe_token}" if unsubscribe_token else None
|
||||
if unsubscribe_url:
|
||||
lines.append("")
|
||||
lines.append(f"Unsubscribe from email alerts: {unsubscribe_url}")
|
||||
|
||||
body = "\n".join(lines)
|
||||
|
||||
from_addr = app_settings.SMTP_FROM or app_settings.SMTP_USER
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = email_address
|
||||
if unsubscribe_url:
|
||||
msg["List-Unsubscribe"] = f"<{unsubscribe_url}>"
|
||||
msg["List-Unsubscribe-Post"] = "List-Unsubscribe=One-Click"
|
||||
msg.attach(MIMEText(body, "plain", "utf-8"))
|
||||
|
||||
use_ssl = app_settings.SMTP_PORT == 465
|
||||
if use_ssl:
|
||||
smtp_ctx = smtplib.SMTP_SSL(app_settings.SMTP_HOST, app_settings.SMTP_PORT, timeout=10)
|
||||
else:
|
||||
smtp_ctx = smtplib.SMTP(app_settings.SMTP_HOST, app_settings.SMTP_PORT, timeout=10)
|
||||
with smtp_ctx as s:
|
||||
if not use_ssl and app_settings.SMTP_STARTTLS:
|
||||
s.starttls()
|
||||
if app_settings.SMTP_USER:
|
||||
s.login(app_settings.SMTP_USER, app_settings.SMTP_PASSWORD)
|
||||
s.sendmail(from_addr, [email_address], msg.as_string())
|
||||
|
||||
|
||||
def _send_ntfy(
|
||||
event: NotificationEvent,
|
||||
topic_url: str,
|
||||
auth_method: str = "none",
|
||||
token: str = "",
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
follow_mode: str = "neutral",
|
||||
) -> None:
|
||||
payload = event.payload or {}
|
||||
bill_label = payload.get("bill_label", event.bill_id.upper())
|
||||
bill_title = payload.get("bill_title", "")
|
||||
event_label = _EVENT_TITLES.get(event.event_type, "Bill Update")
|
||||
|
||||
title = f"{event_label}: {bill_label}"
|
||||
|
||||
lines = [bill_title] if bill_title else []
|
||||
if payload.get("brief_summary"):
|
||||
lines.append("")
|
||||
lines.append(payload["brief_summary"][:300])
|
||||
reason = _build_reason(payload)
|
||||
if reason:
|
||||
lines.append("")
|
||||
lines.append(reason)
|
||||
message = "\n".join(lines) or bill_label
|
||||
|
||||
headers = {
|
||||
"Title": title,
|
||||
"Priority": _EVENT_PRIORITY.get(event.event_type, "default"),
|
||||
"Tags": _EVENT_TAGS.get(event.event_type, "bell"),
|
||||
}
|
||||
if payload.get("bill_url"):
|
||||
headers["Click"] = payload["bill_url"]
|
||||
|
||||
if follow_mode == "pocket_boost":
|
||||
headers["Actions"] = (
|
||||
f"view, View Bill, {payload.get('bill_url', '')}; "
|
||||
"view, Find Your Rep, https://www.house.gov/representatives/find-your-representative"
|
||||
)
|
||||
|
||||
if auth_method == "token" and token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif auth_method == "basic" and username:
|
||||
creds = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
resp = requests.post(topic_url, data=message.encode("utf-8"), headers=headers, timeout=NTFY_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.notification_dispatcher.send_weekly_digest")
|
||||
def send_weekly_digest(self):
|
||||
"""
|
||||
Proactive week-in-review summary for followed bills.
|
||||
|
||||
Runs every Monday at 8:30 AM UTC. Queries bills followed by each user
|
||||
for any activity in the past 7 days and dispatches a low-noise summary
|
||||
via ntfy and/or creates a NotificationEvent for the RSS feed.
|
||||
|
||||
Unlike send_notification_digest (which bundles queued events), this task
|
||||
generates a fresh summary independent of the notification event queue.
|
||||
"""
|
||||
from app.config import settings as app_settings
|
||||
from app.models.bill import Bill
|
||||
|
||||
db = get_sync_db()
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = now - timedelta(days=7)
|
||||
base_url = (app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip("/")
|
||||
|
||||
users = db.query(User).all()
|
||||
ntfy_sent = 0
|
||||
rss_created = 0
|
||||
|
||||
for user in users:
|
||||
prefs = user.notification_prefs or {}
|
||||
ntfy_enabled = prefs.get("ntfy_enabled", False)
|
||||
ntfy_url = prefs.get("ntfy_topic_url", "").strip()
|
||||
rss_enabled = prefs.get("rss_enabled", False)
|
||||
ntfy_configured = ntfy_enabled and bool(ntfy_url)
|
||||
|
||||
if not ntfy_configured and not rss_enabled:
|
||||
continue
|
||||
|
||||
bill_follows = db.query(Follow).filter_by(
|
||||
user_id=user.id, follow_type="bill"
|
||||
).all()
|
||||
if not bill_follows:
|
||||
continue
|
||||
|
||||
bill_ids = [f.follow_value for f in bill_follows]
|
||||
|
||||
active_bills = (
|
||||
db.query(Bill)
|
||||
.filter(
|
||||
Bill.bill_id.in_(bill_ids),
|
||||
Bill.updated_at >= cutoff,
|
||||
)
|
||||
.order_by(Bill.updated_at.desc())
|
||||
.limit(20)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not active_bills:
|
||||
continue
|
||||
|
||||
count = len(active_bills)
|
||||
anchor = active_bills[0]
|
||||
|
||||
summary_lines = []
|
||||
for bill in active_bills[:10]:
|
||||
lbl = _format_bill_label(bill)
|
||||
action = (bill.latest_action_text or "")[:80]
|
||||
summary_lines.append(f"• {lbl}: {action}" if action else f"• {lbl}")
|
||||
if count > 10:
|
||||
summary_lines.append(f" …and {count - 10} more")
|
||||
summary = "\n".join(summary_lines)
|
||||
|
||||
# Mark dispatched_at immediately so dispatch_notifications skips this event;
|
||||
# it still appears in the RSS feed since that endpoint reads all events.
|
||||
event = NotificationEvent(
|
||||
user_id=user.id,
|
||||
bill_id=anchor.bill_id,
|
||||
event_type="weekly_digest",
|
||||
dispatched_at=now,
|
||||
payload={
|
||||
"bill_label": "Weekly Digest",
|
||||
"bill_title": f"{count} followed bill{'s' if count != 1 else ''} had activity this week",
|
||||
"brief_summary": summary,
|
||||
"bill_count": count,
|
||||
"bill_url": f"{base_url}/bills/{anchor.bill_id}",
|
||||
},
|
||||
)
|
||||
db.add(event)
|
||||
rss_created += 1
|
||||
|
||||
if ntfy_configured:
|
||||
try:
|
||||
_send_weekly_digest_ntfy(count, summary, ntfy_url, prefs)
|
||||
ntfy_sent += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Weekly digest ntfy failed for user {user.id}: {e}")
|
||||
|
||||
db.commit()
|
||||
logger.info(f"send_weekly_digest: {ntfy_sent} ntfy sent, {rss_created} events created")
|
||||
return {"ntfy_sent": ntfy_sent, "rss_created": rss_created}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _format_bill_label(bill) -> str:
|
||||
_TYPE_MAP = {
|
||||
"hr": "H.R.", "s": "S.", "hjres": "H.J.Res.", "sjres": "S.J.Res.",
|
||||
"hconres": "H.Con.Res.", "sconres": "S.Con.Res.", "hres": "H.Res.", "sres": "S.Res.",
|
||||
}
|
||||
prefix = _TYPE_MAP.get(bill.bill_type.lower(), bill.bill_type.upper())
|
||||
return f"{prefix} {bill.bill_number}"
|
||||
|
||||
|
||||
def _send_weekly_digest_ntfy(count: int, summary: str, ntfy_url: str, prefs: dict) -> None:
|
||||
auth_method = prefs.get("ntfy_auth_method", "none")
|
||||
ntfy_token = prefs.get("ntfy_token", "").strip()
|
||||
ntfy_username = prefs.get("ntfy_username", "").strip()
|
||||
ntfy_password = prefs.get("ntfy_password", "").strip()
|
||||
|
||||
headers = {
|
||||
"Title": f"PocketVeto Weekly — {count} bill{'s' if count != 1 else ''} updated",
|
||||
"Priority": "low",
|
||||
"Tags": "newspaper,calendar",
|
||||
}
|
||||
if auth_method == "token" and ntfy_token:
|
||||
headers["Authorization"] = f"Bearer {ntfy_token}"
|
||||
elif auth_method == "basic" and ntfy_username:
|
||||
creds = base64.b64encode(f"{ntfy_username}:{ntfy_password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
resp = requests.post(ntfy_url, data=summary.encode("utf-8"), headers=headers, timeout=NTFY_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
|
||||
|
||||
def _send_digest_ntfy(events: list, ntfy_url: str, prefs: dict) -> None:
|
||||
auth_method = prefs.get("ntfy_auth_method", "none")
|
||||
ntfy_token = prefs.get("ntfy_token", "").strip()
|
||||
ntfy_username = prefs.get("ntfy_username", "").strip()
|
||||
ntfy_password = prefs.get("ntfy_password", "").strip()
|
||||
|
||||
headers = {
|
||||
"Title": f"PocketVeto Digest — {len(events)} update{'s' if len(events) != 1 else ''}",
|
||||
"Priority": "default",
|
||||
"Tags": "newspaper",
|
||||
}
|
||||
|
||||
if auth_method == "token" and ntfy_token:
|
||||
headers["Authorization"] = f"Bearer {ntfy_token}"
|
||||
elif auth_method == "basic" and ntfy_username:
|
||||
creds = base64.b64encode(f"{ntfy_username}:{ntfy_password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
# Group by bill, show up to 10
|
||||
by_bill: dict = defaultdict(list)
|
||||
for event in events:
|
||||
by_bill[event.bill_id].append(event)
|
||||
|
||||
lines = []
|
||||
for bill_id, bill_events in list(by_bill.items())[:10]:
|
||||
payload = bill_events[0].payload or {}
|
||||
bill_label = payload.get("bill_label", bill_id.upper())
|
||||
event_labels = list({_EVENT_TITLES.get(e.event_type, "Update") for e in bill_events})
|
||||
lines.append(f"• {bill_label}: {', '.join(event_labels)}")
|
||||
|
||||
if len(by_bill) > 10:
|
||||
lines.append(f" …and {len(by_bill) - 10} more bills")
|
||||
|
||||
message = "\n".join(lines)
|
||||
resp = requests.post(ntfy_url, data=message.encode("utf-8"), headers=headers, timeout=NTFY_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
164
backend/app/workers/notification_utils.py
Normal file
164
backend/app/workers/notification_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Shared notification utilities — used by llm_processor, congress_poller, etc.
|
||||
Centralised here to avoid circular imports.
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
_VOTE_KW = ["passed", "failed", "agreed to", "roll call"]
|
||||
_PRES_KW = ["signed", "vetoed", "enacted", "presented to the president"]
|
||||
_COMMITTEE_KW = ["markup", "ordered to be reported", "ordered reported", "reported by", "discharged"]
|
||||
_CALENDAR_KW = ["placed on"]
|
||||
_PROCEDURAL_KW = ["cloture", "conference"]
|
||||
_REFERRAL_KW = ["referred to"]
|
||||
|
||||
# Events created within this window for the same (user, bill, event_type) are suppressed
|
||||
_DEDUP_MINUTES = 30
|
||||
|
||||
|
||||
def categorize_action(action_text: str) -> str | None:
|
||||
"""Return the action category string, or None if not notification-worthy."""
|
||||
t = (action_text or "").lower()
|
||||
if any(kw in t for kw in _VOTE_KW): return "vote"
|
||||
if any(kw in t for kw in _PRES_KW): return "presidential"
|
||||
if any(kw in t for kw in _COMMITTEE_KW): return "committee_report"
|
||||
if any(kw in t for kw in _CALENDAR_KW): return "calendar"
|
||||
if any(kw in t for kw in _PROCEDURAL_KW): return "procedural"
|
||||
if any(kw in t for kw in _REFERRAL_KW): return "referral"
|
||||
return None
|
||||
|
||||
|
||||
def _build_payload(
|
||||
bill, action_summary: str, action_category: str, source: str = "bill_follow"
|
||||
) -> dict:
|
||||
from app.config import settings
|
||||
base_url = (settings.PUBLIC_URL or settings.LOCAL_URL).rstrip("/")
|
||||
return {
|
||||
"bill_title": bill.short_title or bill.title or "",
|
||||
"bill_label": f"{bill.bill_type.upper()} {bill.bill_number}",
|
||||
"brief_summary": (action_summary or "")[:300],
|
||||
"bill_url": f"{base_url}/bills/{bill.bill_id}",
|
||||
"action_category": action_category,
|
||||
# kept for RSS/history backwards compat
|
||||
"milestone_tier": "referral" if action_category == "referral" else "progress",
|
||||
"source": source,
|
||||
}
|
||||
|
||||
|
||||
def _is_duplicate(db, user_id: int, bill_id: str, event_type: str) -> bool:
|
||||
"""True if an identical event was already created within the dedup window."""
|
||||
from app.models.notification import NotificationEvent
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=_DEDUP_MINUTES)
|
||||
return db.query(NotificationEvent).filter_by(
|
||||
user_id=user_id,
|
||||
bill_id=bill_id,
|
||||
event_type=event_type,
|
||||
).filter(NotificationEvent.created_at > cutoff).first() is not None
|
||||
|
||||
|
||||
def emit_bill_notification(
|
||||
db, bill, event_type: str, action_summary: str, action_category: str = "vote"
|
||||
) -> int:
|
||||
"""Create NotificationEvent rows for every user following this bill. Returns count."""
|
||||
from app.models.follow import Follow
|
||||
from app.models.notification import NotificationEvent
|
||||
|
||||
followers = db.query(Follow).filter_by(follow_type="bill", follow_value=bill.bill_id).all()
|
||||
if not followers:
|
||||
return 0
|
||||
|
||||
payload = _build_payload(bill, action_summary, action_category, source="bill_follow")
|
||||
count = 0
|
||||
for follow in followers:
|
||||
if _is_duplicate(db, follow.user_id, bill.bill_id, event_type):
|
||||
continue
|
||||
db.add(NotificationEvent(
|
||||
user_id=follow.user_id,
|
||||
bill_id=bill.bill_id,
|
||||
event_type=event_type,
|
||||
payload={**payload, "follow_mode": follow.follow_mode},
|
||||
))
|
||||
count += 1
|
||||
if count:
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
def emit_member_follow_notifications(
|
||||
db, bill, event_type: str, action_summary: str, action_category: str = "vote"
|
||||
) -> int:
|
||||
"""Notify users following the bill's sponsor (dedup prevents double-alerts for bill+member followers)."""
|
||||
if not bill.sponsor_id:
|
||||
return 0
|
||||
|
||||
from app.models.follow import Follow
|
||||
from app.models.notification import NotificationEvent
|
||||
|
||||
followers = db.query(Follow).filter_by(follow_type="member", follow_value=bill.sponsor_id).all()
|
||||
if not followers:
|
||||
return 0
|
||||
|
||||
from app.models.member import Member
|
||||
member = db.get(Member, bill.sponsor_id)
|
||||
payload = _build_payload(bill, action_summary, action_category, source="member_follow")
|
||||
payload["matched_member_name"] = member.name if member else None
|
||||
payload["matched_member_id"] = bill.sponsor_id
|
||||
count = 0
|
||||
for follow in followers:
|
||||
if _is_duplicate(db, follow.user_id, bill.bill_id, event_type):
|
||||
continue
|
||||
db.add(NotificationEvent(
|
||||
user_id=follow.user_id,
|
||||
bill_id=bill.bill_id,
|
||||
event_type=event_type,
|
||||
payload=payload,
|
||||
))
|
||||
count += 1
|
||||
if count:
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
def emit_topic_follow_notifications(
|
||||
db, bill, event_type: str, action_summary: str, topic_tags: list,
|
||||
action_category: str = "vote",
|
||||
) -> int:
|
||||
"""Notify users following any of the bill's topic tags."""
|
||||
if not topic_tags:
|
||||
return 0
|
||||
|
||||
from app.models.follow import Follow
|
||||
from app.models.notification import NotificationEvent
|
||||
|
||||
# Single query for all topic followers, then deduplicate by user_id
|
||||
all_follows = db.query(Follow).filter(
|
||||
Follow.follow_type == "topic",
|
||||
Follow.follow_value.in_(topic_tags),
|
||||
).all()
|
||||
|
||||
seen_user_ids: set[int] = set()
|
||||
followers = []
|
||||
follower_topic: dict[int, str] = {}
|
||||
for follow in all_follows:
|
||||
if follow.user_id not in seen_user_ids:
|
||||
seen_user_ids.add(follow.user_id)
|
||||
followers.append(follow)
|
||||
follower_topic[follow.user_id] = follow.follow_value
|
||||
|
||||
if not followers:
|
||||
return 0
|
||||
|
||||
payload = _build_payload(bill, action_summary, action_category, source="topic_follow")
|
||||
count = 0
|
||||
for follow in followers:
|
||||
if _is_duplicate(db, follow.user_id, bill.bill_id, event_type):
|
||||
continue
|
||||
db.add(NotificationEvent(
|
||||
user_id=follow.user_id,
|
||||
bill_id=bill.bill_id,
|
||||
event_type=event_type,
|
||||
payload={**payload, "matched_topic": follower_topic.get(follow.user_id)},
|
||||
))
|
||||
count += 1
|
||||
if count:
|
||||
db.commit()
|
||||
return count
|
||||
126
backend/app/workers/trend_scorer.py
Normal file
126
backend/app/workers/trend_scorer.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
Trend scorer — calculates the daily zeitgeist score for bills.
|
||||
Runs nightly via Celery Beat.
|
||||
"""
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from app.database import get_sync_db
|
||||
from app.models import Bill, BillBrief, TrendScore
|
||||
from app.services import news_service, trends_service
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PYTRENDS_BATCH = 5 # max keywords pytrends accepts per call
|
||||
|
||||
|
||||
def calculate_composite_score(newsapi_count: int, gnews_count: int, gtrends_score: float) -> float:
|
||||
"""
|
||||
Weighted composite score (0–100):
|
||||
NewsAPI article count → 0–40 pts (saturates at 20 articles)
|
||||
Google News RSS count → 0–30 pts (saturates at 50 articles)
|
||||
Google Trends score → 0–30 pts (0–100 input)
|
||||
"""
|
||||
newsapi_pts = min(newsapi_count / 20, 1.0) * 40
|
||||
gnews_pts = min(gnews_count / 50, 1.0) * 30
|
||||
gtrends_pts = (gtrends_score / 100) * 30
|
||||
return round(newsapi_pts + gnews_pts + gtrends_pts, 2)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.trend_scorer.calculate_all_trend_scores")
|
||||
def calculate_all_trend_scores(self):
|
||||
"""Nightly task: calculate trend scores for bills active in the last 90 days."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
cutoff = date.today() - timedelta(days=90)
|
||||
active_bills = (
|
||||
db.query(Bill)
|
||||
.filter(Bill.latest_action_date >= cutoff)
|
||||
.all()
|
||||
)
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Filter to bills not yet scored today
|
||||
bills_to_score = []
|
||||
for bill in active_bills:
|
||||
existing = (
|
||||
db.query(TrendScore)
|
||||
.filter_by(bill_id=bill.bill_id, score_date=today)
|
||||
.first()
|
||||
)
|
||||
if not existing:
|
||||
bills_to_score.append(bill)
|
||||
|
||||
scored = 0
|
||||
|
||||
# Process in batches of _PYTRENDS_BATCH so one pytrends call covers multiple bills
|
||||
for batch_start in range(0, len(bills_to_score), _PYTRENDS_BATCH):
|
||||
batch = bills_to_score[batch_start: batch_start + _PYTRENDS_BATCH]
|
||||
|
||||
# Collect keyword groups for pytrends batch call
|
||||
keyword_groups = []
|
||||
bill_queries = []
|
||||
for bill in batch:
|
||||
latest_brief = (
|
||||
db.query(BillBrief)
|
||||
.filter_by(bill_id=bill.bill_id)
|
||||
.order_by(BillBrief.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
topic_tags = latest_brief.topic_tags if latest_brief else []
|
||||
query = news_service.build_news_query(
|
||||
bill_title=bill.title,
|
||||
short_title=bill.short_title,
|
||||
sponsor_name=None,
|
||||
bill_type=bill.bill_type,
|
||||
bill_number=bill.bill_number,
|
||||
)
|
||||
keywords = trends_service.keywords_for_bill(
|
||||
title=bill.title or "",
|
||||
short_title=bill.short_title or "",
|
||||
topic_tags=topic_tags,
|
||||
)
|
||||
keyword_groups.append(keywords)
|
||||
bill_queries.append(query)
|
||||
|
||||
# One pytrends call for the whole batch
|
||||
gtrends_scores = trends_service.get_trends_scores_batch(keyword_groups)
|
||||
|
||||
for i, bill in enumerate(batch):
|
||||
try:
|
||||
query = bill_queries[i]
|
||||
# NewsAPI + Google News counts (gnews served from 2-hour cache)
|
||||
newsapi_articles = news_service.fetch_newsapi_articles(query, days=30)
|
||||
newsapi_count = len(newsapi_articles)
|
||||
gnews_count = news_service.fetch_gnews_count(query, days=30)
|
||||
gtrends_score = gtrends_scores[i]
|
||||
|
||||
composite = calculate_composite_score(newsapi_count, gnews_count, gtrends_score)
|
||||
|
||||
db.add(TrendScore(
|
||||
bill_id=bill.bill_id,
|
||||
score_date=today,
|
||||
newsapi_count=newsapi_count,
|
||||
gnews_count=gnews_count,
|
||||
gtrends_score=gtrends_score,
|
||||
composite_score=composite,
|
||||
))
|
||||
scored += 1
|
||||
except Exception as exc:
|
||||
logger.warning(f"Trend scoring skipped for {bill.bill_id}: {exc}")
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Scored {scored} bills")
|
||||
return {"scored": scored}
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"Trend scoring failed: {exc}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
271
backend/app/workers/vote_fetcher.py
Normal file
271
backend/app/workers/vote_fetcher.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Vote fetcher — fetches roll-call vote data for bills.
|
||||
|
||||
Roll-call votes are referenced in bill actions as recordedVotes objects.
|
||||
Each recordedVote contains a direct URL to the source XML:
|
||||
- House: https://clerk.house.gov/evs/{year}/roll{NNN}.xml
|
||||
- Senate: https://www.senate.gov/legislative/LIS/roll_call_votes/...
|
||||
|
||||
We fetch and parse that XML directly rather than going through a
|
||||
Congress.gov API endpoint (which doesn't expose vote detail).
|
||||
|
||||
Triggered on-demand from GET /api/bills/{bill_id}/votes when no votes
|
||||
are stored yet.
|
||||
"""
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
import requests
|
||||
|
||||
from app.database import get_sync_db
|
||||
from app.models.bill import Bill
|
||||
from app.models.member import Member
|
||||
from app.models.vote import BillVote, MemberVotePosition
|
||||
from app.services.congress_api import get_bill_actions as _api_get_bill_actions
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FETCH_TIMEOUT = 15
|
||||
|
||||
|
||||
def _parse_date(s) -> date | None:
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
return date.fromisoformat(str(s)[:10])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_xml(url: str) -> ET.Element:
|
||||
resp = requests.get(url, timeout=_FETCH_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
return ET.fromstring(resp.content)
|
||||
|
||||
|
||||
def _parse_house_xml(root: ET.Element) -> dict:
|
||||
"""Parse House Clerk roll-call XML (clerk.house.gov/evs/...)."""
|
||||
meta = root.find("vote-metadata")
|
||||
question = (meta.findtext("vote-question") or "").strip() if meta is not None else ""
|
||||
result = (meta.findtext("vote-result") or "").strip() if meta is not None else ""
|
||||
|
||||
totals = root.find(".//totals-by-vote")
|
||||
yeas = int((totals.findtext("yea-total") or "0").strip()) if totals is not None else 0
|
||||
nays = int((totals.findtext("nay-total") or "0").strip()) if totals is not None else 0
|
||||
not_voting = int((totals.findtext("not-voting-total") or "0").strip()) if totals is not None else 0
|
||||
|
||||
members = []
|
||||
for rv in root.findall(".//recorded-vote"):
|
||||
leg = rv.find("legislator")
|
||||
if leg is None:
|
||||
continue
|
||||
members.append({
|
||||
"bioguide_id": leg.get("name-id"),
|
||||
"member_name": (leg.text or "").strip(),
|
||||
"party": leg.get("party"),
|
||||
"state": leg.get("state"),
|
||||
"position": (rv.findtext("vote") or "Not Voting").strip(),
|
||||
})
|
||||
|
||||
return {"question": question, "result": result, "yeas": yeas, "nays": nays,
|
||||
"not_voting": not_voting, "members": members}
|
||||
|
||||
|
||||
def _parse_senate_xml(root: ET.Element) -> dict:
|
||||
"""Parse Senate LIS roll-call XML (senate.gov/legislative/LIS/...)."""
|
||||
question = (root.findtext("vote_question_text") or root.findtext("question") or "").strip()
|
||||
result = (root.findtext("vote_result_text") or "").strip()
|
||||
|
||||
counts = root.find("vote_counts")
|
||||
yeas = int((counts.findtext("yeas") or "0").strip()) if counts is not None else 0
|
||||
nays = int((counts.findtext("nays") or "0").strip()) if counts is not None else 0
|
||||
not_voting = int((counts.findtext("absent") or "0").strip()) if counts is not None else 0
|
||||
|
||||
members = []
|
||||
for m in root.findall(".//member"):
|
||||
first = (m.findtext("first_name") or "").strip()
|
||||
last = (m.findtext("last_name") or "").strip()
|
||||
members.append({
|
||||
"bioguide_id": (m.findtext("bioguide_id") or "").strip() or None,
|
||||
"member_name": f"{first} {last}".strip(),
|
||||
"party": m.findtext("party"),
|
||||
"state": m.findtext("state"),
|
||||
"position": (m.findtext("vote_cast") or "Not Voting").strip(),
|
||||
})
|
||||
|
||||
return {"question": question, "result": result, "yeas": yeas, "nays": nays,
|
||||
"not_voting": not_voting, "members": members}
|
||||
|
||||
|
||||
def _parse_vote_xml(url: str, chamber: str) -> dict:
|
||||
root = _fetch_xml(url)
|
||||
if chamber.lower() == "house":
|
||||
return _parse_house_xml(root)
|
||||
return _parse_senate_xml(root)
|
||||
|
||||
|
||||
def _collect_recorded_votes(congress: int, bill_type: str, bill_number: int) -> list[dict]:
|
||||
"""Page through all bill actions and collect unique recordedVotes entries."""
|
||||
seen: set[tuple] = set()
|
||||
recorded: list[dict] = []
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
data = _api_get_bill_actions(congress, bill_type, bill_number, offset=offset)
|
||||
actions = data.get("actions", [])
|
||||
pagination = data.get("pagination", {})
|
||||
|
||||
for action in actions:
|
||||
for rv in action.get("recordedVotes", []):
|
||||
chamber = rv.get("chamber", "")
|
||||
session = int(rv.get("sessionNumber") or rv.get("session") or 1)
|
||||
roll_number = rv.get("rollNumber")
|
||||
if not roll_number:
|
||||
continue
|
||||
roll_number = int(roll_number)
|
||||
key = (chamber, session, roll_number)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
recorded.append({
|
||||
"chamber": chamber,
|
||||
"session": session,
|
||||
"roll_number": roll_number,
|
||||
"date": action.get("actionDate"),
|
||||
"url": rv.get("url"),
|
||||
})
|
||||
|
||||
total = pagination.get("count", 0)
|
||||
offset += len(actions)
|
||||
if offset >= total or not actions:
|
||||
break
|
||||
|
||||
return recorded
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.vote_fetcher.fetch_bill_votes")
|
||||
def fetch_bill_votes(self, bill_id: str) -> dict:
|
||||
"""Fetch and store roll-call votes for a single bill."""
|
||||
db = get_sync_db()
|
||||
try:
|
||||
bill = db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
return {"error": f"Bill {bill_id} not found"}
|
||||
|
||||
recorded = _collect_recorded_votes(bill.congress_number, bill.bill_type, bill.bill_number)
|
||||
|
||||
if not recorded:
|
||||
logger.info(f"fetch_bill_votes({bill_id}): no recorded votes in actions")
|
||||
return {"bill_id": bill_id, "stored": 0, "skipped": 0}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
stored = 0
|
||||
skipped = 0
|
||||
|
||||
# Cache known bioguide IDs to avoid N+1 member lookups
|
||||
known_bioguides: set[str] = {
|
||||
row[0] for row in db.query(Member.bioguide_id).all()
|
||||
}
|
||||
|
||||
for rv in recorded:
|
||||
chamber = rv["chamber"]
|
||||
session = rv["session"]
|
||||
roll_number = rv["roll_number"]
|
||||
source_url = rv.get("url")
|
||||
|
||||
existing = (
|
||||
db.query(BillVote)
|
||||
.filter_by(
|
||||
congress=bill.congress_number,
|
||||
chamber=chamber,
|
||||
session=session,
|
||||
roll_number=roll_number,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
if not source_url:
|
||||
logger.warning(f"No URL for {chamber} roll {roll_number} — skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
parsed = _parse_vote_xml(source_url, chamber)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Could not parse vote XML {source_url}: {exc}")
|
||||
continue
|
||||
|
||||
bill_vote = BillVote(
|
||||
bill_id=bill_id,
|
||||
congress=bill.congress_number,
|
||||
chamber=chamber,
|
||||
session=session,
|
||||
roll_number=roll_number,
|
||||
question=parsed["question"],
|
||||
description=None,
|
||||
vote_date=_parse_date(rv.get("date")),
|
||||
yeas=parsed["yeas"],
|
||||
nays=parsed["nays"],
|
||||
not_voting=parsed["not_voting"],
|
||||
result=parsed["result"],
|
||||
source_url=source_url,
|
||||
fetched_at=now,
|
||||
)
|
||||
db.add(bill_vote)
|
||||
db.flush()
|
||||
|
||||
for pos in parsed["members"]:
|
||||
bioguide_id = pos.get("bioguide_id")
|
||||
if bioguide_id and bioguide_id not in known_bioguides:
|
||||
bioguide_id = None
|
||||
db.add(MemberVotePosition(
|
||||
vote_id=bill_vote.id,
|
||||
bioguide_id=bioguide_id,
|
||||
member_name=pos.get("member_name"),
|
||||
party=pos.get("party"),
|
||||
state=pos.get("state"),
|
||||
position=pos.get("position") or "Not Voting",
|
||||
))
|
||||
|
||||
db.commit()
|
||||
stored += 1
|
||||
|
||||
logger.info(f"fetch_bill_votes({bill_id}): {stored} stored, {skipped} skipped")
|
||||
return {"bill_id": bill_id, "stored": stored, "skipped": skipped}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.workers.vote_fetcher.fetch_votes_for_stanced_bills")
|
||||
def fetch_votes_for_stanced_bills(self) -> dict:
|
||||
"""
|
||||
Nightly task: queue vote fetches for every bill any user has a stance on
|
||||
(pocket_veto or pocket_boost). Only queues bills that don't already have
|
||||
a vote stored, so re-runs are cheap after the first pass.
|
||||
"""
|
||||
from app.models.follow import Follow
|
||||
|
||||
db = get_sync_db()
|
||||
try:
|
||||
from sqlalchemy import text as sa_text
|
||||
rows = db.execute(sa_text("""
|
||||
SELECT DISTINCT f.follow_value AS bill_id
|
||||
FROM follows f
|
||||
LEFT JOIN bill_votes bv ON bv.bill_id = f.follow_value
|
||||
WHERE f.follow_type = 'bill'
|
||||
AND f.follow_mode IN ('pocket_veto', 'pocket_boost')
|
||||
AND bv.id IS NULL
|
||||
""")).fetchall()
|
||||
|
||||
queued = 0
|
||||
for row in rows:
|
||||
fetch_bill_votes.delay(row.bill_id)
|
||||
queued += 1
|
||||
|
||||
logger.info(f"fetch_votes_for_stanced_bills: queued {queued} bills")
|
||||
return {"queued": queued}
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user